【画像生成AI】CycleGANによる画像のスタイル変換

本記事ではGPUSOROBANのインスタンスを使ったCycleGANによる画像のスタイル変換を紹介します。

GPUSOROBAN(高速コンピューティング)は高性能なGPUインスタンスが低コストで使えるクラウドサービスです。
サービスについて詳しく知りたい方は、
GPUSOROBANの公式サイトを御覧ください。


目次[非表示]

  1. 1.CycleGANとは
  2. 2.敵対的損失(Adversarial Loss)
  3. 3.サイクル一貫性損失(Cycle Consistency Loss)
  4. 4.自己同一性損失(Identity Mapping Loss)
  5. 5.環境構築
  6. 6.データセットのダウンロード・前処理
  7. 7.Generatorの定義
  8. 8.Discriminatorの定義
  9. 9.各関数の定義
  10. 10.クラスのインスタンス化
  11. 11.ハイパーパラメータの設定と学習の実行
  12. 12.学習済みモデルのテスト


CycleGANとは


CycleGANは、"Cycle-Consistent Generative Adversarial Networks"の略で、異なるドメイン間で画像のスタイルを変換するための技術です。スタイル変換とはデータの外見的特徴を変換することです。ドメインとはある特定の特徴を有するデータの集合を指します。例えば馬の画像やシマウマの画像は異なる特徴をもつドメインになります。画像のスタイル変換とは、あるドメインの画像をインプットし異なるドメインの画像へ変換することです。本記事では馬の画像をシマウマの画像に変換するタスクを行います。

cyclegan01


出典:https://arxiv.org/abs/1703.10593 Zhu, J., Park, T., Isola, P., & Efros, A. A. (2017). Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 


CycleGANのアーキテクチャには、「敵対的損失(Adversarial Loss)」「サイクル一貫性損失(Cycle Consistency Loss)」「自己同一性損失(Identity Mapping Loss)」の3つの特徴があります。


敵対的損失(Adversarial Loss)

CycleGANのアーキテクチャは、画像生成モデルGANのGenerator(生成器)とDiscriminator(判別器)という2つのネットワークをベースにしています。Generatorは、ランダムなノイズを入力として、Discriminatorが本物と誤認するデータを生成できるように学習します。一方でDiscriminatorは、本物のデータとGeneratorが生成した偽物のデータを正しく識別できるように学習します。これらのGeneratorとDiscriminatorは、互いに競わせるように学習していきます。これらの学習プロセスで使用される損失を「敵対的損失(Adversarial Loss)」と呼びます。


敵対的損失の例を図で解説します。図の緑の矢印に注目すると、馬(A)の本物画像を入力し、シマウマ(B)の偽物画像を生成しています。生成したシマウマ(B)の偽物画像をDiscriminatorBに入力し、シマウマの画像として偽物[0]か本物[1]かの判別をします。この判別結果による損失がGenerator Bにフィードバックされ(緑の点線)、本物に近い偽物を生成できるように学習をします。


一方のDiscriminatorBは、シマウマ(B)の偽物画像とシマウマ(B)の本物画像が入力され、シマウマの画像として偽物[0]と本物[1]かを判別します。この判別結果による損失をDiscriminatorBにフィードバックし(緑の点線)、偽物と本物をただしく判別できるように学習します。


通常のGANでは、図の緑の矢印で示した一つのネットワークのみになりますが、CycleGANでは青の矢印で示すように別のネットワークがあります。青の矢印では、緑の矢印で説明したことの逆を実行しています。シマウマ(B)の本物画像のインプットから始まり、最終的にそれぞれのGenerator AとDiscriminator Aが学習する流れになります。

cyclegan02

(図)敵対的損失のしくみ
Generator A = 馬(A)の偽物画像を生成するGenerator
Generator B = シマウマ(B)の偽物画像を生成するGenerator
Discriminator A = 馬(A)の本物画像と偽物画像を判別するDiscriminator
Discriminator B = シマウマ(B)の本物画像と偽物画像を判別するDiscriminator


サイクル一貫性損失(Cycle Consistency Loss)

「サイクル一貫性損失(Cycle Consistency Loss)」はCycleGANにおいて、主要な特徴になります。

サイクル一貫性損失を図の例に説明します。図の緑の矢印に注目してみると、馬(A)の本物画像をGenerator Bに入力し、シマウマ(B)の偽物画像を生成しています。次にこのシマウマ(B)の偽物画像をGenerator Aに入力し、馬(A)の偽物画像を生成します。


