Practical Guide to Variational Inference

Update: since I wrote this blog post 5 years ago, it’s quite a ride along the variational inference path, both for me and the state of the art! There was no mention of VAEs or normalizing flows or autodiff variational inference because they had not been invented yet (though BBVI was around this time). This post is focused on the tricky analytic derivations, which still have their place for certain models requiring lower variance updates during inference. 

 

There are a few standard techniques for performing inference on hierarchical Bayesian models. Finding the posterior distribution over parameters or performing prediction requires an intractable integral for most Bayesian models, arising from the need to marginalise ("integrate out") nuisance parameters. In the face of this intractability there are two main ways to perform approximate inference: either transform the integration into a sampling problem (e.g., Gibbs sampling, slice sampling) or an optimisation problem (e.g., expectation-maximisation, variational inference).

Probably the most straightforward method is Gibbs sampling (see Chapter 29 of "Information Theory, Inference, and Learning Algorithms" by David MacKay) because you only need to derive conditional probability distributions for each random variable and then sample from these distributions in turn. Of course you have to handle convergence of the Markov chain, and make sure your samples are independent, but you can't go far wrong with the derivation of the conditional distributions themselves. The downside of sampling methods is their slow speed. A related issue is that sampling methods are not ideal for online scalable inference (e.g., learning from streaming social network data).

For these reasons, I have spent the last 6 months learning how to apply variational inference to my mobility models. While there are some very good sources describing variational inference (e.g., chapter 10 of "Pattern Recognition and Machine Learning" by Bishop, this tutorial by Fox and Roberts, this tutorial by Blei), I feel that the operational details can get lost among the theoretical motivation. This makes it hard for someone just starting out to know what steps to follow. Having successfully derived variational inference for several custom hierarchical models (e.g., stick-breaking hierarchical HMMs, extended mixture models), I'm writing a practical summary for anyone about go down the same path. So, here is my summary for how you actually apply variational Bayes to your model.

 

Preliminaries

I'm omitting an in-depth motivation because it has been covered so well by the aforementioned tutorials. But briefly, the way that mean-field variational inference transforms an integration problem into an optimisation problem is by first assuming that your model factorises further than you originally specified. It then defines a measure of error between the simpler factorised model and the original model (usually, this function is the Kullback-Leibler divergence, which is a measure of distance between two distributions). The optimisation problem is to minimise this error by modifying the parameters to the factorised model (i.e., the variational parameters).

Something that can be confusing is that these variational parameters have a similar role in the variational model as the (often, fixed) hyperparameters have in the original model, which is to control things like prior mean, variance, and concentration. The difference is that you will be updating the variational parameters to optimise the factorised model, while the fixed hyperparameters to the original model are left alone. The way that you do this is by using the following equations for the optimal distributions over the parameters and latent variables, which follow from the assumptions made earlier:

$$\mathrm{ln} \; q^*(V_i) = \mathbb{E}_{-V_i}\left( \mathrm{ln} \; p(X, Z, V | \alpha) \right)$$

$$\mathrm{ln} \; q^*(Z_i) = \mathbb{E}_{-Z_i}\left( \mathrm{ln} \; p(X, Z, V | \alpha) \right)$$

where \(X\) is the observed data, \(V\) is the set of parameters, \(Z\) is the set latent variables, and \(\alpha\) is the set of hyperparameters. Another source of possible confusion is that these equations do not explicitly include the variational parameters, yet these parameters are the primary source of interest in the variational scheme. In the steps below, I describe how to derive the update equations for the variational parameters from these equations.

1. Write down the joint probability of your model

Specify the distributions and conditional dependencies of the data, parameters, and latent variables for your original model. Then write down the joint probability of the model, given the hyperparameters. In the following steps, I'm assuming that all the distributions are conjugate to each other (e.g., multinomial data have Dirichlet priors, Gaussian data have Gaussian-Wishart priors and so on).

The joint probability will usually look like this:

$$p(X, Z, V | \alpha) = p(V | \alpha) \prod_n^N \mathrm{<data \; likelihood \; of \; V, Z_n>} \mathrm{<probability \; of \; Z_n>}$$

where \(N\) is the number of observations. For example, in a mixture model, the data likelihood is \(p(X_n | Z_n, V)\) and the probability of \(Z_n\) is \(p(Z_n | V)\). An HMM has the same form, except that \(Z_n\) now has probability \(p(Z_n | Z_{n-1}, V)\). A Kalman filter is an HMM with continuous \(Z_n\). A topic model introduces an outer product over documents and additional set of (global) parameters.

