【画像生成AI】DCGANで手書き文字の生成

本記事では、GPUSOROBANのインスタンスを使った、DCGANによる画像生成の例を紹介します。
GPUSOROBANは高性能なGPUインスタンスが低コストで使えるクラウドサービスです。
サービスについて詳しく知りたい方は、GPUSOROBANの公式サイトを御覧ください。


目次[非表示]

  1. 1.DCGANとは
  2. 2.環境構築
  3. 3.学習に使う手書き文字画像の確認
  4. 4.各設定、データの前処理
  5. 5.Generatorのモデル構築
  6. 6.Discriminatorのモデル構築
  7. 7.画像の生成・表示
  8. 8.正解数の定義
  9. 9.学習の実行
  10. 10.誤差と正解率の推移


DCGANとは


DCGANとは、Generator(生成器)とDiscriminator(識別器)の2つのモデルを互いに競わせるように学習して、生成を行うGANの派生モデルになります。
Generatorは、ランダムなノイズを入力として、Discriminatorが本物と誤認しするデータを生成できるように学習します。一方でDiscriminatorは、本物のデータとGeneratorが生成した偽物のデータを正しく識別できるように学習します。
下図は、GeneratorとDiscriminatorがどのように学習をしているかを示す概念図になります。



GeneratorとDiscriminator



下図は本記事で実施する、手書き文字の画像を生成するモデルの構成になります。
DCGANの特徴として、Generatorに逆畳み込み層、Discriminatorに畳み込み層を使うことで、GANよりも自然な画像の生成をすることができます。


Generatorモデル


環境構築


環境はGPUSOROBANのインスタンスを使用します。GPUSOROBANは、高性能なGPUインスタンスが格安で使えるクラウドサービスです。インスタンスの作成方法、接続方法はこちらの記事を御覧ください。
インスタンス起動後、PyTorch、matplotlib、scikit-learn、JupyterLabの4つのライブラリをインストールします。
PyTorchのインストールについては、こちらの記事をご参照ください。
matplotlibとscikit-learnについては、下記のコマンドでインストールします。


pip install matplotlib scikit-learn


JupyterLabを使用する場合は、こちらの記事を参考にJupyterLabのインストールから起動までを実行してください。


学習に使う手書き文字画像の確認


DCGANに用いる学習用のデータを用意します。
scikit-learnから、8×8の手書き数字の画像データを読み込んで表示します。


import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
digits_data = datasets.load_digits()
n_img = 10  # 表示する画像の数
plt.figure(figsize=(10, 10))
for i in range(n_img):
    # 入力画像
    ax = plt.subplot(16, 16, i+1)
    plt.imshow(digits_data.data[i].reshape(8, 8), cmap="Greys_r")
    ax.get_xaxis().set_visible(False)  # 軸を非表示
    ax.get_yaxis().set_visible(False)
plt.show()
print("データの形状:", digits_data.data.shape)
print("ラベル:", digits_data.target[:n_img])


次ような画像が表示されます。こちらが学習で使うデータになります。


学習データ


各設定、データの前処理


DCGANに必要な各種パラメータの設定から、データの読み込み、前処理を行います。


import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import torch
from torch.utils.data import DataLoader
# 各設定値
img_size = 8  # 画像の高さと幅
n_noise = 64  # ノイズの数を指定
# 各相の設定値はclassで行う
eta = 0.001  # 学習係数
epochs = 200  # 学習回数
interval = 20  # 経過の表示間隔
batch_size = 16
# 学習データの読み込み、前処理
digits_data = datasets.load_digits() # 学習データの読み込み
x_train = np.asarray(digits_data.data) # numpyの配列に変換
x_train = x_train / 16*2-1  # 学習データの範囲を-1から1の範囲に指定(Generator出力のtanhに合わせるため)
t_train = digits_data.target # 手書き文字のラベルと取り出す
x_train = torch.tensor(x_train, dtype=torch.float) # 学習データをPytorchのテンソルに変換
train_dataset = torch.utils.data.TensorDataset(x_train) # データセットの設定
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # データローダの設定


Generatorのモデル構築


Pytorchのnnモジュールを使って、Generatorのモデルを構築します。
Generatorでは逆畳み込み層を3層重ねた構成で、ノイズから画像を生成します。
逆畳み込み層はPytorchのConvTranspose2dにより実装します。
出力層の活性化関数には、Discriminatorへの入力を-1から1の範囲にするためにtanhを使います。


import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
    def __init__(self):
        super().__init__() # 逆畳み込み層の初期設定
        # 入力画像1x1, カーネル3x3  → 出力画像3x3
        self.convt_1 = nn.ConvTranspose2d(n_noise, 64, 3)  # 引数(入力のチャンネル数,出力のチャンネル数,カーネルのサイズ)
        # 入力画像3x3, カーネル3x3 → 出力画像5x5
        self.convt_2 = nn.ConvTranspose2d(64, 32, 3)
        # 入力画像5x5, カーネル4x4 → 出力画像8x8 
        self.convt_3 = nn.ConvTranspose2d(32, 1, 4)
    # Generatorの順伝播
    def forward(self, x):
        x = x.view(-1, n_noise, 1, 1)  # 引数(バッチサイズ(自動), チャンネル数, 高さ, 幅)
        x = F.relu(self.convt_1(x))
        x = F.relu(self.convt_2(x))
        x = torch.tanh(self.convt_3(x)) # nn.functionalモジュールのtanhは非推奨のため,torchを使用
        return x
