Least Squares GAN

Thanks to F-GAN, which established the general framework of GAN training, recently we saw modifications of GAN which unlike the original GAN, learn other metrics other than Jensen-Shannon divergence (JSD).

One of those modifications are Wasserstein GAN (WGAN), which replaces JSD with Wasserstein distance. It works wonderfully well and even the authors claimed that it cures the mode collapse problem and providing GAN with meaningful loss function. Although the implementation is quite straightforward, the theory behind WGAN is heavy and requires some “hack” e.g. weight clipping. Moreover, the training process and the convergence are slower than the original GAN.

Now, the question is: could we design GAN that works well, fast, simpler, and more intuitive compared to WGAN?

The answer yes. What we need is to back to basic.

Least Squares GAN

The main idea of LSGAN is to use loss function that provides smooth and non-saturating gradient in discriminator \( D \). We want \( D \) to “pull” data generated by generator \( G \) towards the real data manifold \( P_{data}(X) \), so that \( G \) generates data that are similar to \( P_{data}(X) \).

As we know in original GAN, \( D \) uses log loss. The decision boundary is something like this:

Log loss decision boundary

As \( D \) uses sigmoid function, and as it is saturating very quickly, even for somewhat-still-small data point \( x \), it will quickly ignore the distance of \( x \) to the decision boundary \( w \). What it means is that it essentially won’t penalize \( x \) that is far away from \( w \) in the manifold. That is, as long as \( x \) is correctly labeled, we’re happy. Consequently, as \( x \) becoming bigger and bigger, the gradient of \( D \) quickly goes down to \( 0 \), as log loss doesn’t care about the distance, only the sign.

For learning the manifold of \( P_{data}(X) \), then log loss is not effective. Generator \( G \) is trained using the gradient of \( D \). If the gradient of \( D \) is saturating to \( 0 \), then \( G \) won’t have the necessary information for learning \( P_{data}(X) \).

Enter \( {L_2} \) loss:

L2 decision boundary

In \( L2 \) loss, data that are quite far away from \( w \) (in this context, the regression line of \( P_{data}(X) \)) will be penalized proportional to the distance. The gradient therefore will only become \( 0 \) when \( w \) perfectly captures all of \( x \). This will guarantee \( D \) to yield informative gradients if \( G \) has not captured the data manifold.

During the optimization process, the only way for \( L2 \) loss of \( D \) to be small is to make \( G \) generating \( x \) that are close to \( w \). This way, \( G \) will actually learn to match \( P_{data}(X) \)!

The overall training objective of LSGAN then could be stated as follows:

LSGAN loss

Above, we choose \( b = 1 \) to state that it’s the real data. Conversely, we choose \( a = 0 \) as it the fake data. Finally \( c = 1 \), as we want to fool \( D \).

Those values is not the only valid values, though. The authors of LSGAN provides some theory that optimizing the above loss is the same as minimizing Pearson \( \chi^2 \) divergence, if \( b - c = 1 \) and \( b - a = 2 \). Hence, choosing \( a = -1, b = 1, c = 0 \) is equally valid.

Our final loss is as follows:

LSGAN loss

LSGAN implementation in Pytorch

Let’s outline the modifications done by LSGAN to the original GAN:

  1. Remove \( \log \) from \( D \)
  2. Use \( L2 \) loss instead of log loss

So let’s begin by doing the the first checklist:

G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    # No sigmoid
    torch.nn.Linear(h_dim, 1),
)

G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

The rest is straightforward, following the loss function above.

for it in range(1000000):
    # Sample data
    z = Variable(torch.randn(mb_size, z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Dicriminator
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    # Discriminator loss
    D_loss = 0.5 * (torch.mean((D_real - 1)**2) + torch.mean(D_fake**2))

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    G_sample = G(z)
    D_fake = D(G_sample)

    # Generator loss
    G_loss = 0.5 * torch.mean((D_fake - 1)**2)

    G_loss.backward()
    G_solver.step()
    reset_grad()

The full code is available at https://github.com/wiseodd/generative-models.

Conclusion

In this post we looked at LSGAN, which modifies the original GAN by using \( L2 \) loss instead of log loss.

We looked at the intuition why \( L2 \) loss could help GAN learning the data manifold. We also looked at the intuition on why GAN could not learn effectively using log loss.

Finally, we implemented LSGAN in Pytorch. We found that the implementation of LSGAN is very simple, amounting to just two line changes.

References

  1. Nowozin, Sebastian, Botond Cseke, and Ryota Tomioka. “f-GAN: Training generative neural samplers using variational divergence minimization.” Advances in Neural Information Processing Systems. 2016. arxiv
  2. Mao, Xudong, et al. “Multi-class Generative Adversarial Networks with the L2 Loss Function.” arXiv preprint arXiv:1611.04076 (2016). arxiv