Làm cách nào để truy cập các phần tử của tensor 3D bằng cách sử dụng các chỉ số được chỉ định trong TensorFlow?

Aug 16 2020

Tôi đang cố gắng lấy các hàng của tensor 3D theo một thứ tự chỉ số cụ thể. Đây là các đầu vào:

import tensorflow as tf

matrix = tf.constant([
    [[0, 1], [2, 3], [4, 5], [6, 7]], 
    [[8, 9], [10, 11], [12, 13], [14, 15]], 
    [[16, 17], [18, 19], [20, 21], [22, 23]], 
    [[24, 25], [26, 27], [28, 29], [30, 31]], 
    [[32, 33], [34, 35], [36, 37], [38, 39]]
])

indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])

# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

Tôi đang đấu tranh với tf.gather_nd(). Bất kì lời đề nghị nào? Tôi có thể thấy nó đang xảy ra ở đây nhưng tôi không chắc chắn cách áp dụng trên toàn bộ ma trận mà không sử dụng forvòng lặp hoặctf.map_fn

print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())

"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""

CHỈNH SỬA: Tôi đã hỏi một câu hỏi tương tự liên quan đến numpy. Một câu trả lời lập chỉ mục thông minh giải quyết được phiên bản numpy, nhưng rất khó để áp dụng nó trên Tensors. Vui lòng xem câu trả lời được chấp nhận tại đây: Làm cách nào để lấy các phần tử từ ma trận 3D bằng cách sử dụng các chỉ số được chỉ định trong numpy?

Trả lời

Snehal Aug 16 2020 at 11:10

Duh, điều đó thật ngu ngốc! Đã có sẵn một chức năng rất tuyệt vời hoạt động trên mảng đa chiều trong tensorflow; tf.gather()Kiểm tra đối số batch_dims để biết thêm thông tin.

>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]