Maneira vetorizada de aplicar uma máscara tridimensional ao RGB no pytorch
Eu tenho um tensor HxWx3 representando uma imagem RGB e um tensor de máscara HxWx3 (booleano) como entrada. Assume-se que para cada (i,j) no tensor da máscara existe exatamente um valor verdadeiro (que é exatamente um de R\G\B ativado). Desejo aplicar a máscara à imagem para resultar em um tensor V HxW (ou HxWx1) onde V[i,j]='o valor R\G\B correspondente de acordo com a máscara'.
Usando Problema ao aplicar máscara binária a uma imagem RGB com numpy , consegui o seguinte:
>>> X*mask
tensor([[[ 9., 10.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 20.]],
[[ 0., 0.],
[30., 0.]]])
Mas, como afirmado, quero um único dim HxW e não HxWx3 como resultado.
Ilustração:

Respostas
Assumindo que para cada i,j apenas um único valor R/G/B é retido, você pode simplesmente fazer:
(X*mask).sum(axis=2)
Isso deve fornecer a saída desejada (HxW).