Variational Autoencoders

Autoencoder is an unsupervised model - a deep neural network architecture - which contains an encoder and decoder
  • The encoder component serves as compressing the input to a lower-dimensional representation;
  • The decoder aims to reconstruct the compressed representation back to the original input. 

The architecture is pretty simple, with the number of neurons in the layers of the encoder part (blue below) decreases, and then starts increasing again in the decoder part (purple below).

Input image => Dense(256) => Dense(64) => Dense(2) => Dense(64) => Dense(246) => Output (reconstructed image)

As one might expect, the loss is between the input image/data and reconstructed one, as the part of the name auto (self-supervised) implies.

Variational AutoEncoder (VAE) is the probablistic twist of the Autoencoder. Instead of giving deterministic output of both encoder and decoder in Autoencoder, that in VAE give probability distributions. 

As Autoencoder, the loss of VAE contains the reconstruction loss between the input data and reconstructed data. In addition, it also has a regularization term - the KL divergence between the posterior distribution (output of the encoder) and a prior (usually using a simple isotropic Gaussian). 

In this post, we go through the implementation of VAE with Tensorflow, Tensorflow Probability, and Keras. The example below is from Probabilistic Deep Learning with TensorFlow 2 course from Coursera, which by the way, I am highly recommend if you want to get familiar with Tensorflow Probability module. 

Contents


Import required packages


import tensorflow as tf
import tensorflow_probability as tfp
import seaborn as sns
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Flatten, Dense, Reshape

print("tensorflow", tf.__version__)
print("tensorflow probability", tfp.__version__)
print("matplotlib", matplotlib.__version__)
print("numpy", np.__version__)
print("seaborn", sns.__version__)

tfd = tfp.distributions
tfpl = tfp.layers

tensorflow 2.8.0
tensorflow probability 0.14.0
matplotlib 3.8.0
numpy 1.26.0
seaborn 0.13.0


Fashion MNIST dataset

As we did in the Autoencoder post, we use the Fashion MNIST dataset is from Zalando. Zalando is a publicly traded German online retailer of shoes, fashion and beauty active across Europe. 

The dataset consists of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We don't use those labels but only use images as we want to use Autoencoder to compress and reconstruct a given image. Let's get started.

# Fashion MNIST dataset

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# If using Bernoulli, we can scale by simply dividing the max value 
#X_train = X_train.astype('float32')/256.
#X_test = X_test.astype('float32')/256.
# If we use Beta distribution defined within interval (0,1), 
# we need to scale images to the range (avoiding 0)
X_train = X_train.astype('float32')/256. + 0.5/256
X_test = X_test.astype('float32')/256. + 0.5/256

print(X_train.shape)

class_names = np.array([
    'T-shirt/top', 
    'Trouser/pants', 
    'Pullover shirt', 
    'Dress',
    'Coat', 
    'Sandal', 
    'Shirt', 
    'Sneaker', 
    'Bag',
    'Ankle boot'
])
(60000, 28, 28)

# Show some examples of the data

n_examples = 1000
example_images = X_test[0:n_examples]
example_labels = y_test[0:n_examples]

fig, axes = plt.subplots(1, 5, figsize=(15, 4))
for i in range(len(axes)):
    axes[i].imshow(example_images[i], cmap='binary')
    axes[i].set_title(class_names[example_labels[i]])
    axes[i].axis('off')



Encoder

The same as Autoencoder post, we keep the encoded dimention as 2, i.e., the output of the encoder is 2-dimentional vector - 2 random variables from probability distributions. And as mentioned earlier, we have a prior distribution for the KL divergence loss between it and the posterior distribution (the output distribution from encoder for those 2 random variables).

encoded_dim = 2

# Identity covariance matrix by default
prior = tfd.MultivariateNormalDiag(
    loc=tf.zeros(encoded_dim)
)

Here, we list three different versions for the encoder part. You can focus on first version only and then come back to investigate other versions if you are interested in learning more.

The posterior distribution is also multivariate Gaussian, but its parameters will be learned during training. The KLDivergenceAddLoss is a bypass layer, but will automatically add KL divergence loss between the prior and posterior to the main loss - the reconstruction loss.

# Encoder version 1
# Feel free to skip other versions and come back later