2. Decide on the independence assumptions for the variational model

Decide on the factorisation that will allow tractable inference on the simpler model. The assumption that the latent variables are independent of the parameters is a common way to achieve this. Interestingly, you will find that a single assumption of factorisation will often induce further factorisations as a consequence. These come "for free" in the sense that you get simpler and easier equations without having to make any additional assumptions about the structure of the variational model.

Your variational model will probably factorise like this:

$$q(Z, V) = q(Z) q(V)$$

and you will probably get \(q(V) = \prod_i q(V_i)\) as a set of induced factorisations.

3. Derive the variational update equations

We now address the optimisation problem of minimising the difference between the factorised model and the original one.

Parameters

Use the general formula that we saw earlier:

$$\mathrm{ln} \; q^*(V_i) = \mathbb{E}_{-V_i}\left( \mathrm{ln}\; p(X, Z, V | \alpha) \right)$$

The trick is that most of the terms in \(p(X, Z, V | \alpha)\) do not involve \(V_i\), so can be removed from the expectation and absorbed into a single constant (which becomes a normalising factor when you take the exponential of both sides). You will get something that looks like this:

$$\mathrm{ln} \; q^*(V_i) = \mathbb{E}_{-V_i}\left( \mathrm{ln} \; p(V_i | \alpha) + \sum_n^N \mathrm{ln} \; \mathrm{<data \; likelihood \; of \; V_i, Z_n>} \right) + \mathrm{constant}$$

What you are left with is the log prior distribution of \(V_i\) plus the total log data likelihood of \(V_i\) given \(Z_n\). Even in the two remaining equations, you can often find terms that do not involve \(V_i\), so a lot of the work in this step involves discarding irrelevant parts.

The remaining work, assuming you chose conjugate distributions for your model, is to manipulate the equations to look like the prior distribution of $$V_i$$ (i.e., to have the same functional form as \(p(V_i | \alpha)\)). You will end up with something that looks like this:

$$\mathrm{ln} \; q^*(V_i) = \mathbb{E}_{-V_i}\left( \mathrm{ln} \; p(V_i | \alpha_i') \right) + \mathrm{constant}$$

where your goal is to find the value of \(\alpha_i'\) through equation manipulation. \(\alpha_i'\) is your variational parameter, and it will involve expectations of other parameters \(V_{-i}\) and/or \(Z\) (if it didn't, then you wouldn't need an iterative method). It's helpful to remember at this point that there are standard equations to calculate \(\mathbb{E} \left( \mathrm{ln} \; V_j \right)\) for common types of distribution (e.g., Dirichlet \(V_j\) has \(\mathbb{E} \left( \mathrm{ln} \; V_{j,k} \right) = \psi(V_{j,k}) - \psi(\sum_{k'} V_{j,k'})\), where \(\psi\) is the digamma function). Sometimes you will have to do further manipulation to find expectations of other functions of the parameters. We consider next how to find the expectations of the latent variables \(\mathbb{E}(Z)\).

Latent variables

Start with:

$$\mathrm{ln} \; q^*(Z_n) = \mathbb{E}_{-Z_{n}}\left( \mathrm{ln} \; p(X, Z, V | \alpha) \right)$$

and try to factor out \(Z_{n}\). This will usually be the largest update equation because you will not be able to absorb many terms into the constant. This is because you need to consider the parameters generating the latent variables as well as the parameters that control their effect on observed data. Using the example of multinomial independent \(Z_n\) (e.g., in a mixture model), this works out to be:

$$\mathrm{ln} \; q^*(Z_{n,k}) = \mathbb{E}_{-Z_{n,k}}\left( Z_{n,k} \mathrm{ln} \; V_k + Z_{n,k} \mathrm{ln} \; p(X_n | V_k) \right) + \mathrm{constant}$$

factorising out \(Z_{n,k}\) to get:

$$\mathrm{ln} \; \mathbb{E}(Z_{n,k}) = \mathbb{E}(\mathrm{ln} \; V_k) + \mathbb{E}(\mathrm{ln} \; p(X_n | V_k)) + \mathrm{constant}$$

4. Implement the update equations

Put your update equations from step 3 into code. Iterate over the parameters (M-step) and latent variables (E-step) in turn until your parameters converge. Multiple restarts from random initialisations of the expected latent variables are recommended, as variational inference converges to the local optimum.

The video below shows what variational inference looks like on a mixture model. The green scatters represent the observed data, the blue diamonds are the ground truth means (not known by the model, obviously), the red dots are the inferred means and the ellipses are the inferred covariance matrices: