Multidimensional transpose

· Zach Ocean

It’s not obvious to me prima facie how tensor.T would be expected to work for tensor shapes of more than 2d, i.e. len(tens.shape) > 2.

It turns out that the transpose occurs across the entire tensor, so if tens has shape (x, y, z) then tens.T has shape (z, y, x).

Furthermore, tens[i, j, k] = tens[k, j, i]. I haven’t found it as intuitive geometrically to think about the transpose for above 2 dimensions, so considering the elementwise transposition is the best way I’ve found.

import torch

# Simple transpose of a 2d tensor
data = torch.arange(6).reshape(2, 3)
print(data) # [[0, 1, 2], [3, 4, 5]]
print(data.T) # [[0, 3], [1, 4], [2, 5]]]
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 3],
        [1, 4],
        [2, 5]])
# Less intuitive: transpose of a 3d tensor
data = torch.arange(24).reshape(2, 3, 4)
print(data)
"""
[
  [[0, 1, 2, 3],
   [4, 5, 6, 7],
   [8, 9, 10, 11]]
],
[
  [[12, 13, 14, 15],
   [16, 17, 18, 19],
   [20, 21, 22, 23],
]
"""
print(data.T) # Shape: (4, 3, 2). Value: 
"""
[
  [[0, 12]
   [4, 16],
   [8, 20]]
],
[
  [[1, 13],
   [5, 17],
   [8, 21]
],
[
  [[2, 14],
   [6, 18],
   [10, 22]]
],
[
  [[3, 15],
   [7, 19],
   [11, 23]]
],
"""
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[[ 0, 12],
         [ 4, 16],
         [ 8, 20]],

        [[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]]])