Generative Adverserial Networks

  • Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adverserial Networks or GANs, however, use neural networks for a very different purpose, a generative modeling. Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.
  • While there are many approaches used for generative modeling, a Generative Adverserial Network takes the following approach:
    • There are two neural networks: a Generator and a Discriminator.
    • The generator generates a "fake" sample given a random noise.
    • The discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator).
    • Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.
  • In this tutorial, we'll train a GAN to generate images of handwritten digits. A database for such handwritten digits is called MNIST database, easily and free downloadable. We used this MNIST dataset for traing, and try to generate the MNIST-like handwritten dataset with GAN.
  • We'll use the PyTorch library, which is mainly developed Meta. Inc., (previously known as Facebook).

Load the Data

  • We begin by downloading and importing the data as a PyTorch dataset using the MNIST helper class from torchvision.datasets.

    import torch
    import torchvision
    from torchvision.transforms import ToTensor, Normalize, Compose
    from torchvision.datasets import MNIST
    
    mnist = MNIST(root='data', 
                  train=True, 
                  download=True,
                  transform=Compose([ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]))
    
  • Note that we are are transforming the pixel values from the range [0, 1] to the range [-1, 1]. The reason for doing this will become clear when define the generator network. Let's look at a sample tensor from the data.

    img, label = mnist[0]
    print('Label: ', label)
    print(img[:,10:15,10:15])
    torch.min(img), torch.max(img)
    
  • Let's plot one of the images.

    import matplotlib.pyplot as plt
    %matplotlib inline
    
    plt.imshow(img[0], cmap='gray')
    plt.show()
    print('Label:', label)
    
  • Finally, let's create a dataloader to load the images in batches.

    from torch.utils.data import DataLoader
    
    batch_size = 100
    data_loader = DataLoader(mnist, batch_size, shuffle=True)
    
  • We'll also create a device which can be used to move the data and models to a graphical procssing unit (GPU), if one is available. Using GPU is faster than using CPU!

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    

Discriminator Network

  • The discriminator takes an image as input, and tries to classify it as "real" or "generated". In this sense, it's like any other neural network. While we can use a complicated neural networks, we'll use a simple one with 3 layers.
  • Since the MNIST image has 28 pixels in width and height, it is 28x28 = 784 pixels. We'll transform it to a vector of size 784.

    image_size = 784
    hidden_size = 256
    import torch.nn as nn
    
    D = nn.Sequential(
        nn.Linear(image_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, 1),
        nn.Sigmoid())
    
  • We use the Leaky ReLU activation for the discriminator. Different from the regular ReLU function, Leaky ReLU allows the pass of a small gradient signal for negative values. As a result, it makes the gradients from the discriminator flows stronger into the generator. Instead of passing a gradient (slope) of 0 in the back-prop pass, it passes a small negative gradient.

  • Just like any other binary classification model, the output of the discriminator is a single number between 0 and 1, which can be interpreted as the probability of the input image being fake i.e. GAN-generated.

  • After defineing the discriminator, let's move the model to the chosen device.

    D.to(device)
    

Generator Network

  • The input to the generator is typically a vector of random noise. Once again, to keep things simple, we'll use a neural network with 3 layers, and the output will be a vector of size 784, which can be transformed to a 28x28 pixel image.

    latent_size = 64
    G = nn.Sequential(
        nn.Linear(latent_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, image_size),
        nn.Tanh())
    
  • We use the hyperbolic tangent (tanH) activation function for the output layer of the generator.

  • Let's generate an output vector using the generator and view it as an image by transforming and denormalizing the output.

    y = G(torch.randn(2, latent_size))
    gen_imgs = y.reshape((-1, 28,28)).detach()
    plt.imshow(gen_imgs[0], cmap='gray')
    plt.show()
    
  • Now move the generator to the chosen device.

    G.to(device)
    

