Generative Adversarial Networks (GANs) have become wildly popular in the artificial intelligence field, particularly for their ability to generate new, realistic data from existing datasets. In this blog post, we'll focus on a specific type of GAN - the Wasserstein GAN (WGAN). The WGAN, introduced in 2017 by researchers Martin Arjovsky, Soumith Chintala, and Léon Bottou, addresses some of the issues found in the standard GAN model, such as the vanishing gradient problem and mode collapse.
The primary difference between a WGAN and a standard GAN lies in the choice of loss function. WGANs use the Wasserstein-1 distance (also known as the Earth Mover's distance) as the objective function, which leads to more stable training and better results in practice.
In this blog post, we'll walk through the key concepts behind Wasserstein GANs, implement a WGAN using PyTorch, and share some code snippets to help you get started quickly.
This tutorial assumes that you have some basic experience with Python and machine learning concepts. You should also be familiar with the basics of GANs and PyTorch. If you're not familiar with these topics, here is a great introductory guide to GANs, and here is a good resource to learn PyTorch.
Before we start implementing the WGAN, let's import the necessary libraries:
import numpy as np import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader
Next, we define the hyperparameters that we'll use throughout our implementation:
batch_size = 64 epochs = 100 latent_dim = 100 lr = 1e-4 weight_clip = 0.01 disc_updates = 5
Here, batch_size
, epochs
, and lr
represent the batch size, number of training epochs, and learning rate, respectively. The latent_dim
is the size of the latent space from which we will sample random noise to generate images. The weight_clip
is the clip value for discriminator weights, which ensures that the weights stay within a compact space. Finally, disc_updates
is the number of times we update the discriminator's weights for each generator's weight update.
Now, let's define the architecture for the generator and discriminator. We will use simple feedforward neural networks with linear layers and Leaky ReLU activations:
class Generator(nn.Module): def __init__(self, input_dim, output_dim): super(Generator, self).__init__() self.layers = nn.Sequential( nn.Linear(input_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, output_dim), nn.Tanh() ) def forward(self, x): return self.layers(x) class Discriminator(nn.Module): def __init__(self, input_dim): super(Discriminator, self).__init__() self.layers = nn.Sequential( nn.Linear(input_dim, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1) ) def forward(self, x): return self.layers(x)
Now we need to prepare our dataset and create DataLoader objects that will handle the batching of our data. For this example, we will use the MNIST dataset:
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) mnist_data = datasets.MNIST("data", train=True, download=True, transform=transform) dataloader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True, num_workers=4)
Finally, let's create our generator and discriminator objects, define our optimizers, and train the Wasserstein GAN:
generator = Generator(latent_dim, 28*28).cuda() discriminator = Discriminator(28*28).cuda() optimizer_G = optim.RMSprop(generator.parameters(), lr=lr) optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr) for epoch in range(epochs): for i, (data, _) in enumerate(dataloader): real_data = data.view(batch_size, -1).cuda() # Train discriminator for _ in range(disc_updates): z = torch.randn(batch_size, latent_dim).cuda() fake_data = generator(z) disc_loss = -(torch.mean(discriminator(real_data)) - torch.mean(discriminator(fake_data.detach()))) optimizer_D.zero_grad() disc_loss.backward() optimizer_D.step() # Clip discriminator weights for p in discriminator.parameters(): p.data.clamp_(-weight_clip, weight_clip) # Train generator gen_data = generator(z) gen_loss = -torch.mean(discriminator(gen_data)) optimizer_G.zero_grad() gen_loss.backward() optimizer_G.step() print(f"Epoch {epoch+1}/{epochs}")
That's it! With this code, you'll be able to train a Wasserstein GAN on the MNIST dataset and generate realistic handwritten digits. Note that there are several other hyperparameters and architecture choices that you can tune to achieve better results. Feel free to experiment and see what works best for your problem.
In this blog post, we covered the basics of Wasserstein GANs, including their motivation, differences from standard GANs, and a PyTorch implementation for training a WGAN on the MNIST dataset. We hope this information helps you better understand how Wasserstein GANs work and inspires you to apply them to your projects. Happy coding!