generator = Generator()
generator.cuda()  # GPU対応
print(generator)


Discriminatorのモデル構築


PyTorchのnnモジュールを用いて、Discriminatorのモデルを構築します。
Discriminatorでは、畳込み層を3層重ねて画像の特徴を抽出します。
最後の層の活性化関数には、0から1までの値で本物かどうかを識別するためにsigmoid関数を使います。
逆伝播での勾配消失問題に対処するために、活性化関数にLeakyReLUを使用しています。通常のReLUでは負の入力で、0が出力されるため、微分ができず勾配消失に陥る可能性があります。LeakyReLUは負の入力に対し、微小な負の値を出力することができます。微分値が常に0にならないので、勾配消失問題の対処が可能です。


import torch.nn as nn
import torch.nn.functional as F
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__() # 畳み込み層の初期設定
        # 入力画像8x8,カーネル4x4 -> 出力画像 5x5
        self.conv_1 = nn.Conv2d(1, 16,  4)  # 引数(入力のチャンネル数,出力のチャンネル数,カーネルのサイズ)
        # 入力画像5x5,カーネル3x3 -> 出力画像 3x3
        self.conv_2 = nn.Conv2d(16, 32, 3)
        # 入力画像3x3,カーネル3x3 -> 出力画像 1x1
        self.conv_3 = nn.Conv2d(32, 1, 3)
    # Discriminatorの順伝播
    def forward(self, x):
        x = x.view(-1, 1, img_size, img_size)  # 画像の形状に整形 引数(バッチサイズ, チャンネル数, 高さ, 幅)
        x = F.leaky_relu(self.conv_1(x), negative_slope=0.2)# LeakyRelu,negative_slopeは負の領域での傾き
        x = F.leaky_relu(self.conv_2(x), negative_slope=0.2)
        x = torch.sigmoid(self.conv_3(x)) # nn.functionalモジュールのsigmoidは非推奨のため,torchを使用
        x = x.view(-1, 1)  # 引数(バッチサイズ(自動), 出力の数)
        return x
discriminator = Discriminator()
discriminator.cuda()  # GPU対応
print(discriminator)


画像の生成・表示


画像を生成して表示するための関数を定義します。
画像は、学習済みのGenertorにノイズを入力することで生成されます。
画像は16×16枚生成されますが、並べて一枚の画像にした上で表示されます。


# 画像を生成して表示
def generate_images(i):
    # 画像の生成
    n_rows = 16  # 行数
    n_cols = 16  # 列数
    noise = torch.randn(n_rows * n_cols, n_noise).cuda() # 正規分布に従った乱数を生成
    g_imgs = generator(noise)
    g_imgs = g_imgs/2 + 0.5  # 0-1の範囲にする(元のtanhが-1から1の範囲であり、nupmyの画像表示するため)
    g_imgs = g_imgs.cpu().detach().numpy()
    img_size_spaced = img_size + 2
    matrix_image = np.zeros((img_size_spaced*n_rows, img_size_spaced*n_cols))  # 全体の画像
    #  生成された画像を並べて一枚の画像にする
    for r in range(n_rows):
        for c in range(n_cols):
            g_img = g_imgs[r*n_cols + c].reshape(img_size, img_size)
            top = r*img_size_spaced # 画像を配置する位置
            left = c*img_size_spaced # 画像を配置する位置
            matrix_image[top : top+img_size, left : left+img_size] = g_img
    plt.figure(figsize=(8, 8))
    plt.imshow(matrix_image.tolist(), cmap="Greys_r", vmin=0.0, vmax=1.0)
    plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)  # 軸目盛りのラベルと線を消す
    plt.show() # 画像の表示


正解数の定義


Discriminatorによる識別の正解数を、カウントする関数を定義します。
Discriminatorの精度の計算に使用します。


def count_correct(y, t):
    correct = torch.sum((torch.where(y<0.5, 0, 1) ==  t).float()) # yが0.5より小さい場合は0
    return correct.item() #torchのテンソルから、pythonのスカラー値に変換


学習の実行


構築したDCGANのモデルを使って、学習を行います。
Generatorが生成した偽物の画像には正解ラベル0、本物の画像には正解ラベル1を与えてDiscriminatorを学習します。その後にGeneratorを学習しますが、この場合の正解ラベルは1になります。
損失関数には、二値の交差エントロピー誤差を使用し、オプティマイザーにはAdamを使用しています。


