Como posso acessar os elementos de um tensor 3D usando índices especificados no TensorFlow?
Estou tentando obter as linhas de um tensor 3D em uma ordem específica de índices. Aqui estão as entradas:
import tensorflow as tf
matrix = tf.constant([
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
[[32, 33], [34, 35], [36, 37], [38, 39]]
])
indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])
# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]
Estou lutando com tf.gather_nd()
. Alguma sugestão? Posso ver que está acontecendo aqui, mas não tenho certeza de como aplicar em toda a matriz sem usar for
loop outf.map_fn
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())
"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
EDIT: Eu fiz uma pergunta semelhante com relação ao numpy. Uma resposta de indexação inteligente resolve a versão entorpecida, mas é difícil aplicá-la no Tensors. Sinta-se à vontade para dar uma olhada na resposta aceita aqui: Como posso obter elementos da matriz 3D usando índices especificados em numpy?
Respostas
Duh, isso foi estúpido! Já existe uma função muito grande disponível que funciona em array multi-dimensional em tensorflow; tf.gather()
Verifique o argumento batch_dims para obter mais informações.
>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]