このプロセスをまとめると、「馬(A)の本物画像 → シマウマ(B)の偽物画像 → 馬(A)の偽物画像」となり、馬(A)の本物画像から馬(A)の偽物画像を再構成したことになります。途中でシマウマに変換されているものの、馬から生成された馬になりますので、このプロセスにおける馬(A)の本物画像と馬(A)の偽物画像は、同じになるべき(一貫性があるべき)というのがサイクル一貫性損失の根幹になります。この考えに基づいて、馬(A)の本物画像と馬(A)の偽物画像のピクセル間の差を損失として計算し(緑の点線)、この損失をフィードバックし、Generator BおよびGenerator Aを学習します。


青の矢印は、緑の矢印の逆のプロセスを実行しています。シマウマ(B)の本物画像から始まり、シマウマ(B)の偽物画像を再構成し、この損失をGenerator BとGenerator Aにフィードバックし学習しています。再構成された画像が元の入力画像に近づくように学習され、画像の一貫性をもつことで変換の品質が向上します。

cyclegan03

(図)サイクル一貫性損失のしくみ
Generator A = 馬(A)の偽物画像を生成するGenerator
Generator B = シマウマ(B)の偽物画像を生成するGenerator

cyclegan04

参考:https://arxiv.org/abs/1703.10593
Zhu, J., Park, T., Isola, P., & Efros, A. A. (2017). Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 


自己同一性損失(Identity Mapping Loss)

CycleGANの3つめの特徴の「自己同一性損失(Identity Mapping Loss)」について説明します。
図の緑の矢印に注目してみると、馬(A)の本物画像をGenerator Aに入力し、馬(A)の偽物画像を生成しています。Generator Aは、馬を生成するGeneratorになりますので、馬を入力した場合は、馬から馬がそのまま生成されることになります。馬から馬を生成するため、入力と出力に差異はないという考えのもと、馬(A)の本物画像と馬(A)の偽物画像のピクセル画像の差を損失として計算し、Generator Aにフィードバックして学習します。


一方で青の矢印では、緑の矢印の逆を実行しています。シマウマ(B)の本物画像の入力から始まり、シマウマ(B)の偽物画像を生成し、2つの画像差を損失として計算し、Generator Bにフィードバックしています。


なぜこのような損失計算をするかというと、元の画像を変更する必要がないときは、Generatorに変更を加えさせないよう学習させるためです。論文では、自己同一性損失を使うことで、入力画像の色を保持する効果があったと言われています。(自己同一性損失がない場合は、色が大幅に変化していた)

cyclegan05


(図)自己同一性損失のしくみ
Generator A = 馬(A)の偽物画像を生成するGenerator
Generator B = シマウマ(B)の偽物画像を生成するGenerator


cyclegan06


(図)自己同一性損失を使用した場合に生成画像の色調が保持される例
出典:
https://arxiv.org/abs/1703.10593 Zhu, J., Park, T., Isola, P., & Efros, A. A. (2017). Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 


環境構築

環境はGPUSOROBANのnvd4-80-1ulインスタンスを使用します。
nvd4-80-1ulは、NVIDIA A100を搭載した高性能GPUインスタンスです。機械学習を高速化するTensorコアや大容量GPUメモリ(80GB)が特徴です。後にGoogle ColaboratoryのNVIDIA T4との学習時間を比較してみます。


GPUSOROBANのインスタンスの作成方法、秘密鍵の設置方法については、会員登録~インスタンス作成手順の記事をご覧ください。
インスタンスの作成と秘密鍵の設定が完了しましたら、アクセスサーバーおよびインスタンスに接続をします。
本記事ではJupyterLabを使用するため、上記の手順書のインスタンス接続のコマンドと異なりますので、ご注意ください。


アクセスサーバーへの接続

ssh -L 20122:(インスタンスのIPアドレス):22 -l user as-highreso.com -p 30022 -i .ssh\ackey.txt


インスタンスへの接続方法

ssh -L 8888:localhost:8888 user@localhost -p 20122 -i .ssh\mykey.txt


インスタンス接続が完了しましたら、PyTorch、JupyterLabをインストールします。
(参考)PyTorchのインストール(Ubuntu)の記事
(参考)Jupyter Labのインストール(Ubuntu)の記事


Jupyterを起動後に、各種ライブラリをインストールします。


パッケージをアップデートします。

!sudo apt update


各種ライブラリをインストールおよびインポートします。

!pip install matplotlib
!pip install opencv-python
!sudo apt-get install -y libglib2.0-0 -y
!sudo apt-get install unzip

import os
import numpy as np
import random 
import copy
import itertools
from PIL import Image
from matplotlib import pyplot as plt
import cv2

import torch
import torch.nn as nn
from torch import tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import lr_scheduler


データセットのダウンロード・前処理


一般公開されている馬とシマウマのデータセットをダウンロードし、解凍します。

!mkdir -p datasets
dataname= "horse2zebra"
URL="https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/"+dataname+".zip"
ZIP_FILE="./datasets/"+dataname+".zip"
!wget -N $URL -O $ZIP_FILE
!unzip $ZIP_FILE -d ./datasets/
!rm -f $ZIP_FILE


