Lucas Maystre

Bits of ML

This page contains brief summaries of selected topics in ML research. It is mostly a way for me to review and digest ideas that I find particularly interesting. In each snippet, I try to recap the problem & the key insights in the proposed solution. Note that this page reflects my limited understanding of these topics; I am *not* an expert in most (all?) of these. If you disagree with anything I wrote here, I would be thankful if you could let me know.

Variational Inference

Added on 2018-11-01

Let $\mathbf{z}$ be a vector of latent variables with prior distribution $p(\mathbf{z})$. Consider a vector of observations $\mathbf{x}$ linked to the latent variables through the observation model $p(\mathbf{x} \mid \mathbf{z})$. The central problem in Bayesian inference is the computation of the posterior distribution

A subproblem of interest is the computation of $\log p(\mathbf{x}) = \log \int p(\mathbf{x} \mid \mathbf{z}) p(\mathbf{z}) d\mathbf{z}$, known as the marginal log-likelihood. This can be interpreted as the evidence for the model, and provides a principled way to perform model selection—i.e., to choose $p(\mathbf{z})$ and $p(\mathbf{x} \mid \mathbf{z})$. In many models of interest (even models as simple as logistic regression), computing the posterior or the marginal log-likelihood is intractable.

Variational inference is a family of methods that attempt to approximate the posterior distribution by a simpler distribution (e.g., a multivariate Gaussian). The key idea is to transform the computation of the posterior into an optimization problem. Letting $\mathcal{Q}$ be a family of distributions (the variational family), the inference problem turns into the optimization problem

where $\ell$ is a functional that tells us how close $q(\mathbf{z})$ is to $p(\mathbf{z} \mid \mathbf{x})$. Strictly speaking, any method that attempts to solve this problem should be called a “variational inference method”, but usually what people refer to when using this term is $\ell(p, q) = \mathrm{KL}(q \Vert p)$. There are two key advantages of optimizing this “reverse KL-divergence”:

  1. It is computationally tractable, and
  2. it can be seen as maximizing a lower-bound of the marginal log-likelihood.

We can show this through a series of algebraic manipulations starting from the KL-divergence.

Even though $\mathrm{KL}[q(\mathbf{z}) \Vert p(\mathbf{z} \mid \mathbf{x})]$ cannot be computed directly, we can optimize it via gradient descent. The term $\mathrm{KL}[q(\mathbf{z}) \Vert p(\mathbf{z})]$ (which can be thought of as regularization term) and its derivative can often be computed in closed form. The term $\mathbb{E}_{q}[\log p(\mathbf{x} \mid \mathbf{z})]$ is more challenging, but can often be simplified. In many important cases (e.g., Gaussan processes with non-conjugate observation models), the observations $\{x_i\}$ are independent and every observation depends on a single latent variable. The term then simplifies to $\sum_i \mathbb{E}_{q(z_i)}[\log p(x_i \mid z_i)]$, and can be computed efficiently by a sum of one-dimensional integrals.

Minimizing the reverse KL-divergence is equivalent to maximizing a lower bound on the marginal log-likelihood; this viewpoint comes from the fact that

An alternative way to obtain this result (one that is often mentioned in the literature) uses Jensen’s inequality: