Conditional Variational Autoencoders (or CVAEs for short) are a really powerful and creative tool in AI. To really get to the heart of how they work, it's helpful to understand Autoencoder (AE) first. We’ll then move to Variational Autoencoders (VAEs), which are themselves a special version of the basic AutoEncoder.
At its core, an Autoencoder is a type of neural network that learns to do two things:
The entire network is trained by comparing the final output to the original input. The goal is to make the reconstruction as close to the original as possible. The bottleneck is intentionally small, the autoencoder is forced to learn the most essential and distinctive features of the data and throw away the non-essential "noise."
An input image of a '7' is made of 784 pixels (a 28x28 grid). The encoder might be forced to compress this down to, say, just 32 numbers. To do this successfully, it can't memorize the pixels. It has to learn the abstract idea of a '7': "a horizontal line at the top connected to a diagonal line going down." The specific thickness of the line or the exact angle are noise it learns to ignore. So, a well-trained AE gives us a powerful feature extractor. The compressed representation in the latent space is a rich, meaningful summary of the input.
While the AE's latent space is great for compressing data it has seen, it's not good for generating new data. The space is often disjointed and uneven. The cluster of points representing '1's might be far away from the cluster for '8's, with empty, meaningless gaps in between. If we were to just pick a random point from one of those empty gaps and feed it to the decoder, it would have no idea what to do. The output would likely be a blurry, unrecognizable blob. Can we organize this latent space so that it's smooth and continuous, allowing us to pick any point and generate a plausible new image? This is exactly the problem that a Variational Autoencoder (VAE) is designed to solve.
Awesome. Let's dive into the clever solution offered by the Variational Autoencoder (VAE).
A VAE introduces a brilliant twist to solve the "empty gaps" problem in the latent space. Instead of the encoder outputting a single, precise point (a vector), it outputs the parameters of a probability distribution. Typically, this is a simple Gaussian (or normal) distribution. So, for any given input image, the encoder outputs two things:
These two vectors define a "fuzzy" region or a cloud of possible points in the latent space. We then sample a point z (in our latent space) from this distribution and pass it to the decoder.
A VAE is trained to balance two competing objectives:
By forcing the distributions to overlap, the VAE fills in the gaps between data clusters. The latent space becomes smooth and continuous. Now, the space between the cluster for '1's and '7's contains points that will decode into plausible digit shapes. This structured space is what gives VAEs their generative power. After training, we can throw away the encoder, pick a random point z from the latent space, and the decoder will generate a digit image that looks like it came from the original dataset.
The ultimate goal of the authors of the VAE paper was to build a powerful generative model. The most principled way to do this involves understanding the relationship between an observed data point (x) and its latent representation (z). This relationship is described by the posterior probability p(z∣x), which tells you, "Given this image, what is the likely latent code that generated it?"
According to Bayes' theorem:
The big problem here is the denominator, p(x), known as the evidence. To calculate it, you have to compute the integral . For any non-trivial model, this integral is intractable—it's impossible to compute because it requires summing over an infinitely complex, high-dimensional space.
The solution was to define a simpler approximation of distributions, let's call it (e.g., a Gaussian), controlled by parameters ϕ. The goal then becomes to tune ϕ to make our simple distribution q as "close" as possible to the true, intractable posterior p(z|x).
The standard way to measure the "closeness" between two distributions is the Kullback-Leibler (KL) Divergence. So, the objective became: Minimize . Essentially minimize the distance between our approximation and the original distribution.
By definition →
Simplifying using Bayes’ rule (with log)
Replace log(p(z∣x)) in the first equation with the expression from the second step.
Now, we can group the terms inside the expectation. The term log(p(x)) does not depend on z, so it can be pulled out of the expectation.
The first expectation is just the definition of the KL divergence between and the prior p(z).
Let us combine the 1st and 2nd term on the right side to a negative & re-arrange to →
Now re-arrange the equation to solve for log(p(x)).
We want to maximize the likelihood (the probability of observing our data). Since the 2nd term, KL divergence is always greater than 0, the 1st term is always a lower bound on the log-likelihood. That is why it is called evidence lower bound (ELBO). So, maximizing the likelihood can be simplified to maximizing the ELBO.
From the above derivation ,the ELBO can be written as:
This equation represents the VAE's loss function (which we aim to maximize, or equivalently, minimize its negative).
Essentially, we started with reducing the KL Divergence between our intractable posterior distribution & our approximation. Which on simplification decomposed to give us the log likelihood & some other terms. Solving for the log likelihood gave us the lower bound (ELBO) & another KL divergence term. Now the training of a VAE can be reduced to maximizing the ELBO or conversely minimizing the negative of ELBO.
Before we move on to writing the code, there is another problem that we face\ when it comes to training our VAE. We cannot backpropagate through our sampling stage (between the encoder and decoder). You cannot compute the gradient of a probabilistic sampling. Which means anything before the decoder cannot be trained using traditional means. The authors of the original paper came up with a neat little trick to resolve this issue.
Instead of sampling z directly, we sample a random noise vector ϵ from a standard normal distribution, . We then compute z as a deterministic function: . Here, μ and σ (the mean and standard deviation) are the deterministic outputs of the encoder network. This clever trick separates the random part (ϵ) from the network's parameters, allowing gradients to flow through μ and σ to train the encoder. So and .
I trained a simple VAE on the MNIST digits dataset. The encoder sends the 28x28 images to a 32 channel convolution layer → ReLU → 64 channel convolution → ReLU → Flatten → 256 channel out Linear layer → ReLU → 20 channel Linear and 20 channel Linear . The Decoder is the exact reverse of the encoder, using transposed convolutions in place of convolutions. This complex network trained for 10 epochs ended with a loss of around 65. A sample from this model is below. Not bad, eh!
Let’s move on to Conditional VAEs. The goal of a CVAE is to give us control over the generation process. Instead of asking a VAE to generate a random digit, we want to be able to command it to "generate a 4" or "generate a 9."
The way a CVAE achieves this is surprisingly simple. We feed the "condition"—the piece of information we want to control for (like the digit's label)—as an additional input to both the encoder and the decoder.
This condition is typically represented as a one-hot vector. For example, if we're working with digits 0-9, the label '4' would become the vector [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
.
Here’s how it changes the model's job:
The ELBO for a CVAE is similar to that of our original VAE →
As you can see, we've simply added the condition ‘c’ to every probability distribution
is the prior that was assumed to be standard normal dist. in our VAE math. For CVAEs, we can still assume a standard normal or we can train another model to give us this prior distribution ⇒ . We feed in our condition ’c’ and output a . Then we can train all 3 components together; encoder, decoder & the latent prior model.
In most CVAEs, we assume that the "style" (z) should be independent of the "content" or condition (c). For example, the way someone draws a "4" shouldn't fundamentally depend on the fact that it's a "4". Under this assumption, p(z∣c) simplifies to just p(z), which is our good old standard normal distribution, N(0,I).
I trained models with both simple standard normal prior & learned prior. The encoder-decoder is the same architecture as the VAE above with the addition of a one hot encoded class vector representing the digits. The learned prior network is a simple 1 layer ReLU activated feed forward network. I trained both models for 20 epochs.
Some samples from the simple CVAE -
Some samples from the Learned Prior CVAE -
The learned prior CVAE is orders of magnitude worse that the simple CVAE. Sometimes simple is better.
In my next post, we’ll deep dive into some of the popular derivatives of VAEs. Ciao!