データを格納するディレクトリを作成し、ダウンロードしたデータセットを移動します。

!mkdir -p checkpoint/{dataname}
!mkdir -p output/{dataname}/img
!mkdir -p datasets/{dataname}/trainA/A
!mkdir -p datasets/{dataname}/trainB/B
!mkdir -p datasets/{dataname}/testA/A
!mkdir -p datasets/{dataname}/testB/B
!mkdir -p datasets/{dataname}/testC/C
!mv datasets/{dataname}/trainA/.jpg datasets/{dataname}/trainA/A
!mv datasets/{dataname}/trainB/.jpg datasets/{dataname}/trainB/B
!mv datasets/{dataname}/testA/.jpg datasets/{dataname}/testA/A
!mv datasets/{dataname}/testB/.jpg datasets/{dataname}/testB/B


格納したデータセットを確認します。

ls datasets/horse2zebra/trainA/A
ls datasets/horse2zebra/trainB/B
ls datasets/horse2zebra/testA/A
ls datasets/horse2zebra/testB/B


バッチサイズおよびデータローダーの設定をします。
論文を踏襲しバッチサイズを1に設定しています。バッチサイズが小さい方がCycleganの画像変換の精度が向上すると言われています。

#バッチサイズ
batch_size = 1 #バッチサイズ
mean = np.array([0.5, 0.5, 0.5]) #平均値
std = np.array([0.5, 0.5, 0.5]) #標準偏差

#データローダー
data_transforms = transforms.Compose([
                transforms.Resize(286, Image.BICUBIC),#286x286 大きめに拡大 Image.BICUBICは画像補完のしくみ
                transforms.RandomCrop(256),#拡大した画像を256のサイズでランダムに切り抜き、1枚の画像を複数に水増し。
                transforms.ToTensor(),#テンソル型に変換
                transforms.Normalize(mean, std)
                ])
datasets_A = datasets.ImageFolder(os.path.join('datasets', dataname, 'trainA'), data_transforms) #馬(A)の学習用データセットの作成(ラベリングと前処理) 
datasets_B = datasets.ImageFolder(os.path.join('datasets', dataname, 'trainB'), data_transforms) #シマウマ(B)の学習用データセットの作成(ラベリングと前処理) 
loaders_A = torch.utils.data.DataLoader(datasets_A, batch_size=batch_size, shuffle=True, num_workers=1) # データセットからバッチ単位でデータを取得。画像をランダムに取り出せるようデータローダーに格納する。num_workersはcpuの稼働数。
loaders_B = torch.utils.data.DataLoader(datasets_B, batch_size=batch_size, shuffle=True, num_workers=1) 


画像を表示する関数を定義します。

#画像を表示する関数
def im_show(image):
    im = image.detach().numpy() #pytorchのテンソルから切り離し、numpyの配列に変換する
    im = im.transpose(1, 2, 0) #チャンネル、高さ、幅の順序を、チャンネル、幅、高さの順序に変換
    im = std * im + mean #画像を標準化するために、stdとmeanを使って計算
    plt.axis('off')#軸を非表示にする
    plt.imshow(im)


学習に使用する画像の表示します。

#画像表示テスト(馬)
im = next(iter(loaders_A))[0][0] #最初のバッチを取得する
im_show(im)

cyclegan07


#画像表示テスト(シマウマ)
im = next(iter(loaders_B))[0][0]
im_show(im)

cyclegan08


Generatorの定義


Generatorで使用するResidual blockを定義します。
Residual blockは、複数のResNetをまとめたものになります。ResNetを使用する目的は、深い層のネットワークを構築する際の勾配消失問題に対処するためです。通常の畳み込みニューラルネットワークでは、層が深くなるにつれ勾配がゼロに近づく勾配消失の問題があります。ResNetでは、入力データが層のブロック(Convolution block)をスキップして出力に直接接続します。これにより配消失問題を緩和します。

cyclegan09


データの正規化において、バッチノーマライゼーションではなく、インスタンスノーマライゼーションを使用しています。
バッチノーマライゼーションはミニバッチごとに平均・分散で正規化することから、ミニバッチのサイズが小さいと推定される平均・分散が正確ではなくなり、学習が安定しないことがあります。一方でインスタンスノーマライゼーションは、各サンプルごと(例えば、画像内のピクセルごと)に正規化します。この手法は、各サンプルにおいて平均・分散に正規化するため、ミニバッチ中のサンプル数が少なくても問題ありません。


