ฉันจะเข้าถึงองค์ประกอบของเทนเซอร์ 3 มิติโดยใช้ดัชนีที่ระบุใน TensorFlow ได้อย่างไร

Aug 16 2020

ฉันกำลังพยายามหาแถวของเทนเซอร์ 3 มิติตามลำดับดัชนีที่เฉพาะเจาะจง ปัจจัยการผลิตมีดังนี้

import tensorflow as tf

matrix = tf.constant([
    [[0, 1], [2, 3], [4, 5], [6, 7]], 
    [[8, 9], [10, 11], [12, 13], [14, 15]], 
    [[16, 17], [18, 19], [20, 21], [22, 23]], 
    [[24, 25], [26, 27], [28, 29], [30, 31]], 
    [[32, 33], [34, 35], [36, 37], [38, 39]]
])

indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])

# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

tf.gather_nd()ฉันดิ้นรนกับ ข้อเสนอแนะใด ๆ ? ฉันเห็นว่ามันเกิดขึ้นที่นี่ แต่ฉันไม่แน่ใจว่าจะใช้กับเมทริกซ์ทั้งหมดโดยไม่ใช้forลูปหรือtf.map_fn

print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())

"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""

แก้ไข: ฉันถามคำถามที่คล้ายกันเกี่ยวกับ numpy คำตอบการจัดทำดัชนีที่ชาญฉลาดช่วยแก้ปัญหารุ่นที่เป็นตัวเลขได้ แต่ยากที่จะนำไปใช้กับ Tensors อย่าลังเลที่จะดูคำตอบที่ยอมรับได้ที่นี่: ฉันจะรับองค์ประกอบจากเมทริกซ์ 3 มิติโดยใช้ดัชนีที่ระบุเป็นตัวเลขได้อย่างไร

คำตอบ

Snehal Aug 16 2020 at 11:10

มันโง่! มีฟังก์ชั่นที่ยอดเยี่ยมอยู่แล้วซึ่งทำงานกับอาร์เรย์หลายมิติในเทนเซอร์โฟลว์ tf.gather()ตรวจสอบอาร์กิวเมนต์batch_dimsสำหรับข้อมูลเพิ่มเติม

>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]