¿Cómo puedo acceder a elementos de un tensor 3D usando índices especificados en TensorFlow?
Estoy tratando de obtener las filas de un tensor 3D en un orden específico de índices. Aquí están las entradas:
import tensorflow as tf
matrix = tf.constant([
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
[[32, 33], [34, 35], [36, 37], [38, 39]]
])
indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])
# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]
Estoy luchando con tf.gather_nd()
. ¿Cualquier sugerencia? Puedo ver que está sucediendo aquí, pero no estoy seguro de cómo aplicarlo en toda la matriz sin usar for
loop otf.map_fn
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())
"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
EDITAR: Hice una pregunta similar con respecto a numpy. Una respuesta de indexación inteligente resuelve la versión numpy, pero es difícil aplicarla en Tensors. No dude en echar un vistazo a la respuesta aceptada aquí: ¿Cómo puedo obtener elementos de la matriz 3D utilizando índices especificados en numpy?
Respuestas
¡Duh, eso fue estúpido! Ya hay una gran función disponible que trabaja en una matriz multidimensional en tensorflow; tf.gather()
Consulte el argumento batch_dims para obtener más información.
>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]