GAN学习笔记

1. GAN原理

论文链接:Generative Adversarial Networks

生成式对抗网络(GAN, Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
Ian J. Goodfellow等人于2014年10月在Generative Adversarial Networks中提出了一个通过对抗过程估计生成模型的新框架。框架中同时训练两个模型:捕获数据分布的生成模型G,和估计样本来自训练数据的概率的判别模型D。G的训练程序是将D错误的概率最大化。这个框架对应一个最大值集下限的双方对抗游戏。可以证明在任意函数G和D的空间中,存在唯一的解决方案,使得G重现训练数据分布,而D=0.5。在G和D由多层感知器定义的情况下,整个系统可以用反向传播进行训练。在训练或生成样本期间,不需要任何马尔科夫链或展开的近似推理网络。实验通过对生成的样品的定性和定量评估证明了本框架的潜力。
—— 摘自百度百科

GAN是由两部分组成的,第一部分是生成,第二部分是对抗。简单来说,就是有一个生成网络G和一个判别网络D,通过训练让两个网络相互竞争,生成网络G接受一个随机噪声z来生成假的数据G(z),对抗网络D通过判别器去判别真伪概率,最后希望生成器G生成的数据能够以假乱真。在最理想的状态下,D(G(z)) = 0.5。

以上原理的数学公式为:

式子中,x表示真实数据,z表示噪声,G(z)表示G网络根据z生成的数据,D(x)表示D网络判断真实数据是否为真的概率,因此D(x)接近1越好。而D(G(z))代表D网络判断G网络生成的虚假数据是真实的概率。
因此,对于D网络(辨别器):

  • 如果x来自$P_{data}$,那么D(x)要越大越好,可以用$\log(D(x)) \uparrow$表示。
  • 如果x来自于$P_{generator}$,那么D(G(z))越小越好,进而表示为$\log[1−D(G(z))] \uparrow$。
  • 因此需要最大化$max_D$
    对于G网络(生成器):
  • $D(G(z))$越大越好,进而表示为log[1−D(G(z))]↓
  • 因此需要最小化$min_{G}$。

第一步我们训练D,D是希望V(D,G)越大越好,所以是加上梯度(ascending)。第二步训练G时,V(D,G)越小越好,所以是减去梯度(descending)。整个训练过程交替进行。

2. GAN实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn,optim
import torchvision.transforms as tfs

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

transforms = tfs.Compose([
tfs.Resize((32,32)),
tfs.ToTensor(),
#tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

flat_img = 32*32*3
noise_dim = 100

img = Image.open('1.jpg')
real_img = transforms(img)

torch.manual_seed(2)
fake_img = torch.rand(1,noise_dim)

plt.imshow(np.transpose(real_img.numpy(),(1,2,0)))
#print(real_img)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(flat_img, 1024),
nn.ReLU(),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Linear(2048, 1),
nn.Sigmoid() #sigmoid常用于二分类问题
)

def forward(self, img):
img = img.view(1, -1)
out = self.linear(img)
return out
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.LeakyReLU(),
nn.Linear(1024, 2048),
nn.LeakyReLU(),
nn.Linear(2048, flat_img)
)

def forward(self, latent_space):
latent_space = latent_space.view(1, -1)
out = self.linear(latent_space)
return out
1
2
3
4
5
6
7
8
9
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

discr = Discriminator().to(device)
gen = Generator().to(device)

opt_d = optim.SGD(discr.parameters(), lr=0.001, momentum=0.9)
opt_g = optim.SGD(gen.parameters(), lr=0.001, momentum=0.9)

criterion = nn.BCELoss()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
epochs = 200
discr_e = 4
gen_e = 4

#whole model training starts here
for epoch in range(epochs):

#discriminator training
for k in range(discr_e):
out_d1 = discr(real_img.to(device))
#loss for real image
loss_d1 = criterion(out_d1, torch.ones((1, 1)).to(device))

out_d2 = gen(fake_img.to(device)).detach()
#loss for fake image
loss_d2 = criterion(discr(out_d2.to(device)), torch.zeros((1, 1)).to(device))

opt_d.zero_grad()
loss_d = loss_d1+loss_d2
loss_d.backward()
opt_d.step()

#generator training
for i in range(gen_e):
out_g = gen(fake_img.to(device))
#Binary cross entropy loss
loss_g = criterion(discr(out_g.to(device)), torch.ones(1, 1).to(device))
#Loss function in the GAN paper
#[log(1 - D(G(z)))]
#loss_g = torch.log(torch.ones(1, 1).to(device) - (discr(out_g.to(device))))

opt_g.zero_grad()
loss_g.backward()
opt_g.step()

if (epoch+1)%10==0:
print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f}'.format(epoch+1,epochs,loss_d.data.item(),loss_g.data.item()))

out=gen(fake_img.to(device)).detach()
out_score=discr(out_g.to(device))
loss = criterion(out_score, torch.ones(1, 1).to(device))
print("score:",out_score.item(),"loss:",loss.item())