#Residual network
class Resnet(nn.Module):
    def init(self, in_features): #入力のときの特徴マップの次元 特徴マップとは畳み込みのよって得られるテンソル
        super(Resnet, self).init()
        # convolution blockの作成
        conv_block = [  nn.ReflectionPad2d(1),#パディング
                        nn.Conv2d(in_features, in_features, kernel_size=3),#畳み込み フィルターのサイズ3x3
                        nn.InstanceNorm2d(in_features),#インスタンス正規化
                        nn.ReLU(inplace=True), #inplaceでメモリを有効活用する
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]
        # conv blockを順番に実行する
        # args: 複数の引数(arguments)をタプルとして渡す。
        # **kwargs: 複数のキーワード引数(keyword arguments)を辞書として渡す。
        self.conv_block = nn.Sequential(conv_block) 
    def forward(self, x):
        return x + self.conv_block(x)


Generatorを定義します。Generatorは5つのブロックで構成されています。それぞれのブロックの説明は次のとおりです。


1.Convolution block:畳み込みによって、入力画像から特徴を抽出します。
2.Down sampling:ダウンサンプリングは、特徴マップのサイズを縮小し、画像の情報を階層的に抽象化し、高次の特徴を抽出します。
3.Residual block:畳み込みと入力のスキップ接続を繰り返し、細かい特徴を抽出します。
4.Up sampling:ダウンサンプリングによって縮小した特徴マップを元のサイズに戻し、解像度を上げて画像を生成します。
5.Output layer:パディングと畳み込みでチャンネル数に合わせた画像を生成し、Tanhを使用して-1から1の範囲で出力しています。

cyclegan10


Discriminatorの定義

Discriminatorでは、主に畳み込み→正規化→LeakyReLuを繰り返すダウンサンプリングをして、特徴量の抽出を行っています。

cyclegan11

逆伝播での勾配消失問題に対処するために、活性化関数にLeakyReLUを使用しています。通常のReLUでは負の入力で、0が出力されるため、微分ができず勾配消失に陥る可能性があります。LeakyReLUは負の入力に対し、微小な負の値を出力することができます。微分値が常に0にならないので、勾配消失問題の対処が可能です。図はReLUとLeaky ReLuの違いを示しています。

cyclegan12

出力画像のサイズについては、下記の式で計算されます。
O = (I - F + 2P)/S + 1
O:出力画像のサイズ(高さ or 幅)
I:入力画像のサイズ(高さ or 幅)
F:カーネル(フィルタ)のサイズ(高さ or 幅)
P:パディング幅
S:ストライド幅


#Discriminatorの定義
class Discriminator(nn.Module):
    def init(self, in_channels=3, out_channels=1):
        super(Discriminator, self).init()
        num_features = 64 #ベースとなる特徴量
        #kernel size:4, stride:2, padding:1
        #leakyReLu:0.2,最小2まで減衰させる
        model = [   nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=4, stride=2, padding=1),#入力チャンネル3,出力チャンネル64 入力画像サイズ256x256 出力画像サイズ 128x128
                    nn.LeakyReLU(0.2, inplace=True) ]
        model += [  nn.Conv2d(num_features, 128, 4, stride=2, padding=1),#入力チャンネル64 出力チャンネル128 入力画像サイズ128x128 出力画像サイズ64x64
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]
        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),#入力チャンネル128 出力チャンネル256 入力画像サイズ64x64 出力画像サイズ32x32
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]
        model += [  nn.Conv2d(256, 512, 4, stride=1,  padding=1),#入力チャンネル256 出力チャンネル512 入力画像サイズ32x32 出力画像サイズ31x31
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]
        model += [nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,stride=1, padding=1)]#入力チャンネル512 出力チャンネル1 入力画像サイズ31x31 出力画像サイズ30x30
        self.model = nn.Sequential(*model)
    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)#xの平均値を取る
        #上はこれと同義 -> return F.avg_pool2d(x, [30, 30]).view(1, 1) #特徴マップと同サイズのフィルタを割り当て、1x1の画素に展開


重みの初期化を行う関数について定義します。

#重みの初期化
def weights_init(m):
    classname = m.class.name
    if classname.find('Conv') != -1:#畳み込みに関するクラスのみ取り出し、初期化する
        nn.init.normal_(m.weight.data, 0.0, 0.02) #平均0.0 標準偏差0.02


定義したGeneratorを表示します。

#Generatorの表示
Generator()


cyclegan13


定義したDiscriminatorを表示します。

#Discriminatorの表示
Discriminator()


cyclegan14.


各関数の定義


学習率のスケジューリングを定義します。

