Come posso accedere agli elementi di un tensore 3D utilizzando indici specificati in TensorFlow?
Sto cercando di ottenere le righe di un tensore 3D in un ordine specifico di indici. Ecco gli input:
import tensorflow as tf
matrix = tf.constant([
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
[[32, 33], [34, 35], [36, 37], [38, 39]]
])
indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])
# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]
Sto lottando con tf.gather_nd()
. Qualche suggerimento? Posso vedere che sta succedendo qui ma non sono sicuro di come applicare su tutta la matrice senza usare for
loop otf.map_fn
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())
"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
EDIT: ho posto una domanda simile riguardo a numpy. Una risposta di indicizzazione intelligente risolve la versione numpy, ma è difficile applicarla su Tensors. Sentiti libero di dare un'occhiata alla risposta accettata qui: come posso ottenere elementi dalla matrice 3D utilizzando indici specificati in numpy?
Risposte
Duh, è stato stupido! È già disponibile una grande funzione che funziona su array multidimensionali in tensorflow; tf.gather()
Controlla l' argomento batch_dims per ulteriori informazioni.
>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]