Variational Autoencoders

Open In Colab Download Jupyter Notebbok

In this tutorial, we want to use PyBlaze to train a variational autoencoder (VAE). More precisely, we want to generate handwritten digits as obtained from the MNIST dataset.

Later on, we will repeat the same tutorial and train a Wasserstein GAN instead of a VAE.

Note: This tutorial currently lacks both explanation and any theory. It will be added in the future.

[1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as D
import torchvision
import torchvision.transforms as transforms
import pyblaze.nn as xnn
import pyblaze.nn.functional as X
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2

plt.style.use('seaborn-notebook')
plt.rcParams['figure.dpi'] = 150

Loading the Data

At first, we want to load the data. Again, torchvision can make our life easier:

[2]:
train_val_dataset = torchvision.datasets.MNIST(
    root="~/Downloads/", train=True, download=True, transform=transforms.ToTensor()
)
test_dataset = torchvision.datasets.MNIST(
    root="~/Downloads/", train=False, download=True, transform=transforms.ToTensor()
)

As we did before, we can also easily initialize the data loaders:

[3]:
train_dataset, val_dataset = train_val_dataset.random_split(0.8, 0.2)
train_loader = train_dataset.loader(batch_size=256, num_workers=4, shuffle=True)
val_loader = val_dataset.loader(batch_size=2048)
test_loader = test_dataset.loader(batch_size=2048)

Before we continue to set up our model, we first have a look at a few randomly sampled images from our training data:

[4]:
plt.figure()

images = [train_dataset[i] for i in np.random.choice(len(train_dataset), 10)]
for i, (image, _) in enumerate(images):
    plt.subplot(1, 10, i+1)
    plt.imshow(image[0], cmap='binary')
    plt.axis('off')

plt.show()
../_images/examples_vae_7_0.png

Defining the Model

After we had a look at the data, we can define our model. We use convolutional layers in the encoder and scale the hidden representation up in the end.

[5]:
class Encoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc_mu = nn.Linear(576, 16)
        self.fc_logvar = nn.Linear(576, 16)

    def forward(self, x):
        z = self.conv(x)
        z = z.view(z.size(0), -1)
        return self.fc_mu(z), self.fc_logvar(z)


class Decoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(16, 2048)
        self.conv = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 5, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 6)
        )

    def forward(self, x):
        z = self.fc(x)
        z = z.view(-1, 128, 4, 4)
        return self.conv(z)