Discriminator Training

  • Since the discriminator is a binary classification model, we can use the binary cross entropy loss function to quantify how well it is able to differentiate between real and generated images.

    criterion = nn.BCELoss()
    d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
    
  • Let's define helper functions to reset gradients and train the discriminator.

    def reset_grad():
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
    
    def train_discriminator(images):
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
    
        # Loss for real images
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
    
        # Loss for fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
    
        # Combine losses
        d_loss = d_loss_real + d_loss_fake
        # Reset gradients
        reset_grad()
        # Compute gradients
        d_loss.backward()
        # Adjust the parameters using backprop
        d_optimizer.step()
    
        return d_loss, real_score, fake_score
    
  • Here are the steps involved in training the discriminator.

    1. We expect the discriminator to output 1 if the image was picked from the real MNIST dataset, and 0 if it was generated.
    2. We first pass a batch of real images, and compute the loss, setting the target labels to 1.
    3. Then we generate a batch of fake images using the generator, pass them into the discriminator, and compute the loss, setting the target labels to 0 (fake).
    4. Finally we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.
  • It's important to note that we don't change the weights of the generator model while training the discriminator (d_optimizer only affects the D.parameters())

Generator Training

  • Since the outputs of the generator are images, it's not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function.
  • Here's how it works: We generate a batch of images using the generator, pass the into the discriminator.
  • We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator's objective is to "fool" the discriminator.
  • We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images.

  • Here's what this looks like in code.

    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
    def train_generator():
        # Generate fake images and calculate loss
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        labels = torch.ones(batch_size, 1).to(device)
        g_loss = criterion(D(fake_images), labels)
    
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        return g_loss, fake_images
    

Training the Model

  • Let's create a directory where we can save intermediate outputs from the generator to visually inspect the progress of the model

    import os
    
    sample_dir = 'samples'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    
  • Let's save a batch of real images that we can use for visual comparision while looking at the generated images.

    from IPython.display import Image
    from torchvision.utils import save_image
    
    # Save some real images
    for images, _ in data_loader:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(images, os.path.join(sample_dir, 'real_images.png'), nrow=10)
        break
    
    Image(os.path.join(sample_dir, 'real_images.png'))
    
  • We'll also define a helper function to save a batch of generated images to disk at the end of every epoch. We'll use a fixed set of input vectors to the generator to see how the individual generated images evolve over time as we train the model.

    sample_vectors = torch.randn(batch_size, latent_size).to(device)
    
    def save_fake_images(index):
        fake_images = G(sample_vectors)
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        fake_fname = 'fake_images-{0:0=4d}.png'.format(index)
        print('Saving', fake_fname)
        save_image(fake_images, os.path.join(sample_dir, fake_fname), nrow=10)
    
    # Before training
    save_fake_images(0)
    Image(os.path.join(sample_dir, 'fake_images-0000.png'))
    
  • We are now ready to train the model. In each epoch, we train the discriminator first, and then the generator.

  • The training might take a while if you're not using a GPU.

    num_epochs = 20
    total_step = len(data_loader)
    d_losses, g_losses, real_scores, fake_scores = [], [], [], []
    
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(data_loader):
            # Load a batch & transform to vectors
            images = images.reshape(batch_size, -1).to(device)
    
            # Train the discriminator and generator
            d_loss, real_score, fake_score = train_discriminator(images)
            g_loss, fake_images = train_generator()
    
            # Inspect the losses
            if (i+1) % 200 == 0:
                d_losses.append(d_loss.item())
                g_losses.append(g_loss.item())
                real_scores.append(real_score.mean().item())
                fake_scores.append(fake_score.mean().item())
                print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                      .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                              real_score.mean().item(), fake_score.mean().item()))
    
        # Sample and save images
        if (epoch+1) % 5 == 0:
            save_fake_images(epoch+1)
    
  • Here's how the generated images look, after the 10th, 50th, 100th and 300th epochs of training.

    Image('./samples/fake_images-0050.png')
    Image('./samples/fake_images-0100.png')
    Image('./samples/fake_images-0300.png')
    
  • We can also visualize how the loss changes over time. Visualizing losses is quite useful for debugging the training process. For GANs, we expect the generator's loss to reduce over time, without the discriminator's loss getting too high.

    plt.plot(d_losses, '-')
    plt.plot(g_losses, '-')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Discriminator', 'Generator'])
    plt.title('Losses')
    plt.show()
    

results matching ""

    No results matching ""