out=out.reshape((3,32,32)).cpu()

#print(out)
plt.subplot(1,2,1)
plt.title('fake')
plt.imshow(np.transpose(out.numpy(),(1,2,0)))
plt.subplot(1,2,2)
plt.title('real')
plt.imshow(np.transpose(real_img.numpy(),(1,2,0)))

3. DCGAN原理

https://arxiv.org/pdf/1511.06434.pdf

DCGAN的原理和GAN是一样的。只不过DCGANs体系结构有所改变:

  • 使用指定步长的卷积层代替池化层
  • 在生成器和鉴别器中使用batch norm。
  • 移除全连接层,以实现更深层次的体系结构,减少参数。
  • 在生成器中使用ReLU激活,但输出使用Tanh。
  • 在鉴别器中使用LeakyReLU激活

DCGAN中提到了网络的训练细节:

  • 使用Adam算法更新参数,betas=(0.5, 0.999);
  • batch size选为128;
  • 权重使用正太分布,均值为0,标准差为0.02;
  • 学习率0.0002。

4. DCGAN实例

生成动漫头像,数据集来自https://www.kaggle.com/soumikrakshit/anime-faces

1
2
3
4
5
6
7
8
9
10
11
import os
import numpy as np
import imageio
from tqdm.auto import tqdm
import torch,torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

avatar_img_path = "E:/python/dataset/anime face/data"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

noise_dim = 100
batch_size = 16
beta1=0.5
'''
#自定义数据集
file_train=[]
for image_name in tqdm(os.listdir(avatar_img_path)):
file_train.append(os.path.join(avatar_img_path,image_name))

def default_loader(path):
img = imageio.imread(path)
img = img/255
img = trans(img)
return img

class trainset(Dataset):
def __init__(self, loader=default_loader):
#定义好 image 的路径
self.images = file_train
self.target = 0
self.loader = loader

def __getitem__(self, index):
fn = self.images[index]
img = self.loader(fn)
target = self.target
return img,target

def __len__(self):
return len(self.images)
'''
img_dataset=torchvision.datasets.ImageFolder("E:/python/dataset/anime face", transform=trans)
#img_dataset=trainset()
img_dataloader=DataLoader(img_dataset,batch_size=batch_size,shuffle=True)
#print(img_dataset)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Generator(nn.Module):
def __init__(self, z_dim):
super(Generator,self).__init__()
self.z_dim = z_dim
self.generator = nn.Sequential(
nn.ConvTranspose2d(self.z_dim,512,4,1,0,bias=False),
nn.BatchNorm2d(num_features=512),
nn.ReLU(True),
nn.ConvTranspose2d(512,256,4,2,1,bias=False),
nn.BatchNorm2d(num_features=256),
nn.ReLU(True),
nn.ConvTranspose2d(256,128,4,2,1,bias=False),
nn.BatchNorm2d(num_features=128),
nn.ReLU(True),
nn.ConvTranspose2d(128,64,4,2,1,bias=False),
nn.BatchNorm2d(num_features=64),
nn.ReLU(True),
nn.ConvTranspose2d(64,3,4,2,1,bias=False),
nn.Tanh()
)
self.weight_init()

def weight_init(self):
for m in self.generator.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight.data, 0, 0.02)

elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 0, 0.02)
nn.init.constant_(m.bias.data, 0)

def forward(self, x):
out = self.generator(x)
return out
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Discriminator(nn.Module):
def __init__(self):
"""
initialize

:param image_size: tuple (3, h, w)
"""
super(Discriminator,self).__init__()
self.discriminator = nn.Sequential(
nn.Conv2d(3,64,4,2,1,bias=False),
nn.BatchNorm2d(num_features=64),
nn.LeakyReLU(0.2),
nn.Conv2d(64,128,4,2,1,bias=False),
nn.BatchNorm2d(num_features=128),
nn.LeakyReLU(0.2),
nn.Conv2d(128,256,4,2,1,bias=False),
nn.BatchNorm2d(num_features=256),
nn.LeakyReLU(0.2),
nn.Conv2d(256,512,4,2,1,bias=False),
nn.BatchNorm2d(num_features=512),
nn.LeakyReLU(0.2),
nn.Conv2d(512,1,4,2,0,bias=False),
nn.Sigmoid()
)
self.weight_init()

def weight_init(self):
for m in self.discriminator.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight.data, 0, 0.02)

elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 0, 0.02)
nn.init.constant_(m.bias.data, 0)

def forward(self, x):
out = self.discriminator(x)
out = out.view(x.size(0), -1)
return out
1
2
3
4
5
6
7
8
9
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

bce_loss = nn.BCELoss()
#optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.00005, betas=(beta1, 0.999))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(beta1, 0.999))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
epochs=20

