【第一节】初识GAN,GAN是什么意思?GAN的直觉理解,使用 PyTorch 零基础训练一个自己的GAN
发表于: 2021-02-01 10:12:00 | 已被阅读: 6 | 分类于: GAN
GAN 的全称是“Generative Adversarial Networks”,翻译成中文,意思是“生成式对抗网络”。深度学习领域的同学即使没有深入研究过 GAN,应该大多数也听说过这个概念。GAN 非常有意思,举个狭义的例子,训练好的图像 GAN 看起来就像是一个能够作画的机器人,而且它可以根据不同的输入特征做出不同的画。有不少学者称,生成网络才是真正的人工智能网络,由此可见业界对 GAN 的高度评价。
GAN 的“烂大街”形象解释
这里还是啰嗦一下关于 GAN 的烂大街的形象解释:GAN 的训练过程是一个“对抗”过程,谁在对抗呢?可以想象是画作的鉴别者和作假者在对抗。
刚开始时,鉴别者和作假者的水平都比较低。
于是,鉴别者开始提升自己的鉴别能力,并且能够轻易鉴别出作假者画出来的“假画”。当然,为了让作假者心服口服,鉴别者要把自己发现“假的地方”告诉作假者。作假者知道自己假画的“不足”后,开始改进,提升自己的作假能力。
在作假者提升自己作假能力期间,鉴别者没闲着,也在不断的提升自己鉴别能力。随着时间推移,鉴别者的鉴别能力越来越强,作假者的作假能力也越来越强,直到最后,作假者终于能够作出在一般人看来与真画无异的画。再回到 GAN 网络,此时我们便称 GAN 网络的训练收敛了。
到这里读者可以考虑两个“意外情况”:
- 如果鉴别者只要看到画的作者是作假者,就说它是假画;
- 鉴别者每次告诉作假者“假的地方”时,都撒谎了,瞎说的是别的地方假,但是鉴别者在业界很权威,作假者相信了。
此时,GAN 网络还能收敛吗?
事实上,这两个“意外情况”在实际的 GAN 训练过程中经常出现,那么应该怎么深入理解这种“意外情况”,以及如何解决呢?本系列文章在日后再详谈。
设计 GAN 网络
根据前面一小节 GAN 的形象解释,很容易就可以设计出 GAN 网络。以非常容易获得的 mnist 手写数字数据集为例,本文将使用 PyTorch,从头设计一个 GAN 网络并训练之,让机器学会“写数字”。
mnist 是一组 28x28 的手写数字数据集。
首先,前面提到 GAN 网络一般包括两个网络——鉴别者和作假者,先定义这两个网络:
class Generator(nn.Module):
def __init__(self, nz=16):
super(Generator, self).__init__()
self.nz = nz
self.fc_1 = nn.Linear(self.nz, 28 * 7, bias=False)
self.leaky_ReLU_1 = nn.LeakyReLU(0.2, inplace=True)
self.fc_2 = nn.Linear(28 * 7, 28 * 14, bias=False)
self.leaky_ReLU_2 = nn.LeakyReLU(0.2, inplace=True)
self.fc_3 = nn.Linear(28 * 14, 28 * 28, bias=False)
self.tanh = nn.Tanh()
def forward(self, z):
bs = z.size(0)
z = self.leaky_ReLU_1(self.fc_1(z))
z = self.leaky_ReLU_2(self.fc_2(z))
z = self.tanh(self.fc_3(z))
z = torch.reshape(z, (bs, 1, 28, 28))
return z
这里称“作假者”为Generator
,“鉴别者”为Discriminator
。可以看出 Generator 的结构相当简单——由几个全连接层(fc)构成主干网络,最后通过 reshape
层将网路学习到的特征整形为 1x28x28 的手写数字图片,tanh
层的作用则是将特征缩放到 -1.0 到 1.0 的范围内,以保证生成的图像像素值都在合理范围内,如下图所示:
现在将注意力放到 Generator 的输入上,不难发现它是一组一维向量,这组向量通常从一个随机分布(比如高斯分布)中随机采样。
为了行文方便,下面称这组一维向量为
隐变量
。
当然了,即使隐变量来自于随机数,它也有自己的具体含义:
我们可以将 Generator 的过程反向来看——即输入为 1x28x28 的手写数字图片,输出为一维向量,这就和通常的图像特征提取网络很像了.
不难理解,隐变量其实就是手写数字的抽象低维特征,它包含着手写数字诸如类别、粗细、字体、倾斜程度等特征信息。因此,Generator 其实就是根据输入的隐变量,生成对应特征的手写数字图片。
弄明白 Generator 后,我们接着定义 Discriminator 网络,一般来说,它的复杂度与 Generator 相当,所以 Discriminator 几乎与 Generator 是对称的,如下图所示:
Discriminator 在最后将所有特征提取为 size=1,并且通过 sigmoid
算子将这个特征映射到 0~1.0 范围内,以方便判断输入的图片是“真”还是“假”。通常来说,我们可以认为 Discriminator 最终输出的值越接近 1.0,输入的图片越“真”,否则就越“假”。这就相当于 Discriminator 对输入图片打分,越“真”的图片得分越高。相应的 Python 代码:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc_3 = nn.Linear(28 * 28, 28 * 14, bias=False)
self.leaky_ReLU_2 = nn.LeakyReLU(0.2, inplace=True)
self.fc_2 = nn.Linear(28 * 14, 28 * 7, bias=False)
self.leaky_ReLU_1 = nn.LeakyReLU(0.2, inplace=True)
self.fc_1 = nn.Linear(28 * 7, 1, bias=False)
def forward(self, x):
bs = x.size(0)
x = x.view(bs, -1)
x = self.leaky_ReLU_2(self.fc_3(x))
x = self.leaky_ReLU_1(self.fc_2(x))
x = torch.sigmoid(self.fc_1(x))
return x
准备 MNIST 数据集
在 PyTorch 中使用 MNIST 数据集是非常方便的,使用下面的 Python 代码,可以自动从网络下载到:
class GanDataSet(torch_data.Dataset):
def __init__(self, nz=16):
super(GanDataSet, self).__init__()
if not os.path.exists('./tmp'):
os.makedirs('./tmp')
self.mnist = datasets.MNIST(root='./tmp', train=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]), download=True)
self.nz = nz
def __len__(self):
return self.mnist.__len__()
def __getitem__(self, item):
real = self.mnist.__getitem__(item)[0] # item consist of [data, label]
z = torch.randn(self.nz).float()
return real, z
为了方便,数据集的每一个 item 由真实的手写数字图片和隐变量组成,隐变量随机从 randn 分布中采样,随后输送给 Generator,用于生成“假的”手写数字图片。
训练 GAN 网络
GAN 的训练过程其实就是 Generator 和 Discriminator 对抗的过程,Generator 需要不断提升自己的生成能力,Discriminator 需要不断提升自己的鉴别能力。这个过程反映到工程上倒是非常简单了,无非就是交替训练 Generator 和 Discriminator 而已。
设真实手写数字图片样本 \( x \sim p(x) \),隐变量 \( z \sim q(z) \),将 Generator 网络看作函数 \( G \),则 \( G(z) \) 即为生成的“假”样本,同样也将 Discriminator 网络看作函数 \( D \),则 \( D(x) \) 和 \( D(G(z)) \) 分别表示真实样本和假样本的得分值。
训练 Discriminator
在训练判别器 Discriminator 时,为了提升 Discriminator 的判别能力,我们需要站在 Discriminator 的角度,自然希望它对真实样本打的分越高越好,对假样本打的分越低越好,即期望下面这个目标
越大越好:
\( $\underset{D}{argmax}{\mathbb{E}_{x \sim p(x)), z \sim q(z))}}log(D(x))-log(D(G(z))) \)$
注意,训练 Discriminator 时,我们并不需要作假者 Generator 的其他信息(梯度),只关注它生成的假样本本身。至此,便不难写出对应的 Discriminator 训练代码,请看:
def train_discriminator(real_x, fake_z):
for p in G.parameters():
p.requires_grad = False
for p in D.parameters():
p.requires_grad = True
with torch.no_grad():
fake_x = G(fake_z)
real_conf, fake_conf = D(real_x), D(fake_x)
loss = (torch.log(fake_conf + 1e-6) - torch.log(real_conf + 1e-6)).mean()
D_optimizer.zero_grad()
loss.backward()
D_optimizer.step()
return loss
还需要注意的是,鉴于 PyTorch 优化器默认最小化 loss,而我们期望能够最大化上述目标
,因此上述 Python 代码中的 loss 其实是期望目标
的相反数。
训练 Generator
训练作假者 Generator 时,自然是要站在作假者的角度,我们希望生成的假样本越“真”越好,也即“假”样本在鉴别者 Discriminator 那边的得分值越高越好:
\( $\underset{G}{argmax}{\mathbb{E}_{z \sim q(z))}}log(D(G(z))) \)$
不难写出相应的 Python 代码,请看:
def train_generator(real_x, fake_z):
for p in D.parameters():
p.requires_grad = False
for p in G.parameters():
p.requires_grad = True
fake_x = G(fake_z)
fake_conf = D(fake_x)
loss = -torch.log(fake_conf + 1e-6).mean()
G_optimizer.zero_grad()
loss.backward()
G_optimizer.step()
return loss
至此,鉴别者 Discriminator 和作假者 Generator 的训练代码就完成了,结合 MNIST 数据集代码,将整个训练过程串起来即可完成生成模型,完整的代码可以参考:
todo:
add link.
训练结果
我们已经知道,GAN 的训练过程是一个对抗过程,在刚开始时,鉴别者 Discriminator 和作假者 Generator 的水平都很差,生成的假样本明显很假,请看:
上图为训练 1 个 epoch 后的生成“假”样本,虽然有一点手写数字的感觉,但和真实样本相比有着明显的区别。不过,随着训练的进行,Generator 生成的“假”样本将会越来越真实,下图是训练 50 个 epoch 后的生成“假”样本:
可见,随着训练的进行,Generator 生成的“假”样本质量还是明显变好了的。