PyTorch - torch.bmm[1]

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()
# torch.Size([10, 3, 5])

explain

  • Performs a batch matrix-matrix product of matrices stored in input and mat2.
    • Input과 mat2에 저장된 행렬의 batch matrix-matrix 곱셈을 수행한다.
  • input and mat2 must be 3-D tensors each containing the same number of matrices.
    • Input과 mat2는 각각 동일한 수의 행렬을 포함하는 3D 텐서여야 합니다.
  • If input is a (b $\times$ n $\times$ m)(b×n×m) tensor, mat2 is a (b $\times$ m $\times$ p)(b×m×p) tensor, out will be a (b $\times$ n $\times$ p)(b×n×p) tensor.
    • 만약 입력 텐서가 (b $\times$ n $\times$ m)(b×n×m) 이고, mat2 텐서가 (b $\times$ m $\times$ p)(b $\times$ m $\times$ p) 이면, out 텐서는 (b $\times$ n $\times$ p)(b $\times$ n $\times$ p) 이다.
\[out_{i}=x_{i}^{exponent}\]
  • Note


Reference

1. TORCH.BMM, PyTorch, https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm