In this post, I’ll implement feature matching, a simple tehnique to mitigate GANs mode collapse.

What is mode collapse

Mode collapse is one of the well-known problems in training GANs. In the real world, data distribution may be multi modal (having multiple peaks). In that case, the generator can fool the discriminator by just creating single plausible sample from one of the peaks. As such, when mode collapse happens, the generator’s output tends to be one or some small numbers of similar data.

While I was playing with DCGAN, I experienced this. This is the sample (real) images from OpenAI gym CarRacing environment.

realImg

And this is the generated images. After at some point of training, the same image is generated again and again.

generatedImg

The graph below is losses through 300 epochs. This indicates that the generator doesn’t generate a realistic data well.

loss_DCGAN_300

Proposed solutions/mitigations

Some mitigations were suggested from previous studies. Here is the list of them.

  • Feature Matching
  • Minibatch Discrimination
  • WGAN
  • Unrolled GAN

In this post, I’ll explore feature matching [1] a bit (and explore rest of them further in the future posts).

The technique was proposed in the paper “Improved Techniques for Training GANs”. This is fairly simple to implement and easy to see if it makes the situation better.

Instead of directly maximizing the output of the discriminator, the new objective requires the generator to generate data that matches the statistics of the real data, where we use the discriminator only to specify the statistics that we think are worth matching.

In this regard, our loss function is changed to this where denote activations on an intermediate layer of the discriminator.

Let’s try this out. To implement this, first, modify discriminator to return intermediate tensor along with final output.

You can find existing implementation of DCGAN like PyTorch tutorial as your starting point, although the following code is not from tutorial but the main point is to return intermediate tensor so you can copy and paste and fix it.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, 4, 2, 1, bias=False)

        self.conv2 = nn.Conv2d(32, 64, 4, 2, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.bn3   = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
        self.bn4   = nn.BatchNorm2d(256)

        self.conv5 = nn.Conv2d(256, 1, 4, 1, 0, bias=False)

        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky_relu( self.conv1(x) )
        x = self.leaky_relu( self.bn2(self.conv2(x)) )
        x = self.leaky_relu( self.bn3(self.conv3(x)) )
        x = intermediate = self.leaky_relu( self.bn4(self.conv4(x)) ) # Return intermediate is for feature matching
        x = self.conv5(x)
        return x.squeeze(), intermediate

Then, change your loss function of the generator from BCE to MSE and compute the loss like following way.

loss_f_G = nn.MSELoss() # MSE to implement Feature Matching
...
real_out, inter_real = model_D(real_img)
...
out, inter_fake = model_D(fake_img)
loss_G = loss_f_G(inter_real, inter_fake)

Basically, that’s it and here is the experimental results.

Feature Matching

Quality of generated images are fine and the losses of both the generator and the discriminator looks good. This simple technique nicely improves the GANs performance.

However, one thing to note is that the loss of the generator started rising at some point and didn’t go down. When I train GANs, I frequently see this kind of situation. It’s one of the problem to determine when to finish GANs training. I’ll study this further.

loss_DCGAN_FM_300

In this post, I introduced feature matching and confirmed that this simple approach actually solved mode collapse problem. I’m going to implement other known techniques later.


References

[1]: Improved Techniques for Training GANs