#学習率の減衰についてのクラス
class LambdaLR():
    def init(self, epochs, offset, decay_epoch):
        self.epochs = epochs
        self.offset = offset #保存データを読み込んだときに始める位置
        self.decay_epoch = decay_epoch #どこから減衰させるか 学習率の減衰が開始されるエポック数
    def step(self, epoch):
      lambda_lr = epoch + self.offset - self.decay_epoch
      lambda_lr = max(0, lambda_lr) #0と比較して大きい方を保持する 0以下は0とする
      lambda_lr /= self.epochs - self.decay_epoch #分母
      lambda_lr = 1.0 - lambda_lr
      return lambda_lr
      #ある一定の時刻までは同じ学習係数になるが、減衰ポイントを通過した時点で、徐々に学習係数を下げていき、徐々に細かく学習
      #offsetを加算したエポック数とdecay_epochを比較。減衰を開始するエポックよりも前の場合は0を返し、それ以降は減衰率を計算
      #計算結果は学習率の割合(0から1の間)として返される


一度生成した画像を再利用するクラスを定義します。


#一度生成した画像を再利用するクラス

#バッファが最大サイズに達した場合には古いデータを削除しながら新しいデータを追加する

#データの追加はランダムに行われ、一定の割合で古いデータを置き換える
class ReplayBuffer():
    def init(self, max_size=50):#最大何個までバッファに溜め込むかの初期値
        self.max_size = max_size
        self.items = []
    def call(self, item):
        return_item = None
        if len(self.items) < self.max_size: #最大50になるまで画像を溜め込んでおく
            self.items.append(item)
            return_item = item
        else: #50以上の場合
            if random.uniform(0,1) > 0.5: # 一様乱数を発生させて 1/2の確率でバッファを使う。
                idx = random.randint(0, self.max_size - 1)
                buffer = copy.copy(self.items[idx]) # Pop buffer
                return_item = buffer
                self.items[idx] = item # Push buffer
            else:
                return_item = item
        return return_item



学習中のログを表示する関数を定義します。

#学習中のログを表示する関数
def make_log(epoch, i, loss_info):
    log  = "epoch: " + str(epoch)
    log += " iter: " + str(i) #イテレーション
    log += " loss: " + str(loss_info)
    return log


チェックポイントのパスを生成する関数を定義します。

#モデルのチェックポイントを保存するためのパス
def checkpoint_path(filename):
    return os.path.join('checkpoint', dataname, filename)


学習に関する情報を保存する関数を定義します。

#学習に関する情報を保存する関数
def save_train_info():
    train_info = '' #train_infoの初期化
    train_info += 'epochs: %d\n' % epochs
    train_info += 'decay_epoch: %d\n' %decay_epoch
    train_info += 'learning_rate: %f\n' %learning_rate
    train_info += 'lambda_A: %f\n' %lambda_A
    train_info += 'lambda_B: %f\n' %lambda_B
    train_info += 'lambda_id: %f\n' %lambda_id
    train_info += 'criterion_gan: %s\n' %criterion_gan._get_name()
    train_info += 'criterion_cyc: %s\n' %criterion_cyc._get_name()
    train_info += 'criterion_id: %s\n' %criterion_id._get_name()
    with open(checkpoint_path('train_info.txt'), 'w') as f:
        f.write(train_info)


学習中に生成されたデータやモデルの状態を保存する関数を定義します。

#学習中に生成されたデータやモデルの状態を保存する関数
def save_data(epoch, train_log):
    torch.save(G_A.state_dict(), checkpoint_path('G_A.pth'))#Generator G_Aの状態を保存
    torch.save(G_B.state_dict(), checkpoint_path('G_B.pth'))#Generator G_Bの状態を保存
    torch.save(D_A.state_dict(), checkpoint_path('D_A.pth'))#Discriminator D_Aの状態を保存
    torch.save(D_B.state_dict(), checkpoint_path('D_B.pth'))#Discriminator D_Bの状態を保存
    #トレーニングログの保存
    if epoch == 1 and epoch_from == 1:
        write_mode = 'w' #新規データに書き込み
    else:
        write_mode = 'a' #既存データに追加
    #学習ログの保存
    with open(checkpoint_path('train_log.txt'), write_mode) as f:
        f.write(train_log)
    #epoch数の保存
    with open(checkpoint_path('epoch.txt'), 'w') as f:
        f.write(str(epoch))
    print("Saved.")



保存されたモデルの状態を読み込む関数を定義します。

#保存されたモデルの状態の読み込み
def load_data(flag):
    global epoch_from #関数の外で定義した変数を関数の中で変更する
    #Falseの場合、モデルをリセットする場合、終了する
    if not flag:
        rs = input('Do you reset the model? (Y/n)')
        if rs == 'Y':
            epoch_from = 1
            return
    with open(checkpoint_path('epoch.txt')) as f:
        epoch_count = f.read()

    #GeneratorとDiscriminatorのモデルの状態を読み込み
    G_A.load_state_dict(torch.load(checkpoint_path('G_A.pth')))
    G_B.load_state_dict(torch.load(checkpoint_path('G_B.pth')))
    D_A.load_state_dict(torch.load(checkpoint_path('D_A.pth')))
    D_B.load_state_dict(torch.load(checkpoint_path('D_B.pth')))

    epoch_from = int(epoch_count) + 1

    print("Loaded.")


