Pytorch, ¿eliminar un bucle for al agregar la permutación de un vector a las entradas de una matriz?
Estoy tratando de implementar este documento y me quedé con este simple paso. Aunque esto tiene que ver con la atención, lo que me preocupa es cómo implementar una permutación de un vector agregado a una matriz sin usar bucles for.
Los puntajes de atención tienen un vector de sesgo aprendido agregado, la teoría es que codifica la posición relativa (ji) de los dos tokens que representa el puntaje

por lo que alfa es una matriz T x T, T depende del lote que se reenvía, y B es un vector de polarización aprendido cuya longitud tiene que ser fija y tan grande como 2T. Mi implementación actual, que creo que hace lo que sugiere el documento, es:
def __init__(...):
...
self.bias = torch.nn.Parameter(torch.randn(config.n),requires_grad = True)
stdv = 1. / math.sqrt(self.bias.data.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(..)
...
#n = 201 (2* max_seq_len + 1)
B_matrix = torch.zeros(self.T, self.T) # 60 x 60
for i in range(self.T):
B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]
attention_scores = attention_scores + B_matrix.unsqueeze(0)
# 64 x 60 x 60
...
Esta es la única parte relevante.
B_matrix = torch.zeros(self.T, self.T) # 60 x 60
for i in range(self.T):
B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]
básicamente tratando de no usar un bucle for para repasar cada fila.
pero sé que esto debe ser realmente ineficiente y costoso cuando este modelo es muy grande. Estoy haciendo un bucle for explícito sobre cada fila para obtener una permutación del vector de sesgo aprendido.
¿Alguien puede ayudarme de una mejor manera, quizás a través de la transmisión inteligente?
Después de pensarlo, no necesito crear una instancia de una matriz cero, pero ¿todavía no puedo deshacerme del bucle for? y no puede usar la recopilación ya que B_matrix tiene un tamaño diferente al de un vector b en mosaico.
functor = lambda i : bias[torch.arange(start=n//2-i, end=n//2-i+T)]
B_matrix = torch.stack([functor(i) for i in torch.arange(T)])
Respuestas
No pude averiguar qué n
se suponía que debía estar en su código, pero creo que el siguiente ejemplo torch.meshgridproporciona lo que está buscando.
Suponiendo que
n, m = 10, 20 # arbitrary
a = torch.randn(n, m)
b = torch.randn(n + m)
después
for i in range(n):
for j in range(m):
a[i, j] = a[i, j] + b[n - i + j]
es equivalente a
ii, jj = torch.meshgrid(torch.arange(n), torch.arange(m))
a = a + b[n - ii + jj]
aunque esto último es una operación fuera de lugar, lo que suele ser algo bueno. Si realmente quería una operación en el lugar, reemplácela a =
con a[...] =
.
Tenga en cuenta que este es un ejemplo de indexación de matriz de enteros donde indexamos b
usando un tensor que tiene la misma forma que a
.