Wasserstein GAN

PAPER
Generative Adversarial Networks
Improved Techniques for Training GANs
Towards Principled Methods for Training Generative Adversarial Networks
Wasserstein Generative Adversarial Networks
Improved Training of Wasserstein GANs

Kullback–Leibler and Jensen–Shannon Divergence

  KL-divergence and JS-divergence are generally used to measure the distance between two probability distributions $p$ and $q$. KL-divergence is formulated as follows.

\[KL(p\Vert q)=\int p(x)\log\frac{p(x)}{q(x)}\text{d}x\]

  In practice, we usually assume that $p$ and $q$ follow the Gaussian for simplifying the calculation. Notably, KL-divergence is asymmetric, and thus weak $p$ may induce insignificant results. JS-divergence balances $p$ and $q$ as:

\[JS(p\Vert q)=\frac{1}{2}KL(p\Vert \frac{p+q}{2})+\frac{1}{2}KL(q\Vert\frac{p+q}{2})\]

  JS-divergence is symmetric and stable if switching $p$ and $q$. However, KL and JS-divergence both rely on a strong assumption that $p$ and $q$ should overlap.

Generative Adversarial Network

  GAN contains two modules: a generator $G$ that synthesize the fake samples close to the real data distribution, and a discriminator $D$ that learns to determine whether a sample is from the $G$ or the real data distribution $p_r$. In the training phase, $D$ and $G$ are playing the following two-player minimax game as:

\[\begin{aligned} \min_G\max_DL(D,G) &= \mathbb{E}_{x\sim p_r}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[\log(1-D(G(z)))] \\ &= \mathbb{E}_{x\sim p_r}[\log D(x)]+\mathbb{E}_{x\sim p_g(x)}[\log(1-D(x))] \end{aligned}\]

  In other words, the generator $G$ is trained to fool the discriminator $D$ while $D$ is to tell the real data from the generated samples. The code of training a GAN on the MINST dataset is as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.encoder = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        # z: [batch_size, latent_dim]
        img = self.encoder(z)

        return img.view(img.shape[0], *self.img_shape)


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid() # removed in WGAN
        )

    def forward(self, img):
        # img: [batch_size, c * h * w]
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Initialize generator and discriminator
generator = Generator(100, (1, 28, 28))
discriminator = Discriminator((1, 28, 28))
adversarial_loss = torch.nn.BCELoss()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.FloatTensor(imgs.size(0), 1).fill_(1.0)
        fake = torch.FloatTensor(imgs.size(0), 1).fill_(0.0)

        # Train Generator to fool the discriminator
        optimizer_G.zero_grad()
        z = torch.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator to classify real from generated samples
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

  We first consider the optimal discriminator $D$ for any generator $G$. The training criterion for the discriminator $D$ is to maximize the $L(G,D)$, where we can obtain the analytical solution of the optimal $D^*$.

\[L(G,D)=\int p_r(x)\log D(x)+p_g(x)\log(1-D(x))\text{d}x\] \[D^*=\arg\max_{D(x)} L(G,D)=\frac{p_r(x)}{p_r(x)+p_g(x)}\] \[\max_DL(G,D)=\mathbb{E}_{x\sim p_{data}(x)}[\log D^*]+\mathbb{E}_{x\sim p_g(x)}[\log(1-D^*)]\]

  Given the global optimality of $p_g=p_r$, the optimal $D^*(x)$ becomes $1/2$ and the minimum of the $L(G,D)$ is $-2\log2$. Rethinking the distance between $p_r$ and $p_g$, we can get

\[JS(p_r\Vert p_g)=\log2+\frac{1}{2}L(G,D)\] \[L(G,D)=2JS(p_r\Vert p_g)-2\log2\]

  In essence, the loss function of GAN quantifies the distance between the real data distribution $p_r$ and the generative data distribution $p_g$. According to the JS-divergence, the lower bound of $L(G,D)$ is also $-2\log2$.

Problems in GAN

  Although GAN has shown significant potential in image generation, its training is massively unstable. One possible reason is that the generator and the discriminator are trained independently without interaction. Updating the gradient of both models simultaneously may not guarantee convergence.

  The other possible cause is that $p_r$ and $p_g$ rest in low dimensional manifolds, where they are almost disjointed. In this case the optimal discriminator will be perfect and its gradient will be zero almost everywhere. When the discriminator is perfect, the generator will hardly update due to vanishing gradients.

