Back to basics. As part of my “reinventing-the-wheels” project to understand things deeply, I took some time to reimplement Conditional Generative Adversarial Nets from scratch. This is a note on that process.

Background

The Generative Adversarial Networks framework was introduced by Ian Goodfellow in 2014. Since then, many different types of GANs have been invented, and GANs have been one of the most active research areas in the machine learning community.

A few months after the original GAN paper was submitted, Conditional Generative Adversarial Nets (cGAN) was proposed. According to arXiv, the original GAN paper was submitted on 10 Jun 2014 and the cGAN paper on 6 Nov 2014.

The core motivation of the cGAN paper was that although GANs showed impressive image generation capabilities, there was no way to control or specify the type of image to generate; for instance, generating the digit ‘1’ from the MNIST dataset.

The proposed conditioning method is to simply provide some extra information y, such as class labels, to both the generator and the discriminator.

ConditionalAdversarialNet Figure from Conditional Generative Adversarial Nets paper

The authors demonstrated its effectiveness through MNIST experiments in which the generator and discriminator were conditioned on one-hot class labels.

Implementation

OK, time to code. My implementations: the original GAN (gan.py is the main file; models are defined in models/original_gan.py) and Conditional GAN.

The key differences are:

  1. Model definition: the input size of the first layer is now z_dim + num_classes.
  2. Training: concatenate noise z with the label, represented as a one-hot vector.

Model definition

class Generator(nn.Module):
    def __init__(self, batch, z_dim, out_shape, num_classes):
        super(Generator, self).__init__()

        self.batch_size = batch
        self.z_dim = z_dim
        self.out_shape = out_shape
        self.num_classes = num_classes

        self.fc1 = nn.Linear(z_dim+num_classes, 256) # simple concat
        ...

Similarly, the discriminator is modified to take input images concatenated with a one-hot label.

Training

generated_imgs = generator( torch.cat((z, label), dim=1) )

Quick tip: I found torch.Tensor.scatter_ useful for converting integer class labels into N-dimensional one-hot encodings (e.g., converting [3] into [0,0,0,1,0,0,0,0,0,0] where N=10). Since MNIST labels are integers, I convert them into one-hot vectors of shape [batch_size × num_classes] using the following function:

def convert_onehot(label, num_classes):
    """Return one-hot encoding of given list of label
    Args:
        label: A list of true label
        num_classes: The number of total classes for y
    Return:
        one_hot: Encoded vector of shape [length of data, num of classes]
    """
    one_hot = torch.zeros(label.shape[0], num_classes).scatter_(1, label, 1)
    return one_hot

Experimental results

The left image shows data from a particular batch; the right image shows the generated output conditioned on the corresponding labels.

True image Generated image

As you can see, the generated digits in the right image match those in the left. This is because the model is conditioned on the label, allowing it to generate a specific digit. Simple, yet it works.

Further reading

cGANs with Projection Discriminator is an interesting follow-up. The authors propose a novel way to incorporate conditional information into the GAN discriminator.

Rather than simply concatenating additional information to the input vector, they introduce a projection-based approach. Here is the key quote from the paper:

We propose a novel, projection based way to incorporate the conditional information into the discriminator of GANs that respects the role of the conditional information in the underlining probabilistic model. This approach is in contrast with most frameworks of conditional GANs used in application today, which use the conditional information by concatenating the (embedded) conditional vector to the feature vectors.

The figure below compares the main variants of cGANs. (a) is the approach I implemented above; (d) is the projection-based method.

Figure 1: Discriminator models for conditional GANs Figure from cGANs with Projection Discriminator paper

I plan to implement this approach and write a follow-up post.


References

[1]: Generative Adversarial Networks

[2]: Conditional Generative Adversarial Nets

[3]: cGANs with Projection Discriminator