【PyTorch 入門】PyTorch の次元操作 permute, transpose, reshape, view って何が違うの?

こんにちは、機械学習の講師をしている木下です!

ディープラーニングのコーディングをしていく際には、PyTorch などのフレームワークを使う機会が多いと思います。そんな中でよく出てくる、テンソルや次元確認・次元操作という言葉に面食らった経験もあるのではないでしょうか。

PyTorch で次元操作をする際にも、permute,transpose,reshape,view など様々な関数・メソッドが存在し混乱を招く原因となっています。

そこで、本記事では PyTorch の次元操作を徹底解説していきます!今までなんとなく関数やメソッドを使用していた方もこの記事を読めば、きちんと理解しながら PyTorch のコーディングを行うことができます!

こんな人におすすめ!

  • そもそもデータの次元がよくわからない!
  • PyTorch を用いたディープラーニングを実装を学びたい!
  • 次元操作の関数の違い・使い分けを知りたい!


今回扱う関数・メソッドは以下の通りです。また、これらに加えて次元確認のメソッドについても確認していきましょう。

  • permute
  • transpose
  • reshape
  • view

実行環境・使用するデータ

サンプルデータの作成

予め PyTorch をインポートしておきましょう。PyTorch は Google Colabratory 上には予めインストールされているので、インストールは不要です。

PyTorch のインポート
# インポート
import torch
# バージョンの確認
torch.__version__

""" 1.12.1+cu113 """

それでは、これから扱っていくサンプルデータを作成します。PyTorch では randn を利用することで、任意の次元のデータをランダムに作成することができます。ここでは、各次元の要素数が 2, 3, 5 の 3 次元データを作成してみます。

データの作成
# データの作成
x = torch.randn(2, 3, 5)

# データの確認(環境によって値は代わります)
print(x)


""" tensor([[[-0.6246,  1.0766,  1.4077,  1.3389,  1.3747],
         [ 1.4110, -0.6644, -0.2251,  0.3428,  1.1569],
         [ 0.0392,  0.5391,  0.0154,  1.3016, -0.8394]],

        [[ 2.3794, -0.0364,  1.2753, -0.6133, -0.8352],
         [ 1.0868,  1.0627,  1.3589, -1.4522, -1.7674],
         [-2.6390, -0.6833,  2.4647,  1.1037,  0.2417]]])"""

データの次元数、要素数の確認

PyTorch ではデータの次元数や要素数が異なるとコードの実行時にエラーがでます。以下のコードを活用し、常にデータの形状を確認する習慣をつけましょう。

データの次元数確認
# データの次元数確認のメソッド dim 
x.dim()

# データの次元数確認のメソッド ndimension も全く同等のメソッド
x.ndimension()

# ndim 属性を参照しても同様
x.ndim

いずれを実行しても、今回の次元数である 3 が表示されます。

それぞれの要素数は以下のコードで確認しましょう。

データの要素数確認
# PyTorch の size メソッドを用いることが一般的
x.size()

# numpy と同じように shape 属性でも確認可能
x.shape

いずれもそれぞれの要素数、torch.Size([2, 3, 5]) が表示されます。

各関数・メソッドの解説

それでは、各関数・メソッドの違いについてみていきましょう。

permute

まずは、permute について解説していきます。
permute は軸(次元)を並び替えます。第一引数に並び替えたいテンソル、第二引数に並び替える順番をタプル型で指定します。

permute の使い方
# 第一引数に並び替えたいテンソル、第二引数に並び替える順番をタプル型
x_permute = torch.permute(x, (2, 0, 1))
x_permute.size()

例えば、上記のように 2, 0, 1 と指定すると元々の軸(次元)の 2 番目、0 番目、1 番目の軸の順番に並び替えられるので、torch.Size([5, 2, 3]) のように形が変わります。

少し発展的ですが、2 次元のテンソルの場合、転置(torch.t())した後のテンソルにも適用可能です。

転置用データの作成
# 転置用データの作成(2 次元である必要がある)
y = torch.randn(2, 3)

# データの確認(環境によって値は代わります)
print(y)

""" tensor([[ 0.6869,  0.3730, -0.0598],
        [ 0.8805, -0.0262,  1.2460]])"""

転置用後に permute の実行
# 転置後に要素を入れ替える
y_permute = torch.t(y).permute(1, 0)
y_permute.size()

上記を実行すると転置で次元が入れ替わった後、さらに入れ替えをおこなっているため、作成時と同様torch.Size([2, 3]) という形になっていることがわかります。

transpose

それでは次に transpose を見てみましょう。
transpose は二つの軸(次元)を入れ替えるための関数です。

transpose の使い方
# 第一引数にデータ、第二引数、第三引数に入れ替えたい軸を指定
x_transpose = torch.transpose(x, 2, 1)
x_transpose.size()

上記のコードでは 2 番目と 1 番目の軸を入れ替えているためサイズは torch.Size([2, 5, 3]) のようになります(Python は 0 はじまりなことに注意してください)。
ここで、3 つ以上の軸を並び替えることを試みてみます。

3 軸以上の並び替え
# 第一引数にデータ、第二引数、第三引数、第四引数に入れ替えたい軸を指定
x_transpose = torch.transpose(x, 2, 1, 0)
x_transpose.size()

# 実行結果
"""
TypeError: transpose() received an invalid combination of arguments - got (Tensor, int, int, int), but expected one of:
(Tensor input, int dim0, int dim1)
(Tensor input, name dim0, name dim1)
"""