encoder = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256, activation='relu'),
    Dense(64, activation='relu'),
    Dense(tfpl.MultivariateNormalTriL.params_size(encoded_dim)),
    tfpl.MultivariateNormalTriL(encoded_dim),
    tfpl.KLDivergenceAddLoss(prior)
])

print(encoder.losses, end='\n\n')
print(encoder(example_images), end='\n\n')
print(encoder.losses)

[tf.Tensor 'kl_divergence_add_loss_1/kldivergence_loss/batch_total_kl_divergence:0' shape=() dtype=float32]

tfp.distributions._TensorCoercible("sequential_2_multivariate_normal_tri_l_1_tensor_coercible", batch_shape=[1000], event_shape=[2], dtype=float32)

[tf.Tensor: shape=(), dtype=float32, numpy=0.89796853]

The second version gives more details of some parameters in the KLDivergenceAddLoss.

# Encoder version 2: with some arguments

encoder = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256, activation='relu'),
    Dense(64, activation='relu'),
    Dense(tfpl.MultivariateNormalTriL.params_size(encoded_dim)),
    tfpl.MultivariateNormalTriL(encoded_dim),
    tfpl.KLDivergenceAddLoss(
        prior,
        use_exact_kl=False,
        weight=1.5,
        test_points_fn=lambda d: d.sample(10),
        test_points_reduce_axis=0
    )
])
  • weight: what multiple of KL divergence to be added to the loss, useful for implementing 𝛽-VAE, where 𝛽 indicates the weight of KL divergence.
  • test_points_fn: Receives batch of distributions, returns tensor of samples of shape (n_sample, batch_size, dim_z). These samples are converted to scalar value. $z_{ij}$ is the $i$-th sample for the observation $x_j$ (is at (i,j,:) in the tensor of samples) is mapped to $\log q(z_{ij})|x_j-\log p(z_{ij})$. This implies the tensor of samples returned by test_points_fn is converted into a tensor of values with a shape (n_samples, batch)
  • test_points_reduce_axis: to compute the loss added to the model, this arg indicates axis to average over (reduce_mean)
We can do exactly the same thing without using the KLDivergenceAddLoss. Instead, we can declare the KLDivergenceRegularizer as below and specify the activity_regularizer parameter in the MultivariateNormalTriL directly.

# Encoder version 3: Using KLDivergenceRegularizer

divergence_regularizer = tfpl.KLDivergenceRegularizer(
    prior,
    use_exact_kl=False,
    test_points_fn=lambda d: d.sample(10),
    test_points_reduce_axis=0
)

encoder = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256, activation='relu'),
    Dense(64, activation='relu'),
    Dense(tfpl.MultivariateNormalTriL.params_size(encoded_dim)),
    tfpl.MultivariateNormalTriL(
        encoded_dim, 
        activity_regularizer=divergence_regularizer
    ),
])
We can first look at some encoded images before training. As one might expect, there is no clusters exhibit as the encoder is before any training.

pretrain_example_encodings = encoder(example_images).mean().numpy()

# Plot encoded examples before training 

f, ax = plt.subplots(1, 1, figsize=(7, 7))
sns.scatterplot(x=pretrain_example_encodings[:, 0],
                y=pretrain_example_encodings[:, 1],
                hue=class_names[example_labels], ax=ax,
                palette=sns.color_palette("colorblind", 10));
ax.set_xlabel('Encoding dimension 1'); ax.set_ylabel('Encoding dimension 2')
ax.set_title('Encodings of example images before training')



Decoder

For decoder part, we also use two different versions/options for modeling the output distribution. First one uses Bernoulli distribution and the second one uses Beta distribution. Tensorflow Probability provides an IndependentBernoulli layer which we can directly use for the first version. For the second one, as there is no such independent Beta layer implemented, we bake one from scratch using the DistributionLambda layer.

# Decoder version 1: Using IndependentBernoulli

decoder = Sequential([
    Dense(64, activation='relu', input_shape=(encoded_dim,)),
    Dense(256, activation='relu'),
    Dense(28*28),
    tfpl.IndependentBernoulli((28, 28))
])

# Decoder version 2: Using Independent Beta Distribution
# Since there is no IndependentBeta layer, bake one from scratch

