본문 바로가기
AI/기술

GAN[Generative Adversarial Network] 구현하기 w/ Pytorch

by ai.forme 2020. 11. 30.
반응형

2014년에 발표된 GAN은 이미 너무나 많은 갈래가 생겼습니다. 한 갈래인 CycleGAN에 대해 알게 되었는데, 정말 신기하다는 생각이 들었습니다.

그래서 이왕 궁금한 거, GAN부터 구현해보는 생각을 하게 되었습니다.

 

과제가 쌓여있지만... 일단... 하고 싶은 것부터 하련다..

 

구현 코드는Pytorch-GAN repo를 참고하였습니다.

GAN의 논문은 누구나 읽어볼 가치가 있습니다!

 

 

어떤 것을 만들어볼까?


누구나 하는 MNIST로 구현하자기엔 너무 지겨워, kaggle에서 구한 데이터들로 진행하였습니다.

총 8가지로 비행기, 고양이, 차, 강아지, 꽃, 과일, 오토바이, 사람입니다.

parser.add_argument("--object", type=str, default='person', help='which object to generate')

이 부분에서 object를 쉽게 바꿀 수 있게 설정해놓았습니다. 사실 차원을 추가하여 한 번에 모든 종류의 사진들이 생성되도록 하고 싶었는데,

Training Set의 개수가 적어 학습이 잘 이루어지지 않았습니다.

 

구조


 

  • data : 이미지들이 들어갈 폴더 (총 8가지 내부 폴더로 구성)
  • images : 생성된 이미지들이 저장될 폴더 (코드를 돌렸을 때 폴더가 생성되어 그 안에 저장)
  • trained_models : 훈련된 모델들이 저장될 폴더 (코드를 돌렸을 때 폴더가 생성되어 그 안에 저장)
  • models : 모델들의 Class 파일이 저장될 폴더 (Generator, Discriminator)
  • utils : Training을 진행하는데 필요한 함수들이 저장될 폴더
  • main.py 실행 파일

Hyperparameter


config.yaml 파일을 만드려다가 좀 더 친숙한 argparse를 사용하기로 결정하였습니다.

실제로 직관적이기 때문에 사용하기 편하다고 생각합니다.

# Parsing Arguments
def parse_args():

    parser = argparse.ArgumentParser()
    
    parser.add_argument("--n_epochs", type=int, default=500, help="number of epochs for training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=5e-4, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=48, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
    parser.add_argument("--seed", type=int, default=777, help="seed number")
    parser.add_argument("--object", type=str, default='person', help='which object to generate')
    
    args = parser.parse_args()
    
    return args
    

 

Datasets & Dataloader


Data 폴더 구조

각각의 종류에 해당하는 이미지들은 imageSequence 폴더 안에 들어있습니다.

Dataloader은 torchvision의 dataset와 torch의 DataLoader을 사용하였습니다.

 

# torchvision
import torchvision
import torchvision.transforms as transforms

# torch
import torch
from torch.utils.data import DataLoader
from torchvision import datasets

def set_dataloader(args):

    dataset = datasets.ImageFolder(root='./data/{}'.format(args.object),
                                   transform=transforms.Compose([
                                   transforms.Resize((args.img_size, args.img_size)),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5], [0.5])
                                   ]))

    dataloader = DataLoader(dataset=dataset,
                            batch_size=args.batch_size,
                            shuffle=True
                            )
    
    return dataloader
    

 

Models (Generator & Discriminator)


GAN에서는 두 가지 모델이 사용됩니다.

 

Generator은 입력 벡터로부터 이미지를 생성해냅니다. 작명 참 잘했습니다.

Discriminator는 입력된 이미지가 진짜(Real)인지 , 가짜(Fake)인지 구별해냅니다.

 

1. Generator는 Discriminator을 속이도록

2. Discriminator은 Generator을 올바르게 판단하도록

 

GAN의 학습이 이루어집니다. 이런 일련의 과정을 반복하면서

Generator은 더욱더 진짜 같은 이미지를 생성하고, Discriminator은 판단력이 상승하게 됩니다.

 

경찰(Discriminator)과 위조지폐범(Generator)으로 많이 비유됩니다.

위조지폐범은 더욱 진짜 같은 지폐를 만들고, 경찰은 이를 더욱 잘 판별하고.

참 적절한 비유라고 생각합니다! (GAN을 사용하여 정말 진짜 같은 위조지폐를 만들 수 있을지도?)

 

Generator

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# numpy
import numpy as np

class Generator(nn.Module):

    def __init__(self, latent_dim, image_shape):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim
        self.image_shape = image_shape

        def block(input_fea, output_fea, normalize=True):
            layers =[nn.Linear(input_fea, output_fea)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_fea, 0.5))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(image_shape))),
            nn.Tanh()
        )
    
    def forward(self, z):
        image = self.model(z)
        image = image.view(image.size(0), *self.image_shape)
        return image
    

 

Discriminator

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# numpy
import numpy as np