実は上記のように、エラーが起きてしまうため 3 つ以上の軸を同時に並び替えることはできません

また、transpose も permute と同様に転置後の処理が可能な関数です。

reshape

次に見かけることも多い reshape を見ていきます。
reshape は軸の並び替えだけでなく、次元数や要素数を変更することができます。

ただし、合計の要素数が合うように気をつける必要があります。例えば、今回の torch.Size([2, 3, 5]) であれば、全てを掛け合わせた 30 に要素数を合わせる必要があります。例えば、

reshape の使い方 1
# 第一引数にデータ、第二引数に各軸の要素数を指定
x_reshape = torch.reshape(x, (1, 5, 6))
x_reshape.size()

上記のコードでは 3 次元を保ったまま各要素数を変更しています。torch.Size([1, 5, 6]) と表示されたのではないでしょうか。以下のように軸の数を増減させることも可能です。

3 軸以上の並び替え
# 第一引数にデータ、第二引数に各軸の要素数を指定(要素数の掛け算が 30 になっていることに注意)
x_reshape = torch.reshape(x, (1, 5, 2, 3, 1))
x_reshape.size()

試しに、要素数の合計が異なるため、エラーが起きてしまう例も見ておきましょう。

エラーが起きてしまう例
# 第一引数にデータ、第二引数に各軸の要素数を指定(要素数の掛け算が 30 になっていないことに注意)
x_reshape = torch.reshape(x, (1, 2, 3))
x_reshape.size()

# 実行結果
"""
RuntimeError: shape '[1, 2, 3]' is invalid for input of size 30
"""

また、 reshape は転置後の処理が可能です。これは view と異なる点になります。
発展的な内容ですが、reshape はメモリ上で要素順に並んでいない場合は、コピーを作ってから処理をするので、このような処理が可能になっています。

view

最後に view について解説を行っていきます。view は基本的に reshape と同じような操作が可能です。

view の使い方 1
# 第一引数にデータ、第二引数に各軸の要素数を指定
x_view = x.view(1, 5, 6)
x_view.size()

上記のように reshape と全く同じ使い方で、出力のサイズも torch.Size([1, 5, 6]) と同様のものになります。

それでは何が違うのでしょうか。試しに以下のコードを実行してみてください。

view の使い方 2
# 転置後のデータの第一引数にデータ、第二引数に各軸の要素数を指定
y_view = torch.t(y).view(1, 6)
y_view.size()

# 実行結果
"""
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""

エラーが出てしまいます。実は view は転置後の処理ができないのです。発展的な内容ですが view はメモリ上で要素順に並んでいない場合は、処理ができずエラーとなってしまいます。

view で無理やり実行する場合には以下のように一度メモリ上でデータを並び替える必要がある点に注意しましょう。

view の使い方 3
# contiguous() でメモリ上の並びかえをしてから実行
y_view = torch.t(y).contiguous().view(1, 6)
y_view.size()

それぞれの使い分けについて

これらの使い分けには決まったルールはありませんが、以下の点を意識しておくとコーディングが楽になるかもしれません!

permute と transpose の違いについて

permute の方が上位互換のように見えますが、実際には transpose による2軸の入れ替えで事足りてしまうことの方が多いです。

引数の指定が transpose の方が少なくて済むので、ネットで調べると transpose を使う例が多く見つかります。

reshape と view の違いについて

メモリについては、発展的な内容なため、慣れてくるまで理解する必要はあまりないと思います。

重要なのは reshape はメモリ上にコピーを取る可能性があるので、演算処理の負荷がその分重くなる場合があるという点を意識しておくことです。

ディープラーニングの層が深くなると、その分演算処理に負荷がかかるので、処理を軽くするために view を使うことが多いというのが現状です。

簡単なモデルであれば reshape, 複雑なモデルであれば view と使い分けてみてください。

最後に

本記事では、初学者がつまづきがちな次元削減の関数について整理し、使い分けのアドバイスまで行いました。
PyTorch を用いると 100% 使用する関数なのでここでぜひ理解しておきましょう!

以上、Python 学習している方々のお力添えになれば幸いです!

こちらの記事もオススメ

まずは無料で学びたい方・最速で学びたい方へ

まずは無料で学びたい方: Python&機械学習入門コースがおすすめ

Python&機械学習入門コース

AI・機械学習を学び始めるならまずはここから!経産省の Web サイトでも紹介されているわかりやすいと評判の Python&機械学習入門コースが無料で受けられます!
さらにステップアップした脱ブラックボックスコースや、IT パスポートをはじめとした資格取得を目指すコースもなんと無料です!

無料で学ぶ

最速で学びたい方:キカガクの長期コースがおすすめ

一生学び放題

続々と転職・キャリアアップに成功中!受講生ファーストのサポートが人気のポイントです!

AI・機械学習・データサイエンスといえばキカガク!
非常に需要が高まっている最先端スキルを「今のうちに」習得しませんか?

無料説明会を週 2 開催しています。毎月受講生の定員がございますので確認はお早めに!

説明会ではこんなことをお話します!
  • 国も企業も育成に力を入れている先端 IT 人材とは
  • キカガクの研修実績
  • 長期コースでの学び方、できるようになること
  • 料金・給付金について
  • 質疑応答

参考リンク

参考 Python 3.9.2 ドキュメント

参考 Chainer チュートリアル