decoder = Sequential([
    Dense(64, activation='relu', input_shape=(encoded_dim,)),
    Dense(256, activation='relu'),
    Dense(28*28*2, activation='exponential'), # non-nengative for Beta distribution params
    Reshape((28, 28, 2)),
    tfpl.DistributionLambda(
        lambda t: tfd.Independent(
            tfd.Beta(
                concentration1=t[..., 0],
                concentration0=t[..., 1]
            )
        )
    )
])

VAE

Finally, we use both encoder and decoder to build our VAE model. For the loss part, as we mentioned earlier, the KLDivergenceAddLoss layer already automatically add the KL/regularization loss to the main loss. Here we only need to specify the log loss.

vae = Model(
    inputs=encoder.inputs, 
    outputs=decoder(encoder.outputs)
)

def log_loss(x_true, p_x_given_z):
    return -tf.reduce_sum(p_x_given_z.log_prob(x_true))

vae.compile(loss=log_loss,)
vae.fit(
    x=X_train, 
    y=X_train,
    validation_data=(X_test, X_test),
    epochs=10,
    batch_size=32
)

WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
1874/1875 [============================>.] - ETA: 0s - loss: -55572.9219WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.
1875/1875 [==============================] - 34s 16ms/step - loss: -55578.1367 - val_loss: -62465.1016
Epoch 2/10
1875/1875 [==============================] - 28s 15ms/step - loss: -64829.3750 - val_loss: -65042.4648
Epoch 3/10
1875/1875 [==============================] - 29s 15ms/step - loss: -67461.9844 - val_loss: -70393.7500
Epoch 4/10
1875/1875 [==============================] - 29s 15ms/step - loss: -69034.2734 - val_loss: -66066.8672
Epoch 5/10
1875/1875 [==============================] - 29s 15ms/step - loss: -70065.8438 - val_loss: -68278.2109
Epoch 6/10
1875/1875 [==============================] - 29s 15ms/step - loss: -70965.3438 - val_loss: -69826.4219
Epoch 7/10
1875/1875 [==============================] - 29s 15ms/step - loss: -71667.5000 - val_loss: -72257.3984
Epoch 8/10
1875/1875 [==============================] - 29s 15ms/step - loss: -72232.3594 - val_loss: -68337.1094
Epoch 9/10
1875/1875 [==============================] - 29s 16ms/step - loss: -72632.1641 - val_loss: -71673.4922
Epoch 10/10
1875/1875 [==============================] - 30s 16ms/step - loss: -72972.3281 - val_loss: -72791.6250

Results

First, we can plot some examples using the trained encoder this time.

# Generate an example reconstruction

example_reconstruction = vae(example_images).mean().numpy().squeeze()

# Plot the example reconstructions

fig, axs = plt.subplots(2, 6, figsize=(16, 5))

for j in range(6):
    axs[0, j].imshow(example_images[j, :, :].squeeze(), cmap='binary')
    axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')
    axs[0, j].axis('off')
    axs[1, j].axis('off')

Finally, we can look at whether those encoded images exhibit some clusters after training. As we can observer from the right figure, some clusters can be found which contains images in the same or similar label.

# Compute example encodings after training

posttrain_example_encodings = encoder(example_images).mean().numpy()

# Compare the example encodings before and after training

f, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 7))
sns.scatterplot(
    x=pretrain_example_encodings[:, 0],
    y=pretrain_example_encodings[:, 1],
    hue=class_names[example_labels], ax=axs[0],
    palette=sns.color_palette("colorblind", 10)
)
sns.scatterplot(
    x=posttrain_example_encodings[:, 0],
    y=posttrain_example_encodings[:, 1],
    hue=class_names[example_labels], 
    ax=axs[1],
    palette=sns.color_palette("colorblind", 10)
)

axs[0].set_title('Encodings of example images before training');
axs[1].set_title('Encodings of example images after training');

for ax in axs: 
    ax.set_xlabel('Encoding dimension 1')
    ax.set_ylabel('Encoding dimension 2')
    ax.legend(loc='upper right')



In this post, we introduced Variational AutoEncoder, which is the probablistic twist of Autoencoder. In contrast to the Autoencoder, it is designed or trained to generate images, and it is not deterministic as the Autoencoder (the output of encoder and decoder given an input image). For example, VAE allows sampling from the distributions in the encoder and decoder and will lead to different results for a given image.