from torch import optim
# 二値の交差エントロピー誤差関数
loss_func = nn.BCELoss()
# Adam generatorとdiscriminatorで別々のオプティマイザーを使う
optimizer_gen = optim.Adam(generator.parameters())
optimizer_disc = optim.Adam(discriminator.parameters())
# ログ
error_record_fake = []  # 偽物画像の誤差記録
acc_record_fake = []  # 偽物画像の精度記録
error_record_real = []  # 本物画像の誤差記録
acc_record_real = []  # 本物画像の精度記録
# DCGANの学習
generator.train() #generatorの学習モード
discriminator.train() #discrimnatorの学習モード
for i in range(epochs):
    loss_fake = 0 # 偽物を入れたときの誤差
    correct_fake = 0 # 偽物を入れたときの正解数
    loss_real = 0 # 本物を入れたときの誤差
    correct_real = 0 # 本物をいれたときの正解数
    n_total = 0 # データの総数(精度の計算に使用)
    for j, (x,) in enumerate(train_loader):  # ミニバッチ(x,)を取り出す
        n_total += x.size()[0]  # バッチサイズを累積
        # ノイズから画像を生成しDiscriminatorを学習
        noise = torch.randn(x.size()[0], n_noise).cuda()
        imgs_fake = generator(noise)  # 画像の生成
        t = torch.zeros(x.size()[0], 1).cuda()  # 正解は0(偽物が0)
        y = discriminator(imgs_fake) # discriminatorの出力
        loss = loss_func(y, t) # 誤差の計算
        optimizer_disc.zero_grad() # 勾配のリセット
        loss.backward()
        optimizer_disc.step()  # Discriminatorのみパラメータを更新
        loss_fake += loss.item()
        correct_fake += count_correct(y, t)
        # 本物の画像を使ってDiscriminatorを学習
        imgs_real= x.cuda()
        t = torch.ones(x.size()[0], 1).cuda()  # 正解は1(本物が1)
        y = discriminator(imgs_real)
        loss = loss_func(y, t)
        optimizer_disc.zero_grad()
        loss.backward()
        optimizer_disc.step()  # Discriminatorのみパラメータを更新
        loss_real += loss.item()
        correct_real += count_correct(y, t)
        # Generatorを学習
        noise = torch.randn(x.size()[0]*2, n_noise).cuda()  # バッチサイズを2倍にする(discrimnatorは本物と偽物で2回学習しているため)
        imgs_fake = generator(noise)  # 画像の生成
        t = torch.ones(x.size()[0]*2, 1).cuda()  # 正解は1(本物が1)
        y = discriminator(imgs_fake) 
        loss = loss_func(y, t)
        optimizer_gen.zero_grad()
        loss.backward()
        optimizer_gen.step()  # Generatorのみパラメータを更新
    loss_fake /= j+1  # 誤差
    error_record_fake.append(loss_fake)
    acc_fake = correct_fake / n_total  # 精度
    acc_record_fake.append(acc_fake)
    loss_real /= j+1  # 誤差
    error_record_real.append(loss_real)
    acc_real = correct_real / n_total  # 精度
    acc_record_real.append(acc_real)
    # 一定間隔で誤差と精度、および生成された画像を表示
    if i % interval == 0:
        print ("Epochs:", i)
        print ("Error_fake:", loss_fake , "Acc_fake:", acc_fake)
        print ("Error_real:", loss_real , "Acc_real:", acc_real)
        generate_images(i)


下図は、未学習(Epoch:0)時点の出力画像になります。完全なノイズで数字の形をしていません。

未学習(Epoch:0)時点の出力画像


下図は、学習途中(Epoch:20)時点の出力画像になります。若干数字のような形になってきています。

Discriminatorを学習時点の画像


下図は、学習完了(Epoch:200)時点の出力画像になります。正解データに近い画像が生成されています。


Discriminator学習後画像


参考までに正解データはこちらです。

Generator正解データ


誤差と正解率の推移


学習中における、誤差と正解率の推移を確認します。
Discriminatorに本物画像を識別した際の誤差の推移と、偽物画像の識別した際の誤差の推移をグラフに表示します。併せて正解率の推移も表示します。


# 誤差の推移
plt.plot(range(len(error_record_fake)), error_record_fake, label="Error_fake")
plt.plot(range(len(error_record_real)), error_record_real, label="Error_real")
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()
# 正解率の推移
plt.plot(range(len(acc_record_fake)), acc_record_fake, label="Acc_fake")
plt.plot(range(len(acc_record_real)), acc_record_real, label="Acc_real")
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.show()


誤差の推移

Discriminator使用後の偽物と本物の誤差の推移グラフ


正解率の推移

Discriminatorに本物画像に識別した際の正解率の推移


DCGANで画像生成をして、GeneratorとDiscriminatorが競合するように学習し、その結果生じた均衡のなかで、本物らしい画像が形作られていくことが確認できました。
本環境には、GPUSOROBANのインスタンスを使用しました。
GPUSOROBANは高性能なGPUインスタンスが低コストで使えるクラウドサービスです。
サービスについて詳しく知りたい方は、GPUSOROBANの公式サイトを御覧ください。




MORE INFORMATION

GPUでお困りの方はGPUSOROBANで解決!
お気軽にご相談ください

10日間無料トライアル
詳しい資料はこちら
質問・相談はこちら
ページトップへ戻る