たけのこブログ

凡人が頑張って背伸びするブログ

pytorchで複数の異なるデータセットを簡単に統合して解析する方法

背景

GANの手法でもconditional GANの一つであるpix2pixのような画像を解析する際に、例えば「細身体系の男性Aさん」を「マッチョな男性Aさん」へ生成したいときにはそれぞれ全く別の画像セットを用意させお互いに人物などで紐づいていれば揃えて入力させて訓練させる必要があります。

しかしながら、pytorchに関しては訓練をする前にdataloaderでデータセットを作ってそれをfor文でバッチ数ごとに呼び出してループさせることが多く、二つのデータセットを別々に用意するとfor文の構成などでかなり面倒なことになります。ですので、上記のような解析手法を達成する際にはどうしても「両方のデータセット」を同時に読み込んで上げる必要がありますし、何ならちゃんとデータセットもdataloaderのままでデータもshuffleして欲しい。あと、複数もdataloaderしないで一回に纏めておきたいです。

そんな時は、ConcatDatasetというクラスを作成して複数のデータセットを結合してから纏めてdataloaderで読み込んでしまうのがオススメです。実はpytorchのチュートリアル部分ではデータセットを作成するための定義について説明があるのですが、それを応用してタプルとして複数のデータセットを結合させる形になります。以下のような形で達成できます。

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

def load_datasets():
    thin_transform  = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,))
    ])
    
    muscle_transform  = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,))
    ])
    thin_trainsets = datasets.ImageFolder(root = './GAN_datasets/s1',transform=thin_transform)
    muscle_trainsets = datasets.ImageFolder(root = './GAN_datasets/s2',transform=muscle_transform)
    Image_datasets = ConcatDataset(thin_trainsets,muscle_trainsets)
    train_loader = torch.utils.data.DataLoader(
             Image_datasets,
             batch_size=256, shuffle=True,
             num_workers=4, pin_memory=True)
    return train_loader

次にこの状態でエポック数ごとに訓練する時ですが、以下のような形で書けば大丈夫です。これで、異なるデータセットを同時にpix2pixなどで学習することが可能です。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
(中略:モデルなどを定義する)
for i in range(100):
        for (thin, muscle) in dataset:
            # thin[0] が細身男性の画像、thin[1]がラベル
            # muscle[0]がマッスル男性の画像、muscle[1]がラベル
            thin, muscle = thin[0].to(device), muscle[0].to(device)
以下略

まとめ

今回は、複数のデータセットを統合して解析するための方法について説明させて頂きました。GANなどで全く異なる画像やテキストに対してデータセットを統合して適用させたいケースはかなり出てくると思うので、その時の助けになれば幸いです。