fixed_z=torch.randn(batch_size,noise_dim,1,1,device=device)
for epoch in range(epochs):
for step,(image,_) in enumerate(img_dataloader):
batch_size=image.size(0)
#=====训练辨别器====
optimizer_D.zero_grad()
# 计算判别器对真实样本给出为真的概率
d_out_real = discriminator(image.type(torch.FloatTensor).to(device))
real_loss = bce_loss(d_out_real, torch.ones(size=(batch_size, 1)).to(device))
real_scores = d_out_real
#real_loss.backward()
# 计算判别器对假样本给出为真的概率
noise = torch.randn(batch_size,noise_dim,1,1,device=device)
fake_img = generator(noise)
d_out_fake = discriminator(fake_img.detach())
fake_loss = bce_loss(d_out_fake, torch.zeros(size=(batch_size, 1)).to(device))
fake_scores = d_out_fake
#fake_loss.backward()
# 更新判别器参数
d_loss = (real_loss + fake_loss)/2
d_loss.backward()
optimizer_D.step()

#=====训练生成器====
optimizer_G.zero_grad()
# 计算判别器对伪造样本的输出的为真样本的概率值
d_out_fake = discriminator(fake_img)
# 计算生成器伪造样本不被认为是真的损失
g_loss = bce_loss(d_out_fake, torch.ones(size=(batch_size, 1)).to(device))
# 更新生成器
g_loss.backward()
optimizer_G.step()

# #################################################
# 4:打印损失,保存图片
if step % 200 == 0:
generator.eval()
fixed_image = generator(fixed_z)
generator.train()
print("[epoch: {}/{}], [iter: {}], [G loss: {:.3f}], [D loss: {:.3f}], [R Score: {:.3f}], [F Score: {:.3f}]".format(epoch+1,epochs,step, g_loss.item(), d_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))
utils.save_image(fixed_image.detach(), str(epoch+1)+"fake.jpg",normalize=True)
utils.save_image(image,str(epoch+1)+"real.jpg",normalize=True)

结果如下:
[epoch: 1/20], [iter: 0], [G loss: 0.699], [D loss: 0.694], [R Score: 0.499], [F Score: 0.500]
[epoch: 1/20], [iter: 200], [G loss: 0.803], [D loss: 0.715], [R Score: 0.512], [F Score: 0.529]
[epoch: 1/20], [iter: 400], [G loss: 0.734], [D loss: 0.692], [R Score: 0.492], [F Score: 0.491]
[epoch: 1/20], [iter: 600], [G loss: 0.730], [D loss: 0.693], [R Score: 0.496], [F Score: 0.496]
[epoch: 1/20], [iter: 800], [G loss: 0.748], [D loss: 0.686], [R Score: 0.500], [F Score: 0.492]
[epoch: 1/20], [iter: 1000], [G loss: 0.745], [D loss: 0.680], [R Score: 0.514], [F Score: 0.499]
[epoch: 1/20], [iter: 1200], [G loss: 0.715], [D loss: 0.701], [R Score: 0.527], [F Score: 0.532]
[epoch: 2/20], [iter: 0], [G loss: 0.762], [D loss: 0.679], [R Score: 0.524], [F Score: 0.508]
[epoch: 2/20], [iter: 200], [G loss: 0.815], [D loss: 0.686], [R Score: 0.507], [F Score: 0.498]
[epoch: 2/20], [iter: 400], [G loss: 0.836], [D loss: 0.665], [R Score: 0.509], [F Score: 0.479]
[epoch: 2/20], [iter: 600], [G loss: 0.759], [D loss: 0.694], [R Score: 0.523], [F Score: 0.520]
[epoch: 2/20], [iter: 800], [G loss: 0.973], [D loss: 0.646], [R Score: 0.551], [F Score: 0.499]
[epoch: 2/20], [iter: 1000], [G loss: 0.926], [D loss: 0.671], [R Score: 0.531], [F Score: 0.495]
[epoch: 2/20], [iter: 1200], [G loss: 1.100], [D loss: 0.582], [R Score: 0.497], [F Score: 0.362]

第7个epoch:

batch_size以及其他参数可自行调整。

5. WGAN原理

论文:Wasserstein GAN
Towards Principled Methods for Training Generative Adversarial Networks

总所周知,GAN的训练存在很多问题和挑战:

  • 训练困难,需要精心设计模型结构,协调G和D的训练程度
  • G和D的损失函数无法指示训练过程,缺乏一个有意义的指标和生成图片的质量相关联
  • 模式崩坏(mode collapse),生成的图片虽然看起来像是真的,但是缺乏多样性

WGAN相比较于传统的GAN,做了如下修改:

  • D最后一层去掉sigmoid
  • G和D的loss不取log
  • 每次更新D的参数后,将其绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

G的损失函数原本为$\mathbb{E} {z \sim p z}[\log(1-D(G(z)))]$ ,其导致的结果是,如果D训练得太好,G将学习不到有效的梯度。但是,如果D训练得不够好,G也学习不到有效的梯度。
因此以上损失函数导致GAN训练特别不稳定,需要小心协调G和D的训练程度。

WGAN参考资料:
https://zhuanlan.zhihu.com/p/44169714
https://www.cnblogs.com/Allen-rg/p/10305125.html