class VAE(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        mu, logvar = self.encoder(x)
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        z = mu + eps * std
        return self.decoder(z), mu, logvar

Having defined the model, we can initialize it. Let’s also see how big it is:

[6]:
model = VAE()

print(f'Total parameters:   {sum(p.numel() for p in model.parameters()):6,}')
print(f'Encoder parameters: {sum(p.numel() for p in model.encoder.parameters()):6,}')
print(f'Decoder parameters: {sum(p.numel() for p in model.decoder.parameters()):6,}')
Total parameters:   464,577
Encoder parameters: 172,512
Decoder parameters: 292,065

Training the Model

[7]:
optimizer = optim.Adam(model.parameters())
loss = xnn.VAELoss(nn.BCEWithLogitsLoss(reduction='none'))
engine = xnn.AutoencoderEngine(model, expects_data_target=True)
[8]:
history = engine.train(
    train_loader,
    val_data=val_loader,
    epochs=50,
    eval_every=5,
    optimizer=optimizer,
    loss=loss,
    callbacks=[
        xnn.BatchProgressLogger()
    ]
)
Epoch 1/50:
 [Elapsed 0:00:05 | 33.81 it/s] loss: 224.27307, val_loss: 159.77618
Epoch 2/50:
 [Elapsed 0:00:03 | 48.45 it/s] loss: 136.16475
Epoch 3/50:
 [Elapsed 0:00:03 | 47.73 it/s] loss: 118.60185
Epoch 4/50:
 [Elapsed 0:00:03 | 47.37 it/s] loss: 113.22937
Epoch 5/50:
 [Elapsed 0:00:03 | 48.84 it/s] loss: 110.84675
Epoch 6/50:
 [Elapsed 0:00:05 | 33.62 it/s] loss: 109.05308, val_loss: 108.66338
Epoch 7/50:
 [Elapsed 0:00:03 | 48.16 it/s] loss: 107.91442
Epoch 8/50:
 [Elapsed 0:00:03 | 48.23 it/s] loss: 106.88257
Epoch 9/50:
 [Elapsed 0:00:03 | 48.17 it/s] loss: 106.14308
Epoch 10/50:
 [Elapsed 0:00:04 | 46.60 it/s] loss: 105.65662
Epoch 11/50:
 [Elapsed 0:00:05 | 32.55 it/s] loss: 104.96011, val_loss: 105.40081
Epoch 12/50:
 [Elapsed 0:00:03 | 47.57 it/s] loss: 104.49724
Epoch 13/50:
 [Elapsed 0:00:03 | 48.58 it/s] loss: 104.16228
Epoch 14/50:
 [Elapsed 0:00:04 | 46.98 it/s] loss: 103.70670
Epoch 15/50:
 [Elapsed 0:00:03 | 47.52 it/s] loss: 103.35032
Epoch 16/50:
 [Elapsed 0:00:05 | 32.94 it/s] loss: 102.94538, val_loss: 104.24427
Epoch 17/50:
 [Elapsed 0:00:03 | 47.96 it/s] loss: 102.75075
Epoch 18/50:
 [Elapsed 0:00:03 | 47.43 it/s] loss: 102.43628
Epoch 19/50:
 [Elapsed 0:00:04 | 46.99 it/s] loss: 102.14312
Epoch 20/50:
 [Elapsed 0:00:03 | 48.86 it/s] loss: 102.00880
Epoch 21/50:
 [Elapsed 0:00:05 | 34.45 it/s] loss: 101.70227, val_loss: 102.42951
Epoch 22/50:
 [Elapsed 0:00:03 | 47.53 it/s] loss: 101.51430
Epoch 23/50:
 [Elapsed 0:00:03 | 47.14 it/s] loss: 101.41744
Epoch 24/50:
 [Elapsed 0:00:03 | 48.21 it/s] loss: 101.11735
Epoch 25/50:
 [Elapsed 0:00:04 | 46.84 it/s] loss: 100.98458
Epoch 26/50:
 [Elapsed 0:00:05 | 34.03 it/s] loss: 100.77367, val_loss: 101.82356
Epoch 27/50:
 [Elapsed 0:00:03 | 47.59 it/s] loss: 100.74920
Epoch 28/50:
 [Elapsed 0:00:03 | 47.90 it/s] loss: 100.51038
Epoch 29/50:
 [Elapsed 0:00:03 | 47.73 it/s] loss: 100.34299
Epoch 30/50:
 [Elapsed 0:00:03 | 47.22 it/s] loss: 100.36984
Epoch 31/50:
 [Elapsed 0:00:05 | 34.22 it/s] loss: 100.08789, val_loss: 101.58796
Epoch 32/50:
 [Elapsed 0:00:03 | 47.51 it/s] loss: 99.98875
Epoch 33/50:
 [Elapsed 0:00:03 | 47.47 it/s] loss: 99.85116
Epoch 34/50:
 [Elapsed 0:00:03 | 47.31 it/s] loss: 99.75701
Epoch 35/50:
 [Elapsed 0:00:03 | 48.25 it/s] loss: 99.61550
Epoch 36/50:
 [Elapsed 0:00:05 | 33.42 it/s] loss: 99.44988, val_loss: 100.72999
Epoch 37/50:
 [Elapsed 0:00:03 | 47.57 it/s] loss: 99.42586
Epoch 38/50:
 [Elapsed 0:00:03 | 47.50 it/s] loss: 99.36446
Epoch 39/50:
 [Elapsed 0:00:03 | 47.76 it/s] loss: 99.26171
Epoch 40/50:
 [Elapsed 0:00:03 | 48.82 it/s] loss: 99.17992
Epoch 41/50:
 [Elapsed 0:00:05 | 33.88 it/s] loss: 99.09690, val_loss: 100.31133
Epoch 42/50:
 [Elapsed 0:00:03 | 47.69 it/s] loss: 98.97497
Epoch 43/50:
 [Elapsed 0:00:03 | 49.46 it/s] loss: 98.90181
Epoch 44/50:
 [Elapsed 0:00:03 | 48.44 it/s] loss: 98.79760
Epoch 45/50:
 [Elapsed 0:00:03 | 48.53 it/s] loss: 98.63344
Epoch 46/50:
 [Elapsed 0:00:05 | 33.40 it/s] loss: 98.56735, val_loss: 100.23586
Epoch 47/50:
 [Elapsed 0:00:04 | 46.33 it/s] loss: 98.47528
Epoch 48/50:
 [Elapsed 0:00:04 | 45.59 it/s] loss: 98.47498
Epoch 49/50:
 [Elapsed 0:00:03 | 47.31 it/s] loss: 98.40941
Epoch 50/50:
 [Elapsed 0:00:05 | 33.68 it/s] loss: 98.32704, val_loss: 100.99587

Inspecting the Model

Reconstructing Images

[9]:
reconstructed = engine.predict(
    test_loader,
    reconstruct=True
)
[10]:
plt.figure(dpi=150)

im = next(iter(test_loader))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(np.concatenate([
        im[0][i][0].numpy(),
        reconstructed[0][i].sigmoid().numpy().reshape(28, 28)
    ]), cmap='binary')
    plt.axis('off')

plt.show()
../_images/examples_vae_18_0.png

Generating Images

[11]:
dim = 16
distribution = D.Normal(torch.zeros(dim), torch.ones(dim))
dist_data = xnn.NoiseDataset(distribution)
[12]:
generated = engine.predict(
    dist_data.loader(batch_size=10),
    iterations=1,
    reconstruct=False
)
[13]:
plt.figure()

for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(generated[i].sigmoid().numpy().reshape(28, 28), cmap='binary')
    plt.axis('off')

plt.show()
../_images/examples_vae_22_0.png