मैं TensorFlow में निर्दिष्ट सूचकांकों का उपयोग करके 3D टेंसर के तत्वों का उपयोग कैसे कर सकता हूं?
मैं सूचक के एक विशिष्ट क्रम में एक 3D टेंसर की पंक्तियों को प्राप्त करने की कोशिश कर रहा हूं। ये हैं इनपुट्स:
import tensorflow as tf
matrix = tf.constant([
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[16, 17], [18, 19], [20, 21], [22, 23]],
[[24, 25], [26, 27], [28, 29], [30, 31]],
[[32, 33], [34, 35], [36, 37], [38, 39]]
])
indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])
# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]
मैं संघर्ष कर रहा हूं tf.gather_nd()
। कोई उपाय? मैं देख सकता हूं कि यह यहां हो रहा है, लेकिन मुझे यकीन नहीं है कि for
लूप का उपयोग किए बिना पूरे मैट्रिक्स पर कैसे लागू किया जाएtf.map_fn
print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())
"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""
EDIT: मैंने अंक के संबंध में एक समान प्रश्न पूछा। एक चतुर अनुक्रमण उत्तर सुन्न संस्करण को हल करता है, लेकिन इसे टेन्सर पर लागू करना कठिन है। बेझिझक दिए गए उत्तर पर एक नज़र डालें: मैं सुपीरियर में निर्दिष्ट सूचकांकों का उपयोग करके 3D मैट्रिक्स से तत्व कैसे प्राप्त कर सकता हूं?
जवाब
दुआ, यह बेवकूफी थी! पहले से ही एक बहुत ही शानदार फ़ंक्शन उपलब्ध है जो टेंसरफ़्लो में बहुआयामी सरणी पर काम करता है; tf.gather()
की जाँच करें batch_dims अधिक जानकारी के लिए तर्क।
>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
[[8, 9], [10, 11], [12, 13], [14, 15]],
[[18, 19], [16, 17], [22, 23], [20, 21]],
[[24, 25], [30, 31], [26, 27], [28, 29]],
[[34, 35], [36, 37], [38, 39], [32, 33]]]