简单使用PyTorch搭建GAN模型

2014年,Goodfellow等人则提出天生对抗网络(Generative Adversarial Network, GAN),能够让我们完全依靠机器学习来天生极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图象天生领域发生了巨变。本文将带大家了解GAN的工作原理,并介绍如何通过PyTorch简单上手GAN。

作者|Ta-Ying Cheng,牛津大学博士研究生,Medium技术博主,多篇文章均被平台官方刊物Towards Data Science收录

翻译|颂贤

以往人们普遍认为天生图象是不可能完成的任务,因为按照传统的机器学习思路,我们根本没有真值(ground truth)可以拿来检验天生的图象是否合格。

2014年,Goodfellow等人则提出天生对抗网络(Generative Adversarial Network, GAN),能够让我们完全依靠机器学习来天生极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图象天生领域发生了巨变。

本文将带大家了解GAN的工作原理,并介绍如何通过PyTorch简单上手GAN

GAN的原理

按照传统的方法,模型的预测结果可以直接与已有的真值进行比较。然而,我们却很难定义和衡量到底怎样才算作是“正确的”天生图象。

Goodfellow等人则提出了一个有趣的解决办法:我们可以先训练好一个分类工具,来自动区分天生图象和真实图象。这样一来,我们就可以用这个分类工具来训练一个天生网络,直到它能够输出完全以假乱真的图象,连分类工具自己都没有办法评判真假。 图 1. GAN的运作流程. 图源作者. 按照这一思路,我们便有了GAN:也就是一个天生器(generator)和一个辨别器(discriminator)。天生器负责根据给定的数据集天生图象,辨别器则负责区分图象是真是假。GAN的运作流程如上图所示。

损失函数

在GAN的运作流程中,我们可以发现一个明显的矛盾:同时优化天生器和辨别器是很困难的。可以想象,这两个模型有着完全相反的目标:天生器想要尽可能伪造出真实的东西,而辨别器则必须要识破天生器天生的图象。

为了说明这一点,我们设D(x)为辨别器的输出,即x是真实图象的概率,并设G(z)为天生器的输出。辨别器类似于一种二进制的分类器,所以其目标是使该函数的结果最大化:请添加图片描述 这一函数本质上是非负的二元交叉熵损失函数。另一方面,天生器的目标是最小化辨别器做出正确判断的机率,因此它的目标是使上述函数的结果最小化。

因此,最终的损失函数将会是两个分类器之间的极小极大博弈,表示如下: 请添加图片描述 理论上来说,博弈的最终结果将是让辨别器判断成功的概率收敛到0.5。然而在实践中,极大极小博弈通常会导致网络不收敛,因此仔细调整模型训练的参数非常重要。

在训练GAN时,我们尤其要注意学习率等超参数,学习率比较小时能让GAN在输入噪音较多的情况下也能有较为统一的输出。

计算环境

本文将指导大家通过PyTorch搭建整个程序(包括torchvision)。同时,我们将会使用Matplotlib来让GAN的天生结果可视化。以下代码能够导入上述所有库:

“””
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
“””
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

数据集

数据集对于训练GAN来说非常重要,尤其考虑到我们在GAN中处理的通常是非结构化数据(一般是图片、视频等),任意一class都可以有数据的分布。这种数据分布恰恰是GAN天生输出的基础。

为了更好地演示GAN的搭建流程,本文将带大家使用最简单的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。

像MNIST这样高质量的非结构化数据集都可以在格物钛的公开数据集网站上找到。事实上,格物钛Open Datasets平台涵盖了很多优质的公开数据集,同时也可以实现数据集托管及一站式搜索的功能,这对AI开发者来说,是相当实用的社区平台。 请添加图片描述

硬件需求

一般来说,虽然可以使用CPU来训练神经网络,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。我们可以用下面的代码来测试自己的机器能否用GPU来训练:

“””
Determine if any GPUs are available
“””
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

实现

网络结构

由于数字是非常简单的信息,我们可以将辨别器和天生器这两层结构都组建成全连接层(fully connected layers)。

我们可以用以下代码在PyTorch中搭建辨别器和天生器: 

“””
Network Architectures
The following are the discriminator and generator architectures
“””