学習のステップを定義します。

#学習ステップの定義
fake_A = fake_B = None #関数の中でエラーがあった場合でも、関数の外でグローバルな変数の値を確認できる
def train_batch(num_iter, batch):
    # debug
    global fake_A , fake_B #関数の外で定義し変数を、関数の中で書き換えられるようにしている
    # Set model input
    real_A = batch[0][0].cuda() #本物画像Aがインデックス0,後ろの0はデータセットの画像を表す。
    real_B = batch[1][0].cuda() #本物画像Bがインデックス1,後ろの0はデータセットの画像を表す。
    ### Generator ###
    optimizer_G.zero_grad() #損失勾配をゼロにする(勾配の初期化)
    # 敵対的損失(Adversarial Loss)
    fake_A = G_A(real_B) #GeneratorAに本物画像Bを入れて偽物画像Aを作る
    pred = D_A(fake_A) #偽物画像AをDiscriminatorAに入れて本物らしさの確率を出す
    loss_gan_A = criterion_gan(pred, target_real) #偽物画像Aと本物画像を比べたときの平均二乗誤差
    fake_B = G_B(real_A) #GeneratorBに本物画像Aを入れて偽物画像Bを作る
    pred = D_B(fake_B) #偽物画像BをDiscriminatorBに入れて本物らしさの確率を出す
    loss_gan_B = criterion_gan(pred, target_real) #偽物画像Bと本物画像の誤差
    # サイクル一貫性損失(Cycle Consistency loss)
    rec_A = G_A(fake_B) #GeneroatorAに偽物画像Bを入れて再構成画像Aを作る
    loss_cyc_A = criterion_cyc(rec_A, real_A) * lambda_A #再構成画像Aと本物画像Aの誤差(L1ノルム),係数Lambda_A
    rec_B = G_B(fake_A) #GeneroatorBに偽物画像Aを入れて再構成Bを作る
    loss_cyc_B = criterion_cyc(rec_B, real_B) * lambda_B #再構成画像Bと本物画像Bの誤差(L1ノルム),係数Lambda_B
    # 自己同一性損失(Identity loss)
    id_A = G_A(real_A) #GeneratorAに本物画像Aをいれて自己同一性画像Aを作る
    loss_id_A = criterion_id(id_A, real_A) * lambda_A * lambda_id #自己同一性画像Aを本物画像Aと比較して誤差を得る,係数LambdaAを掛ける,加えてLabda_idを掛けて割り引く
    id_B = G_B(real_B) #GeneratorBに本物画像Bをいれて自己同一性Bを作る
    loss_id_B = criterion_id(id_B, real_B) * lambda_B * lambda_id #自己同一性画像Aを本物画像Bと比較して誤差を得る,係数LambdaBを掛ける,加えてLabda_idを掛けて割り引く
    # GeneratorのTotal loss
    loss_G = loss_gan_A + loss_gan_B + loss_cyc_A + loss_cyc_B + loss_id_A + loss_id_B #敵対的損失A,B + サイクル一貫性損失A,B + 自己同一性損失A,Bの総和
    # Generatorのパラメータ更新
    loss_G.backward() #Generotorの誤差逆伝播
    optimizer_G.step() #Generotorのパラメータの更新
    ### Discriminator ###
    # In practice, we divide the objective by 2 while optimizing D,
    #  which slows down the rate at which D learns, relative to the rate of G.
    optimizer_D.zero_grad() #勾配を初期化する
    # Sampling from buffers
    fake_A = fake_A_buffer(fake_A) #生成した偽物画像Aを50%の確率で再利用する
    fake_B = fake_B_buffer(fake_B) #生成した偽物画像Bを50%の確率で再利用する
    # Discriminator A
    # Real loss
    pred = D_A(real_A) #DiscriminatorAに本物画像Aを入れて予測結果を求める
    loss_D_real = criterion_gan(pred, target_real) #予測結果と真のラベルの誤差を求める
    # Fake loss
    pred = D_A(fake_A.detach()) #DiscriminatorAに偽物画像Aを入れて予測結果を求める, detach()はTensor情報を切り離してreal_Aと同じ状態にする
    loss_D_fake = criterion_gan(pred, target_fake) #予測結果と真のラベルの誤差を求める
    # DiscriminatorAのTotal loss
    loss_D_A = (loss_D_real + loss_D_fake) * 0.5 #Discriminatorが強く学習しすぎないように、Generatorとバランスさせる。0.5をかけて重み付け。
    # DiscriminatorAの誤差逆伝播
    loss_D_A.backward() 
    # DiscriminatorB
    # Real loss
    pred = D_B(real_B) #DiscriminatorBに本物画像Bを入れて予測結果を求める
    loss_D_real = criterion_gan(pred, target_real)  #予測結果と真のラベルの誤差を求める
    # Fake loss
    pred = D_B(fake_B.detach()) #DiscriminatorBに偽物画像Aを入れて予測結果を求める, detach()はTensor情報を切り離してreal_Bと同じ状態にする
    loss_D_fake = criterion_gan(pred, target_fake) #予測結果と真のラベルの誤差を求める
    # Total loss
    loss_D_B = (loss_D_real + loss_D_fake) * 0.5 #Discriminatorが強く学習しすぎないように、Generatorとバランスさせる。0.5をかけて重み付け。
    # DiscriminatorBの誤差逆伝播
    loss_D_B.backward() 
    #DiscriminatorAとBを併せてパラメーター更新
    optimizer_D.step() 
    # 学習過程の情報
    loss_info = {'G': '%.4f' % loss_G.item(),
            'D': '%.4f' % (loss_D_A.item() + loss_D_B.item()),
            'G_id': '%.4f' % (loss_id_A.item() + loss_id_B.item()),
            'G_gan': '%.4f' % (loss_gan_A.item() + loss_gan_B.item()),
            'G_cyc': '%.4f' % (loss_cyc_A.item() + loss_cyc_B.item()),
            'D_A': '%.4f' % loss_D_A.item(),
            'D_B': '%.4f' % loss_D_B.item()
            }
    return loss_info


