(1) I am working on a class project where I compare the performance of GAN and WGAN. Since the only difference between GAN and WGAN is the Wasserstein loss, I chose one neural network model architecture and trained both GAN and WGAN (so, only the loss functions differ).
However, WGAN performs much worse than GAN, and I’m not sure why. Is the performance of Wasserstein loss model dependent? If had to compare GAN and WGAN, holding the NN architecture fixed, what architecture should I choose?
Usually, the same architecture and parameters would not be good for training both GAN and WGAN.
In a typical GAN, you want to avoid making the discriminator more powerful than the generator, and you want to avoid training the discriminator so much that it “overpowers” the generator and always finds the fakes.
In WGAN, you want to make the discriminator as powerful as possible, possibly by giving it a larger network, and you also want to train it for as long as computationally feasible — several iterations for every one iteration the generator trains. The theory behind WGAN requires that the discriminator has converged to the optimal discriminating function, so this is important.
If for some reason, you really need to fix one architecture, choose one where the generator is about the same size as the discriminator, and then make sure when you’re training the WGAN that you really train the discriminator a lot — maybe 10x more than the generator.
(2) Problem with GANs
- It’s harder to achieve Nash Equilibrium — Since there are two neural networks (generator and discriminator), they are being trained simultaneously to find a Nash Equilibrium. In the whole process each player updates the cost function independently without considering the updates of cost function by another network. This method cannot assure a convergence, which is the stated objective.
- Vanishing gradient — When the discriminator works as required, the distribution D(x) equals 1 when x belongs to Pr and vice versa. In this process, loss function L fails to zero and results in no gradients to update the loss during the training process. This figure shows that as the discriminator gets increasingly better, the gradient vanishes fast, tending to 0.
- Use better metric of distribution similarity — The loss function as proposed in the vanilla GAN (by Goodfellow et al.) measures the JS divergence between the distributions of Pr and P(theta). This metric fails to provide a meaningful value when two distributions are disjointed.
Replacing JS divergence with the Wasserstein metric gives a much smoother value space.
Training a Generative Adversarial Network faces a major problem:
- If the discriminator works as required, the gradient of the loss function starts tending to zero. As a process loss cannot be updated, training becomes very slow or the model gets stuck.
- If the discriminator behaves badly, the generator does not have accurate feedback and the loss function cannot represent the reality.