\[\lim_{\Vert D-D^*\Vert=0}\nabla_\theta\mathbb{E}_{z\sim p(z)}[\log(1-D(G_\theta(z)))]=0\]

  When the discriminator gets better, the gradient of the generator vanishes. This means the generator may always produce the same outputs, which is commonly referred to as mode collapse. See Arjovsky and Bottou for more details.

Improved Techniques for Training GANs

  (1) Adding noises. Vanishing gradients always occurs in that $p_r$ and $p_g$ are disjoint. We can add continuous noise to the inputs of the discriminator, therefore smoothening the distribution of the probability mass.

  (2) Softer metrics of distribution distance. When $p_r$ and $p_g$ are disjoint, the JS-divergence can not provide a meaningful value. Wasserstein metric is introduced to replace JS-divergence due to its better performance.

  As suggested in Salimans, et al., we list improved rechniques for training GANs, including (3) feature matching, (4) mini-batch discrimination, (5) historical averaging , (6) one-sided label smoothing, and (7) virtual batch normalization. See the original paper for more details.

Wasserstein GAN

  Wasserstein distance $W(p,q)$ is the minimum cost of transporting the whole probability mass of $p$ to match the probability mass of $q$, which is defined as

\[W(p,q)=\inf_{\gamma\sim\Gamma}\mathbb{E}_{(x,y)\sim\gamma}[\Vert x-y\Vert]=\sum_{x,y}\gamma(x,y)\Vert x-y\Vert\]

where $\inf$ means the infimum and $\Gamma$ denotes the set of all possible joint probability distributions between $p$ and $q$. In essence, Wasserstein distance is a measure of energy conversion if treating the $\gamma(x,y)$ as force and $\Vert x-y\Vert$ as displacement. Even two distributions are located in lower dimensional manifolds without overlaps, Wasserstein distance can still provide a meaningful value.

  However, the infimum in $W(p,q)$ is intractable. According to the Kantorovich-Rubinstein duality, we can obtain

\[W(p,q)=\frac{1}{K}\sup_{\Vert f\Vert_L\leq K}\mathbb{E}_{x\sim p}[f(x)]-\mathbb{E}_{x\sim q}[f(x)]\]

where $\sup$ is the opposite of $\inf$ and the function $f$ satisfies K-Lipschitz continuous. Suppose the $f_\omega$ is parameterized by $\omega$, the discriminator of WGAN is optimized by

\[L(p_r,p_g)=W(p_r,p_g)=\max_{\omega}\mathbb{E}_{x\sim p_r}[f_\omega(x)]-\mathbb{E}_{z\sim p(z)}[f_\omega(g_\theta(z))]\]

  Now comes the question of maintaining the K-Lipschitz continuous of $f_\omega$ in the training phase. Arjovsky presents a simple yet very practical trick: clamp the weights $\omega$ to a fixed box such as $[-0.01,0.01]$, inducing a compact space of $\omega$ and thus ensuring the Lipschitz continuity of $f_\omega$. The specific algorithm and PyTorch implementation of WGAN are as follows.

Imgur

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=5e-5)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=5e-5)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = torch.FloatTensor(imgs)

        # Train Discriminator
        optimizer_D.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        fake_imgs = generator(z).detach()
        # Adversarial loss, inverting the sign to find the maximum
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Train the generator every n_critic iterations
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))
            loss_G.backward()
            optimizer_G.step()

  Empirically, the WGAN recommended taking RMSProp or SGD as the optimizer rather than a momentum-based optimizer such as Adam. Gulrajani proposed an alternative way to enforce the Lipschitz constraint via gradient penalty as follows. $p(x)$ is sampled uniformly along straight lines between pairs of points sampled from $p_r$ and $p_g$. See the original paper for more details.

\[L=\underbrace{-\mathbb{E}_{x\sim p_r}[D(x)]+\mathbb{E}_{x\sim p_g(x)}[D(x)]}_{\text{Original critic loss}}+\lambda\mathbb{E}_{x\sim p(x)}[(\Vert\nabla_xD(x)\Vert_2-1)^2]\]