class discriminator(nn.Module):
   def __init__(self):
       super(discriminator, self).__init__()
       self.fc1 = nn.Linear(784, 512)
       self.fc2 = nn.Linear(512, 1)
       self.activation = nn.LeakyReLU(0.1)

   def forward(self, x):
       x = x.view(-1, 784)
       x = self.activation(self.fc1(x))
       x = self.fc2(x)
       return nn.Sigmoid()(x)

class generator(nn.Module):
   def __init__(self):
       super(generator, self).__init__()
       self.fc1 = nn.Linear(128, 1024)
       self.fc2 = nn.Linear(1024, 2048)
       self.fc3 = nn.Linear(2048, 784)
       self.activation = nn.ReLU()

def forward(self, x):
   x = self.activation(self.fc1(x))
   x = self.activation(self.fc2(x))
   x = self.fc3(x)
   x = x.view(-1, 1, 28, 28)
   return nn.Tanh()(x)

训练

在训练GAN的时候,我们需要一边优化辨别器,一边改进天生器,因此每次迭代我们都需要同时优化两个互相矛盾的损失函数。

对于天生器,我们将输入一些随机噪音,让天生器来根据噪音的微小改变输出的图象:

“””
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
“””
for epoch in range(epochs):
   for idx, (imgs, _) in enumerate(train_loader):
       idx += 1

       # Training the discriminator
       # Real inputs are actual images of the MNIST dataset
       # Fake inputs are from the generator
       # Real inputs should be classified as 1 and fake as 0
       real_inputs = imgs.to(device)
       real_outputs = D(real_inputs)
       real_label = torch.ones(real_inputs.shape[0], 1).to(device)

       noise = (torch.rand(real_inputs.shape[0], 128) – 0.5) / 0.5
       noise = noise.to(device)
       fake_inputs = G(noise)
       fake_outputs = D(fake_inputs)
       fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

       outputs = torch.cat((real_outputs, fake_outputs), 0)
       targets = torch.cat((real_label, fake_label), 0)

       D_loss = loss(outputs, targets)
       D_optimizer.zero_grad()
       D_loss.backward()
       D_optimizer.step()

       # Training the generator
       # For generator, goal is to make the discriminator believe everything is 1
       noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
       noise = noise.to(device)

       fake_inputs = G(noise)
       fake_outputs = D(fake_inputs)
       fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
       G_loss = loss(fake_outputs, fake_targets)
       G_optimizer.zero_grad()
       G_loss.backward()
       G_optimizer.step()

       if idx % 100 == 0 or idx == len(train_loader):
           print(‘Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}’.format(epoch, idx, D_loss.item(), G_loss.item()))

   if (epoch+1) % 10 == 0:
       torch.save(G, ‘Generator_epoch_{}.pth’.format(epoch))
       print(‘Model saved.’)

结果

经过100个训练时期之后,我们就可以对数据集进行可视化处理,直接看到模型从随机噪音天生的数字: 请添加图片描述 我们可以看到,天生的结果和真实的数据非常相像。考虑到我们在这里只是搭建了一个非常简单的模型,实际的应用效果会有非常大的上升空间。

不仅是有样学样

GAN和以往机器视觉专家提出的想法都不一样,而利用GAN进行的具体场景应用更是让许多人赞叹深度网络的无限潜力。下面我们来看一下两个最为出名的GAN延申应用。

CycleGAN

朱俊彦等人2017年发表的CycleGAN能够在没有配对图片的情况下将一张图片从X域直接转换到Y域,比如把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换CycleGAN都能轻松做到,并且结果非常准确。 请添加图片描述

GauGAN

英伟达则通过GAN让人们能够只需要寥寥数笔勾勒出自己的想法,便能得到一张极为逼真的真实场景图片。虽然这种应用需要的计算成本极为高昂,但是GauGAN凭借它的转换能力探索出了前所未有的研究和应用领域。

请添加图片描述

结语

相信看到这里,你已经知道了GAN的大致工作原理,并且能够自己动手简单搭建一个GAN了。

原创文章,作者:格物钛Graviti,如若转载,请注明出处:https://www.iaiol.com/news/jian-dan-shi-yong-pytorch-da-jian-gan-mo-xing/

(0)
上一篇 2021年 8月 25日 下午2:47
下一篇 2021年 8月 26日 下午2:01

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注