クラスのインスタンス化


モデルをインスタンス化し、重みの初期化設定を適用します。

# Generator,Discriminatorのインスタンス化
G_A = Generator().cuda()
G_B = Generator().cuda()
D_A = Discriminator().cuda()
D_B = Discriminator().cuda()

# 重みの初期化設定をモデルに適用する
G_A.apply(weights_init)
G_B.apply(weights_init)
D_A.apply(weights_init)
D_B.apply(weights_init)

# ReplayBufferのインスタンス化
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()


損失関数をインスタンス化します。

敵対的損失の計算においては、平均二乗誤差(MSE)を使用しています。通常使用されるBinaryCrossEntropy(BCE)と比べ、MSEのほうがGeneratorとDiscriminatorの乖離が小さく、学習が安定しやすいと言われています。


サイクル一貫性損失、自己同一性損失の計算においては、L1ノルムを使用します。L1ノルムを使うことで、ピクセル間の差の絶対値として表現し、小さな誤差にも対応します。

#損失関数のインスタンス化
criterion_gan = torch.nn.MSELoss() #敵対的損失,平均二乗誤差
criterion_cyc = torch.nn.L1Loss() #サイクル一貫性損失,L1ノルム
criterion_id = torch.nn.L1Loss() #自己同一性損失,L1ノルム

#Coefficient λ
lambda_A = 10.0 #サイクル一貫性損失、自己同一性損失の重みを調整する係数
lambda_B = 10.0 #サイクル一貫性損失、自己同一性損失の重みを調整する係数
lambda_id = 0.5 #自己同一性損失の重みを調整する係数


敵対的損失を計算するためのターゲットを設定します。

#GeneratorとDiscriminatorの学習中に、敵対的損失を計算するために使用するターゲット
target_real = torch.ones([1,1]).cuda() #本物画像を表す目標テンソル
target_fake = torch.zeros([1,1]).cuda() #偽物画像を表す目標テンソル


ハイパーパラメータの設定と学習の実行

論文を参考に学習率、エポック、学習率の減衰について指定しています。

#ハイパーパラメータ
learning_rate = 0.0002 #学習率0.0002
epochs = 200 #200
epoch_from = 1 #1
decay_epoch = 100 # 100 学習率の減衰(何エポック目から減衰させるか)


保存データの利用確認をします。

#保存データの利用確認
load_data(False) #保存データを利用するか -> Falseの場合:保存データのロードをスキップする。 Trueの場合は、続きの開始エポック数が表示される。
epoch_from

cyclegan15


オプティマイザと学習スケジューラーの設定をします。

#オプティマイザ
optimizer_G = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=learning_rate, betas=(0.5, 0.999)) #Adam, GeneratorAとGeneratorBのパラメータを結合してオプティマイザに渡す。betas=勾配とその二乗の移動平均を計算するために使用される係数
optimizer_D = torch.optim.Adam(itertools.chain(D_A.parameters(), D_B.parameters()), lr=learning_rate, betas=(0.5, 0.999)) #Adam, DiscriminatorAとDiscriminatorBのパラメータを結合してオプティマイザに渡す。

#学習率スケジューラ
G_lr_scheduler = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(epochs, epoch_from, decay_epoch).step) #エポックごとに学習率を変化させるスケジューラ。
D_lr_scheduler = lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(epochs, epoch_from, decay_epoch).step)


学習に関する情報を保存します。

学習に関する情報を保存
save_train_info()


