Stabilizing GAN training With Random Projections

Check the code  github

In this paper (Neyshabur et al., 2017), They introduced a framework to stabilize the GAN training by using multiple projections with fixed filters of each input image to a different discriminator. Training GAN models is unstable in high dimensional space and some problems that might arise during training is the saturation of the discriminator. In that case the discriminator wins the game (diminished gradients problem).

What are GANs ?

Generative models in general provides a way to model structure in complex distributions. They have been useful in generating data points (images, music, etc…).

Generative Adversarial Networks are generative models that creates a minimax game with 2 models discriminator and a generator. The discriminator is a simple classifier that is trying to identify real data coming from training distribution and fake data generated by the generator model.

The generator model takes a simple random noise \(z\) sampled from a gaussian/uniform distribution (gaussian is better check ganhacks). The general objective of the minimax game of the GAN is \(min_{G}max_{D} (D,G) = \mathbb{E}_{x \mathtt{\sim} p_{data}(x)}[log D(x)] + \mathbb{E}_{z \mathtt{\sim} p_z(z)} [log(1- D(G(z)))]\)

General GAN problems

Illustration of GAN framework
Illustration of GAN framework

What are random projections ?

Illustration of Random Projection proposed framework
Illustration of Random Projection proposed framework

Before starting explaining the approach proposed, on important thing to understand is the random projections.

Random projections are simply a set of random filters generated before the training and applied to input images during training creating multiple projections of the data to lower dimensional space.

These Random filters are fixed during the training that each discriminator is looking at a different low dimensional view of input datasets. The Random filters are drawn i.i.d from a gaussian distribution and scaled to have unit L2 norm.

What’s the importance of multiple discriminator in low dimensional space?

In that case, the generator will get meaningful gradient signals from different discriminator each looking at a low dimensional set of features. The more discriminators you have from different projections the better the diversity and the quality of the generator used.

Proposal

In this game setup, the generator is trying to fool an array of discriminators. Each discriminator on a projection of the input training image is trying to maximize his classification accuracy of real vs fake.

The generator is getting gradient signals from the array of discriminators and tries create samples that will fool all of the discriminators. \(min_{G}max_{D_{k}} \sum_{i=1}^{K} (D_k,G) = \sum_{i=1}^{K} \mathbb{E}_{x \mathtt{\sim} p_{data}(x)}[log D_k(\mathbb{W_k^T}x)] + \mathbb{E}_{z \mathtt{\sim} p_z(z)} [log(1- D_k(\mathbb{W_k^T}G(z)))]\)

Experimental Results

Experimental Results reported on CelebFaces dataset
Experimental Results reported on CelebFaces dataset

In their experiment a simple DCGAN architecture was used on CelebFaces dataset. The details of the architectures along with the experiments are explained in a github notebook.

Constructive feedback

The idea of using an ensemble of discriminators to stabilize the GAN training is interesting and showed promising results at that time. Stabilizing the GAN training in this setup comes on the expense of the following:

Side Note

I would recommend reading the paper itself and checking the related work, this is just a summary to give you a rough idea of what is going on.

References

  1. Neyshabur, B., Bhojanapalli, S., & Chakrabarti, A. (2017). Stabilizing GAN Training with Multiple Random Projections. CoRR, abs/1705.07831. http://arxiv.org/abs/1705.07831