$$ \newcommand{\dint}{\mathrm{d}} \newcommand{\vphi}{\boldsymbol{\phi}} \newcommand{\vpi}{\boldsymbol{\pi}} \newcommand{\vpsi}{\boldsymbol{\psi}} \newcommand{\vomg}{\boldsymbol{\omega}} \newcommand{\vsigma}{\boldsymbol{\sigma}} \newcommand{\vzeta}{\boldsymbol{\zeta}} \renewcommand{\vx}{\mathbf{x}} \renewcommand{\vy}{\mathbf{y}} \renewcommand{\vz}{\mathbf{z}} \renewcommand{\vh}{\mathbf{h}} \renewcommand{\b}{\mathbf} \renewcommand{\vec}{\mathrm{vec}} \newcommand{\vecemph}{\mathrm{vec}} \newcommand{\mvn}{\mathcal{MN}} \newcommand{\G}{\mathcal{G}} \newcommand{\M}{\mathcal{M}} \newcommand{\N}{\mathcal{N}} \newcommand{\S}{\mathcal{S}} \newcommand{\I}{\mathcal{I}} \newcommand{\diag}[1]{\mathrm{diag}(#1)} \newcommand{\diagemph}[1]{\mathrm{diag}(#1)} \newcommand{\tr}[1]{\text{tr}(#1)} \renewcommand{\C}{\mathbb{C}} \renewcommand{\R}{\mathbb{R}} \renewcommand{\E}{\mathbb{E}} \newcommand{\D}{\mathcal{D}} \newcommand{\inner}[1]{\langle #1 \rangle} \newcommand{\innerbig}[1]{\left \langle #1 \right \rangle} \newcommand{\abs}[1]{\lvert #1 \rvert} \newcommand{\norm}[1]{\lVert #1 \rVert} \newcommand{\two}{\mathrm{II}} \newcommand{\GL}{\mathrm{GL}} \newcommand{\Id}{\mathrm{Id}} \newcommand{\grad}[1]{\mathrm{grad} \, #1} \newcommand{\gradat}[2]{\mathrm{grad} \, #1 \, \vert_{#2}} \newcommand{\Hess}[1]{\mathrm{Hess} \, #1} \newcommand{\T}{\text{T}} \newcommand{\dim}[1]{\mathrm{dim} \, #1} \newcommand{\partder}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\rank}[1]{\mathrm{rank} \, #1} \newcommand{\inv}1 \newcommand{\map}{\text{MAP}} \newcommand{\L}{\mathcal{L}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} $$

Variational Autoencoder (VAE) in Pytorch

This post should be quick as it is just a port of the previous Keras code. For the intuition and derivative of Variational Autoencoder (VAE) plus the Keras implementation, check this post. The full code is available in my Github repo: https://github.com/wiseodd/generative-models.

The networks

Let’s begin with importing stuffs.

import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
c = 0
lr = 1e-3

Now, recall in VAE, there are two networks: encoder \( Q(z \vert X) \) and decoder \( P(X \vert z) \). So, let’s build our \( Q(z \vert X) \) first:

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var

Our \( Q(z \vert X) \) is a two layers net, outputting the \( \mu \) and \( \Sigma \), the parameter of encoded distribution. So, let’s create a function to sample from it:

def sample_z(mu, log_var):
    # Using reparameterization trick to sample from a gaussian
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps

Let’s construct the decoder \( P(z \vert X) \), which is also a two layers net:

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def P(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

Note, the use of b.repeat(X.size(0), 1) is because this Pytorch issue.

Training

Now, the interesting stuff: training the VAE model. First, as always, at each training step we do forward, loss, backward, and update.

params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

for it in range(100000):
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Forward
    # ...

    # Loss
    # ...

    # Backward
    # ...

    # Update
    # ...

    # Housekeeping
    for p in params:
        p.grad.data.zero_()

Now, the forward step:

    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

That is it. We just call the functions we defined before. Let’s continue with the loss, which consists of two parts: reconstruction loss and KL-divergence of the encoded distribution:

    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False)
    kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var)
    loss = recon_loss + kl_loss

Backward and update step is as easy as calling a function, as we use Autograd feature from Pytorch:

    # Backward
    loss.backward()

    # Update
    solver.step()

After that, we could inspect the loss, or maybe visualizing \( P(X \vert z) \) to check the progression of the training every now and then.

The full code could be found here: https://github.com/wiseodd/generative-models.