学習を実行します。

#Train loop
import time
start_time_total = time.time()
for epoch in range(epoch_from, epochs + 1):
    start_time = time.time()
    lr = optimizer_G.param_groups[0]['lr'] #現在の学習率係数の表示
    train_log = "learning_rate: " + str('%07f\n' %lr) #学習係数ログ
    print(train_log)
    for i, batch in enumerate(zip(loaders_A, loaders_B)):#バッチごとに損失情報を表示
        loss_info = train_batch(i, batch)
        if i % 100 == 0: #100回に1度だけ表示する
            batch_log = make_log(epoch, i, loss_info)
            print(batch_log)
            train_log += batch_log + '\n'
    G_lr_scheduler.step() #Generatorの学習率を更新
    D_lr_scheduler.step() #Discriminatorの学習率を更新
    save_data(epoch, train_log)    
    end_time = time.time()
    execution_time = (end_time - start_time)  / 60
    print("実行時間: {}分".format(execution_time))
end_time_total = time.time()
execution_time_total = (end_time_total - start_time_total) / 60
print("実行時間: {}分".format(execution_time_total)) 


cyclegan16

(↑GPUSOROBAN nvd4-80-1インスタンスでの学習時間:1epoch平均2.7分, 200epoch合計で約9時間)


GPUSOROBANのnvd4-80-1インスタンスで、200epochの学習で約9時間(550分)ほどの時間がかかりました。1epochあたりの学習時間は平均2.7分になります。


一方でGoogle ColaboratoryのNVIDIA T4インスタンスでは、約3時間を経過した19epochの段階で自動的にランタイムが切断され、途中停止していました。Google Colaboratoryでは1epochあたりの学習時間は平均11.4分になりましたので、単純計算でGPUSOROBANが4倍速い結果になりました。

cyclegan17

cyclegan18

(↑Google ColaboratoryのNVIDIA T4インスタンスの学習時間:1epoch平均11.4分, 200epoch合計は不明 ※学習が途中で停止していたため)


学習済みモデルのテスト

ここからは学習済みモデルのテストを行います。


テスト用データの前処理およびデータローダーでデータセットを読み込みます。

#テスト用データの前処理およびデータローダーでデータセットを読み込み
mean, std = 0.5, 0.5

#Pytorch dataloader
data_transforms = transforms.Compose([
                transforms.Resize(256), 
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
                ])
datasets_A = datasets.ImageFolder(os.path.join('datasets', dataname, 'testA'), data_transforms)#testAの読み込み
loaders_A = torch.utils.data.DataLoader(datasets_A, batch_size=1, shuffle=True, num_workers=1)


生成した画像を保存する関数を定義します。

#生成した画像を保存する関数
def im_save(image, num, mode='fake'):
    im = image.detach().numpy() #テンソルをNumpy配列に変換
    im = im.transpose(1, 2, 0) # c h w -> h w c
    im = std * im + mean #画像の正規化
    im *= 256 #画像のピクセル値を[0, 255]の範囲にスケーリング
    
    im = im[:, :, [2, 1, 0]] #GBRからRGBに変換する
    filepath = os.path.join('output', dataname, 'img', mode + '_%06d.png' %num)
    # print(filepath)
    cv2.imwrite(filepath, im)


学習済みモデルのデータをロードします。

#学習済みモデルのデータロード
def load_model_data(flag):
    with open(checkpoint_path('epoch.txt')) as f:
        epoch_count = f.read()
    G_A.load_state_dict(torch.load(checkpoint_path('G_A.pth')))
    G_B.load_state_dict(torch.load(checkpoint_path('G_B.pth')))
    print("epoch %d Loaded." %int(epoch_count))


load_model_dataの関数を実行します。

load_model_data(True)


馬(A)を学習済みモデルにインプットして、シマウマ(B)に変換されるかepochごとに確認していきます。


10epochまで学習したモデルでは、部分的にシマウマに変換されていることが分かります。

cyclegan20

cyclegan21
50epochまでの学習したモデルでは、全体的にシマウマの模様に変換されましたが、変換前の馬の色味が一部残っています。

cyclegan23

cyclegan25


200epochまでの学習では、全体的にシマウマに変換され、元の馬の色味もほぼ残っていません。


cyclegan28

cyclegan29

CycleGANによるスタイル画像変換の説明は以上になります。

本環境には、GPUSOROBAN(高速コンピューティング)のインスタンスを使用しました。
高速コンピューティングは、高性能なGPUインスタンスが低コストで使えるクラウドサービスです。
サービスについて詳しく知りたい方は、公式サイトをご覧ください。

MORE INFORMATION

GPUでお困りの方はGPUSOROBANで解決!
お気軽にご相談ください

10日間無料トライアル
詳しい資料はこちら
質問・相談はこちら