たけのこブログ

凡人が頑張って背伸びするブログ

PyTorchでよくあるエラーの対処方法(次元やチャンネル数)

チャンネル数や次元数、ちゃんと一致してる...?

PyTorchとは、pythonディープラーニング用の機械学習ライブラリで、Facebook人工知能研究グループAI Research labによって開発されました。Define by runの特徴を有しており動的にモデルを組むことができるので、コードを非常に簡潔にできます。

しかし、このPyTorch...如何せん日本語のエラー対応の記事が少ないです(海外のコミュニティでは盛んに見受けられます)。そこで今回は、PyTorchで深層学習を実装する上で遭遇しやすいエラーに対する対処方法についてまとめてみたいと思います。

PyTorchで機械学習の機能を実装する上で一番陥りやすいエラーは「チャンネル数や次元が一致しないことで生じるエラー」だと思います。例えば、以下のようなものが考えられます。

# チャンネル数の対応が間違っている場合に起きるエラー
RuntimeError: Given groups=1, weight[64, 3, 3, 3], so expected input[16, 64, 256, 256] to have 3 channels, but got 64 channels instead
# 次元が一致していない場合に乗じるエラー
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 3, 5, 5], but got 3-dimensional input of size [3, 256, 256] instead

今回は、上記のエラーへの対処方法について説明したいと思います。

チャンネル数の不一致に対するエラーの原因は、PyTorchとnumpy, Keras間のデータ形式の違いから生じているケース、あるいはRGB以外に透明度(α channel)が混入しているケースなどがある

# チャンネル数の対応が間違っている場合に起きるエラー
RuntimeError: Given groups=1, weight[64, 3, 3, 3], so expected input[16, 64, 256, 256] to have 3 channels, but got 64 channels instead

numpyやKerasはNWHC形式に対応しており、それぞれ(バッチ数、画像の幅、画像の高さ、チャンネル数)として取り出されます。一方で、PyTorchの場合はNCWH形式に対応しているため、4階テンソルの形が(バッチ数、チャンネル数、画像の幅、画像の高さ)に対応しています。

例えば、以下のコードだと画像をテンソルに変換してもNWHC形式に対応してしまうため、PyTorchのConv2dなどの畳み込み演算で上記のエラーが生じます。

from PIL import Image
import numpy as np

img = Image.open('hogehoge')
numpy_img = np.asarray(img, np.float32) / 255.0
tensor_img = np.expand_dims(numpy_image, axis=0) # このままだとNWHCの形式

この対策方法として、あくまで一例ですが以下のようにしてNCWHの形式に変換する必要があります。

rgb_tensor = rgb_array.permute(0,3,1,2)

ここではPyTorchのpermute関数を使用してランクの入れ替えを行なっていますが、numpyのtransposeでも同様の変換が可能です。

また、もう一つの原因としては、画像にαチャンネル(透明度)が入っている場合が考えられます。これも通常のRGBの画像処理をPyTorchで実行する場合はウェブサイトの記事などでもRGBの3チャンネルを採用している場合がほとんどなので、もしも使用している画像に透明度が入っている場合には、以下のようにしてRGB形式に変換する必要があります。

img = Image.open(self.paths[index]).convert('RGB')

次元が一致していない場合に乗じるエラーはsqueezeやunsqueezeで次元を追加してやるとうまくいくケースが多い

# 次元が一致していない場合に乗じるエラー
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [6, 3, 5, 5], but got 3-dimensional input of size [3, 256, 256] instead

一般にPyTorchで訓練などを実行する際にはDataLoaderやDatasetを使ってイテレーションを行うことが多いと思いますので、その際はNCWH形式に画像などが変換されるため問題が起きにくいのですが、バッチではなく(一つ一つの画像を)オンラインで処理する場合やDatasetやDataLoaderなどに変換してない場合などは4階のテンソルによるNCWH形式になっていないので上記のエラーが発生するケースが多いです。

この時の対処方法は簡単で、unsqueeze(0)などを用いて次元を一つ追加して4階のテンソルに変換してやれば良いだけです。これで無事にNCWH形式に変換されているので、作成したネットワークモデルに入力として入れてもエラーが生じません。

inputs = inputs.unsqueeze(0) # 上記のエラーの場合は、これによって[3, 256, 256]から[1, 3, 256, 256]に変換される

逆のケースも同じで、[1, 3, 256, 256]から元の画像形式[3, 256, 256]にしたい場合はsqueeze()を用いればエラーを回避することができます。

まとめ

PyTorchの実装上におけるチャンネル数や次元数の不一致によって生じるエラーの対処方法についてまとめました。

他にもエラーの原因として、全結合し忘れているからview(-1)のステップを全結合する前に入れたり、GANやAutoEncoderなどで逆伝搬されない固定入力としてデコーダや識別器に代入させるために、勾配を持たないテンソルを生成するため.detach().detach().clone()などを行うなど、様々な実装やエラーにおける対処方法があります。

基本的にPyTorchで生じるエラーの多くは、英語ですが以下のサイトでまとめてあります。もしわからないことがあればこちらに投げるのが良いかなと思います(英語で質問しないといけませんが)。

discuss.pytorch.org

参考文献

discuss.pytorch.org

discuss.pytorch.org