Multidimensional transpose
·
Zach Ocean
It’s not obvious to me prima facie how tensor.T would be expected to work for tensor shapes of more than 2d, i.e. len(tens.shape) > 2.
It turns out that the transpose occurs across the entire tensor, so if tens has shape (x, y, z) then tens.T has shape (z, y, x).
Furthermore, tens[i, j, k] = tens[k, j, i]. I haven’t found it as intuitive geometrically to think about the transpose for above 2 dimensions, so considering the elementwise transposition is the best way I’ve found.
import torch
# Simple transpose of a 2d tensor
data = torch.arange(6).reshape(2, 3)
print(data) # [[0, 1, 2], [3, 4, 5]]
print(data.T) # [[0, 3], [1, 4], [2, 5]]]
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
# Less intuitive: transpose of a 3d tensor
data = torch.arange(24).reshape(2, 3, 4)
print(data)
"""
[
[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]
],
[
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23],
]
"""
print(data.T) # Shape: (4, 3, 2). Value:
"""
[
[[0, 12]
[4, 16],
[8, 20]]
],
[
[[1, 13],
[5, 17],
[8, 21]
],
[
[[2, 14],
[6, 18],
[10, 22]]
],
[
[[3, 15],
[7, 19],
[11, 23]]
],
"""
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[[ 0, 12],
[ 4, 16],
[ 8, 20]],
[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]]])