Manera vectorizada de aplicar una máscara de 3 dimensiones a RGB en pytorch
Tengo un tensor HxWx3 que representa una imagen RGB y un tensor de máscara HxWx3 (booleano) como entrada. Se supone que para cada (i, j) en el tensor de máscara hay exactamente un valor verdadero (que es exactamente uno de R\G\B activado). Quiero aplicar la máscara a la imagen para obtener un tensor V HxW (o HxWx1) donde V[i,j]='el valor R\G\B coincidente según la máscara'.
Al usar Problema al aplicar una máscara binaria a una imagen RGB con numpy , pude lograr lo siguiente:
>>> X*mask
tensor([[[ 9., 10.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 20.]],
[[ 0., 0.],
[30., 0.]]])
Pero como se indicó, quiero un solo dim HxW y no HxWx3 como resultado.
Ilustración:

Respuestas
Suponiendo que para cada i,j solo se conserva un único valor R/G/B, simplemente puede hacer lo siguiente:
(X*mask).sum(axis=2)
Esto debería darle la salida deseada (HxW).