简单使用PyTorch搭建GAN模型

2014年,Goodfellow等人则提出生成对抗收集(Generative Adversarial Network, GAN),可以让我们完全依靠呆板进修来生成极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图象生成领域发生了巨变。本文将带大家了解GAN的工作原理,并介绍如何通过PyTorch简单上手GAN。

作者|Ta-Ying Cheng,牛津大学博士研究生,Medium技术博主,多篇文章均被平台官方刊物Towards Data Science收录

翻译|颂贤

以往人们普遍认为生成图象是不可能完成的任务,因为按照传统的呆板进修思路,我们根本没有真值(ground truth)可以拿来检验生成的图象是否合格。

2014年,Goodfellow等人则提出生成对抗收集(Generative Adversarial Network, GAN),可以让我们完全依靠呆板进修来生成极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图象生成领域发生了巨变。

本文将带大家了解GAN的工作原理,并介绍如何通过PyTorch简单上手GAN

GAN的原理

按照传统的方法,模型的预测结果可以直接与已有的真值进行比较。然而,我们却很难定义和衡量到底怎样才算作是“正确的”生成图象。

Goodfellow等人则提出了一个有趣的解决办法:我们可以先训练好一个分类工具,来自动区分生成图象和真正图象。这样一来,我们就可以用这个分类工具来训练一个生成收集,直到它可以输入完全以假乱真的图象,连分类工具自己都没有办法评判真假。 简单使用PyTorch搭建GAN模型 按照这一思路,我们便有了GAN:也就是一个生成器(generator)和一个辨别器(discriminator)。生成器负责根据给定的数据集生成图象,辨别器则负责区分图象是真是假。GAN的运作过程如上图所示。

丧失函数

在GAN的运作过程中,我们可以发现一个明显的矛盾:同时优化生成器和辨别器是很困难的。可以想象,这两个模型有着完全相反的宗旨:生成器想要尽可能伪造出真正的东西,而辨别器则必须要识破生成器生成的图象。

为了说明这一点,我们设D(x)为辨别器的输入,即x是真正图象的概率,并设G(z)为生成器的输入。辨别器类似于一种二进制的分类器,所以其宗旨是使该函数的结果最大化:简单使用PyTorch搭建GAN模型 这一函数本质上是非负的二元交叉熵丧失函数。另一方面,生成器的宗旨是最小化辨别器做出正确判断的机率,因此它的宗旨是使上述函数的结果最小化。

因此,最终的丧失函数将会是两个分类器之间的极小极大博弈,表示如下: 简单使用PyTorch搭建GAN模型 理论上来说,博弈的最终结果将是让辨别器判断成功的概率收敛到0.5。然而在实践中,极大极小博弈通常会导致收集不收敛,因此仔细调整模型训练的参数异常重要。

在训练GAN时,我们尤其要注意进修率等超参数,进修率比较小时能让GAN在输入乐音较多的情况下也能有较为统一的输入。

计算环境

本文将指导大家通过PyTorch搭建整个程序(包括torchvision)。同时,我们将会使用Matplotlib来让GAN的生成结果可视化。以下代码可以导入上述所有库:

"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

数据集

数据集对于训练GAN来说异常重要,尤其考虑到我们在GAN中处理的通常是非结构化数据(一般是图片、视频等),任意一class都可以有数据的分布。这种数据分布恰恰是GAN生成输入的基础。

为了更好地演示GAN的搭建过程,本文将带大家使用最简单的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。

像MNIST这样高质量的非结构化数据集都可以在格物钛的公开数据集网站上找到。事实上,格物钛Open Datasets平台涵盖了很多优质的公开数据集,同时也可以实现数据集托管及一站式搜索的功能,这对AI开发者来说,是相当实用的社区平台。 简单使用PyTorch搭建GAN模型

硬件需求

一般来说,虽然可以使用CPU来训练神经收集,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。我们可以用下面的代码来测试自己的呆板能否用GPU来训练:

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

实现

收集结构

由于数字是异常简单的信息,我们可以将辨别器和生成器这两层结构都组建成全连接层(fully connected layers)。

我们可以用以下代码在PyTorch中搭建辨别器和生成器: 

"""
Network Architectures
The following are the discriminator and generator architectures
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    x = x.view(-1, 1, 28, 28)
    return nn.Tanh()(x)

训练

在训练GAN的时候,我们需要一边优化辨别器,一边改进生成器,因此每次迭代我们都需要同时优化两个互相矛盾的丧失函数。

对于生成器,我们将输入一些随机乐音,让生成器来根据乐音的微小改变输入的图象:

"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

结果

经过100个训练时期之后,我们就可以对数据集进行可视化处理,直接看到模型从随机乐音生成的数字: 简单使用PyTorch搭建GAN模型 我们可以看到,生成的结果和真正的数据异常相像。考虑到我们在这里只是搭建了一个异常简单的模型,实际的使用效果会有异常大的上升空间。

不仅是有样学样

GAN和以往呆板视觉专家提出的想法都不一样,而利用GAN进行的具体场景使用更是让许多人赞叹深度收集的无限潜力。下面我们来看一下两个最为出名的GAN延申使用。

CycleGAN

朱俊彦等人2017年发表的CycleGAN可以在没有配对图片的情况下将一张图片从X域直接转换到Y域,比如把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换CycleGAN都能轻松做到,并且结果异常准确。 简单使用PyTorch搭建GAN模型

GauGAN

英伟达则通过GAN让人们可以只需要寥寥数笔勾勒出自己的想法,便能得到一张极为逼真的真正场景图片。虽然这种使用需要的计算成本极为高昂,但是GauGAN凭借它的转换能力探索出了前所未有的研究和使用领域。

简单使用PyTorch搭建GAN模型

结语

相信看到这里,你已经知道了GAN的大致工作原理,并且可以自己动手简单搭建一个GAN了。

给TA打赏
共{{data.count}}人
人已打赏
AI

产业实践推动科技创新,京东科技集团3篇论文当选ICASSP 2021

2021-8-25 14:47:00

AI

UC伯克利教授Pieter Abbeel开课了:六节课初学「深度加强进修」,讲义免费下载

2021-8-26 14:01:00

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
今日签到
搜索