class Discriminator(nn.Module):

    def __init__(self, image_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(image_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, image):
        image_flat = image.view(image.size(0), -1)
        validity = self.model(image_flat)
        return validity
        

 

 

Seed & Device & Loss Function & Optimizer


Seed를 설정 안 하여 매번 돌릴 때마다 결과가 달라져 곤혹을 한 두 번 겪고 나니 항상 seed를 설정하게 되었습니다.

실제로 설정 안 할 경우 불편합니다..

 

 Seed

import random
import numpy as np

# torch
import torch

def set_seed(args) :
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    

 

Device & Loss Function & Optimizer

def main():
    
    # Set seed & dataloader
    set_seed(args)
    dataloader = set_dataloader(args)

    # Set the shape of image [3, 64, 64]
    image_shape = (args.channels, args.img_size, args.img_size)

    # Set device where training will be progressed & base Tensor type (dtype=float32)
    if torch.cuda.is_available():
        device = torch.device("cuda: 0")
        Tensor = torch.cuda.FloatTensor
    else:
        device = torch.device("cpu") 
        Tensor = torch.FloatTensor

    print('Current Device : {} \t Base Tensor : {}'.format(device, Tensor))

    # Initialize Loss Function & Models
    criterion = torch.nn.BCELoss().to(device)
    generator = Generator(args.latent_dim, image_shape).to(device)
    discriminator = Discriminator(image_shape).to(device)

    # Initialize Optimizer for Generator & Discriminator
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    

 

Training


대망의 학습 함수 차례입니다.

사용되는 모델이 2개라, Optimizer가 2개인 것을 제외하고는 특별한 것이 없습니다.

다만 기존 ML 모델 학습과는 다르게, Eval의 과정이 없고 Train의 과정만 존재합니다.

 

GAN 모델의 Evaluation 방법은 여러 논문을 통해 발표되었지만, 아직 잘 모르니까

그냥 눈으로 판별합시다! 하핫.. 공부해서 돌아오겠습니다.

import os
import numpy as np

# torch
import torch
from torch.autograd import Variable

# torchvision
import torchvision
from torchvision.utils import save_image

# datetime
from datetime import datetime

def train(args, device, dataloader, criterion, generator, discriminator, optimizer_G, optimizer_D, Tensor):

    experiment_time = datetime.today().strftime("%Y%m%d_%H_%M")
    result_dir = 'images/{}'.format(experiment_time)
    model_dir = 'trained_models/{}'.format(experiment_time)

    os.makedirs(result_dir, exist_ok=False)
    os.makedirs(model_dir, exist_ok=False)

    for epoch in range(args.n_epochs):
        for idx, data in enumerate(dataloader):
            images, labels = data[0].to(device), data[1].to(device)

            # Adversarial ground truths
            valid = Variable(Tensor(images.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False)

            # Configure Input
            real_images = Variable(images.type(Tensor))

            ###################
            # Train Generator #
            ###################

            optimizer_G.zero_grad()

            # Sampel noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (images.size(0), args.latent_dim))))

            # Generate a batch of images
            gen_images = generator(z)

            # Loss measures generator's ability to fool the discriminator
            loss_G = criterion(discriminator(gen_images), valid)

            loss_G.backward()
            optimizer_G.step()

            #######################
            # Train Discriminator #
            #######################

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            loss_real = criterion(discriminator(real_images), valid)
            loss_fake = criterion(discriminator(gen_images.detach()), fake)
            loss_D = (loss_real + loss_fake) / 2

            loss_D.backward()
            optimizer_D.step()

            if idx%20 == 0 :
                print('[Epoch {:d}/{:d}] \t [Batch {:d}/{:d}] \t [Loss_G : {:.4f}] \t [Loss_D : {:0.4f}]'
                    .format(epoch, args.n_epochs, idx, len(dataloader), loss_G.item(), loss_D.item()))

            batches_done = epoch * len(dataloader) + idx

            if batches_done % args.sample_interval == 0:
                print('Save sample Image')
                save_image(gen_images.data[:25], '{}/{:d}.png'.format(result_dir, batches_done), nrow=5, normalize=True)
    
    print('Everything Done.. Saving Model')

    # Setting the Path to save model
    PATH_G = model_dir + '/generator.pth'
    PATH_D = model_dir + '/discriminator.pth'

    # Save Both Generator & Discriminator
    torch.save(generator.state_dict(), PATH_G)
    torch.save(discriminator.state_dict(), PATH_D)

 

이상으로 Pytorch를 사용한 GAN 구현이 끝났다.

 

Result


600 Epoch을 돌려서 만들어진 결과물입니다.

 

흠.. 뭔가 아직 학습이 많이 부족해 보이네요.

더 많은 Data를 확보하고 Fine Tuning 한 후에 다시 실험해야 될 것 같습니다.

어쨌든.. 사람 같아 보이기는 하잖아요..?

 

최근의 GAN은 정말 성능이 좋습니다. (사람의 눈으로는 구별 불가능할 정도..)

이것은 정말 애기의 장난 수준입니다! 그것도 안될지도..

반응형