Basic Math behind Adversarial Learning

This post discusses GANs, providing an analysis of the learning paradigm with discussions on objective functions and optimization. I have included a link to a notebook GAN implementation on the MNIST fashion dataset with explanations at the end.

Touted to be one of the hottest areas in deep learning, the subdomain of generative modelling was taken by storm with Goodfellow’s 2014 paper titled Generative Adversarial Networks. The idea was simple enough - pit one module against another to learn a probability distribution as best as possible.

Over the last couple of years, GANs have managed to widen the horizons of generative modelling with unprecedented breakthroughs, including some really fun (but computationally expensive) applications. Check out Text2Image - a GAN that takes in text describing an imaginary setting and manages to construct an image, sketch2image - a network that converts sketches to complete images and Font Generation - learning and reproducing Chinese characters with a conditional GAN.

gan Source : skymind.ai

Let’s skim through the GAN learning paradigm once before we dive into the math. There are primarily two modules involved here, the discriminator ($D$) and the generator ($G$). The interplay of these two networks constitutes adversarial learning, for the reason that they are engaged in a zero-sum game. This means that if any network performs ‘better’ than the other, it is at the expense of the other network. Let the distribution we intend to learn (we shall restrict ourselves to images in this post) be denoted $p(x)$. The generator module tries to approximate $p(x)$ as best as possible by constructing images while the discriminator aims to tell these constructed fake images apart from their real counterparts sampled from $p(x)$. The two modules are typically trained in parallel in a process where the weights are backpropogated and adjusted separately for each. Over multiple iterations, the generator develops a better approximation of the distribution while the discriminator simultaneously learns to distinguish the two distributions more effectively.

Loss function

The loss function for the discriminator is straightforward. We intend to have a discriminator capable of separating genuine images from fake ones, and hence it is a simple two class classification model. We use binary crossentropy as the loss function, defined as follows

The label coming from $p_{data}(x)$ (genuine images) is set to $1$ while the label from the generator is set to $0$. Now, for the genuine data we have $y=1$ and $\hat{y}=D(x)$, giving

We feed random noise $z$ as input to the generator. Hence, $y=0$ and $\hat{y}=D(G(z))$, giving

The objective of the discriminator is to maximise both these terms. To understand the reason, look at both the terms separately. $\log(D(x))$ attains its maximum value at $D(x)=1$, since $D(x)\in [0,1]$ . This validates that the discriminator should push such outputs closer to $1$. On the other hand, $log(1-D(G(z)))$ has a maxima at $x=0$ in the given range. Achieving this maximum value would mean the discriminator successfully classifies images generated by the generated as fake. Hence, the discriminator objective becomes

The ideal generator would produce images that the discriminator classifies as genuine. In other words, it attempts to maximise $D(G(z))$. When the GAN is trained, however, the value of $D(G(z))$ tends to be a very small value, given that the generator initially produces meaningless interpretations of the image. Hence having $max(log(D(G(z))))$ as our objective function would lead to the vanishing gradient problem. To circumvent this, we instead choose

Combining the two objectives to obtain an idea of the GAN objective, we get

We’re getting close, but this is not our objective function. The $x$ and $z$ here refer to isolated data points, and present an incomplete picture. Here’s the objective proposed in Goodfellow’s original GAN paper

where $p_{g}(z)$ refers to the probability distribution learned by the generator.

We have established that the training objective of the discriminator is to maximise $V(D,G)$. Hence, the optimal discriminator given a generator $G$ is $D_G^*$ given by

Here, the original GAN paper makes the leap to

To make this simplification, we require a calculus brush up. If the probability density function of a random variable $X$ is given by $p_X(x)$, we can calculate the probability density function of some $Y$ such that $Y=G(X)$ as follows

Our equation takes the function $X=G(Z)$ as the transformation X is chosen as the output as the generator generates samples from the same family of images as the input distribution $p_X(x)$. Consider the second term of the expression for $D_G^*$

Rearranging the terms and using the pdf transformation rule, we obtain the required expression

To summarize, the objective function is now given by

To maximise this and find the optimal discriminator $D^*$, we find the derivate of the integrand $V(D,G)$ with $D$

It is easy to verify that the double derivative at this point is negative, confirming that $\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}$ is indeed a maxima for $D_G^*$. Here, the optimal generator $G^*$ is defined as

Add and subtract $(\log2)p_{data}(x)$ and $(\log2)p_{g}(x)$ and group terms to obtain

Consider each of these terms individually. The first one simplifies to $\int_{x}(-\log2(1+1))dx=-2\log2$ as the integral of probability density functions over their domain is $1$. The second and third terms reduce to

If you’ve read my first blog post on AutoEncoders, you can infer that each of the terms above is of a similar form as the expression for KL divergence. If not, it doesn’t matter as this expression simplifies to a different (and better) metric. To recap,

Jensen-Shannon Divergence (JSD)

JSD, like KL divergence, is a measure of similarity between two probability distributions. Given two probability distributions $P$ and $Q$, we define a third distribution $M$ such that $M=\frac{P+Q}{2}$.

JSD is a more reliable metric for comparing distribution similarity when compared to KL divergence. The latter can, at times, blow up unexpectedly. Consider the expression for KL divergence

At points where $P(x)>0,Q(x)=0$ the KL divergence is infinite, meaning that even the presence of a single such point can mess up our algorithm. Moreover, JSD is a symmetric function and this eliminates the choice induced by the dissymetry of KL divergence where each variant has its own drawbacks.

Let us rewrite $G^*$ in terms of JSD now

The minimum value is achieved when the $JSD$ term is $0$. This makes perfect sense as the optimal generator would be the one that tries to bring $p_{g}$ close to $p_{data}$ most successfully. Recall our expression for the optimal discriminator

When $p_{data}(x)=p_{g}(x)$, we obtain $D_G^*=\frac{1}{2}$. This, again, makes sense as the optimal discriminator is one that manages to distinguish genuine images from generated ones without misclassifications.

Optimization

ganopt

Here’s the training paradigm proposed in Goodfellow’s original paper. The algorithm uses the Monte-Carlo approximation of expectation, that is $\frac{1}{m}\sum\limits_{i=1}^mx=E(x)$. The long expressions in the algorithms are our objectives defined earlier.

The discriminator is trained by sampling mini-batches of noise samples and examples to calculate the objective function and iteratively update its parameters, after which the generator generates its mini-batch of images sampled from random noise based on $p_g$ to calculate its objective and perform gradient descent to update its parameters. As mentioned in the header, the value $k$ is a hyperparameter usually set to $1$. We shall revisit this hyperparameter in the post on W-GANs, where a modification is necessiated.

Takeaways

This post is the first part in a series on adversarial learning for generative modelling. This is a link to a TensorFlow implementation of a basic GAN on the MNIST fashion dataset, along with code explanations wherever necessary.

GANs, although effective for a number of image generation tasks, have their limitations. We discuss these along with possible solutions in the next post.

Written on June 28, 2019