Variational Importance Sampling

Lots of distributions are easy to evaluate (the density), but hard to sample. So when we need to sample such a distribution, we need to use some tricks. We'll see connections between two of these: importance sampling and variational inference, and see a way to use them together for fast inference.

Importance sampling

Importance sampling aims to make it easy to compute expected values. Say we have a distribution \(p\), and we'd like to compute the average of some function \(f\) of the distribution (or equivalently, the expected value of a "push-forward along \(f\)"). If we could sample \(x_j\sim p\), we'd be interested in

\[ \mathbb{E}_p[f]\approx \frac{1}{N}\sum_{j=1}^N f(x_j)\ . \]

If \(p\) is hard to sample (or if we don't get enough samples in areas where \(f\) is very large) , there's a trick we can use:

  1. Find an appropriate proposal distribution \(q\). This should have the same support (areas of non-zero density) as \(p\), and should be easy (or fast) to sample. [We often also want \(q(x) \approx C |f(x)|\ p(x)\), but that won't be our focus.]

  2. Take \(N\) samples \(x_j\sim q\).

  3. Calculate

\[ \mathbb{E}_p[f]\approx \frac{1}{N}\sum_{j=1}^N f(x_j)\frac{p(x_j)}{q(x_j)}\ . \]

This \(\frac{p(x_j)}{q(x_j)}\) multiplier is the importance weight. To get some intuition for this, suppose that in some region \(q\) is twice as large as \(p\). Because we sampled from \(q\) instead of \(p\), we'll have twice as many points in that region as we should. But the importance weight in that region is \(^1\!/_2\), so this cancels out the effects of oversampling.

Some minor adjustments

The previous section gave a simplified description of the classical goal of importance sampling. But it's not very usable as-is.

First, especially when working in high dimensions, it's common for probability densities to span many orders of magnitude. To avoid underflow (treating nonzero values as zero), we usually prefer to work with log-densities. So let's define

\[ \ell(x) = \log p(x) - \log q(x)\ . \]

Second, there's an implicit assumption above that we start out with a single \(f\) of interest. What if we don't yet have such an \(f\)? Or what if our goal is more general?

For example, a sequential importance sampling used in particle filters uses (log-)importance-weighted samples as a discrete approximation for a distribution of interest. This allows efficient computation of distributions that could be intractable to reason about directly.

Motivated by this, our approach will involve sampling from taking a sample of \(x_j\) values from \(q\) , computing the log-weights \(\ell(x_j)\), and then using these \((x_j, \ell(x_j))\) pairs to reason about \(p\). This is exactly our goal: use \(\ell\)-weighted samples from \(q\) to reason about \(p\).

Variational inference

All \(q\)s are not created equal. If we're trying to find a \(q\) that's a good approximation to \(p\), what criteria should we use?

Since we're using \(\ell\)-weighted samples from \(q\) to reason about \(p\) (and since we already know the answer), it's helpful to understand what could lead to a value \(\ell(x)\) being large. From the definition \(\ell(x) = \log p(x) - \log q(x)\), we can see that this requires \(p(x)\) to be large, and \(q(x)\) to be small.

Now, we could try to make \(p(x)\) really large, by concentrating \(q\) near its peak. But that would make \(q\) large as well. At a point we hit diminishing returns. Similarly, making \(q(x)\) really small would require spreading it over a wide range. But most of that range would have small \(p(x)\) values, again coming to hurt us at some point. These two concerns work in tension with each other, so maximizing \(\ell\) requires balancing them.

Why would large values of \(\ell\) be a good thing? It comes down to a matter of efficiency. Say we're computing an expected value, and we could choose between one heavy point or lots of redundant light ones with the same total weight. The computation is the same, except that we have to pay for all those extra points. Actually, we have to pay twice: Once for the extra sampling cost, and again for the extra evaluation cost. And if there's no redundancy of this sort (i.e., the light points are widely distributed), most computations will be overwhelmed by the effect of heavier points.

All in all, the "right" thing to do in this context is to try to make \(\ell\) large, specifically to maximize its expected value:

\[ \mathbb{E}_q[\ell] \approx \frac{1}{N}\sum_{j=1}^N \left[ \log p(x_j) - \log q(x_j) \right] \]

[The \(\mathbb{E}_q\) notation just means "draw random values from \(q\) and find the average".]

Though we came to it from a completely different angle than you'll usually see, this quantity we've just described is the primary value of interest in variational inference, the evidence lower bound or "ELBO":

\[ \text{ELBO}=\mathbb{E}_q[\ell] \]

So the whole game with variational inference boils down to finding a good approximation \(q\) by choosing the one that maximizes the ELBO.

Maximizing the ELBO

We finally have the tools we need to describe the main idea of this post, variational importance sampling. This algorithm is a work in progress. Though anecdotally promising, its convergence properties are not yet known.

As of June 2019 I've not seen this idea in the literature; I'll update this post if I learn of previous work in this area. Anyway, here's the fundamental idea:

Iteratively sample \(q\), use \(e^\ell\) to weight the samples, and use those weights to refit \(q\)

In Julia-like pseudocode, it looks like this:

for iter in 1:numiters
	x = rand(q, N)
	 = logpdf(p,x) - logpdf(q,x)
	q = fit(q, x, exp( - maximum()) + 1/N)
end

There are a couple of aspects of this worth some extra discussion.

First, we've assumed the existence of a fit method that allows weights. This puts another constraint on the available choices for q, but this method is commonly available for distributions in Julia's Distributions.jl.

Second, exp(ℓ - maximum(ℓ)) is a vector with a maximum of one. This helps avoid overflow, which is much more dangerous than underflow, and acts as a fast pseudo-normalization in preparation for what comes next.

Finally, about that \(^1\!/_N\)... If our sample is dominated by just a few points, a given iteration could "zoom in" much too far, or lose identifiability. So we add \(^1\!/_N\) in a similar spirit to Laplace smoothing.

It's convenient to think of this as effectively fitting the models from two samples that happen to share the same points, but have different weights. The first has weights of exp(ℓ - maximum(ℓ)), so the maximum is one, and the total weight is anywhere between 1 and \(N\). The second sample involves the same points, but each has a weight of \(^1\!/_N\), so the total weight is one.

If the \(\ell_j\) are all the same, they'll still be the same after adding \(^1\!/_N\). Refitting will have very little effect (only noise). Maybe this is just by chance, in which case the next round of samples will get to take another crack at it. Or, maybe we've reached convergence.

At the other extreme, maybe we end up with one \(\ell_j=1\), with all others underflowing. In this case, this one point (now effectively the entirety of the "first sample") is taken to have a total weight equal to the entire "second sample". The result of adding the two "samples" is then halfway between the one extraordinary point we have found, and the full sample from the previous round.

There may be approaches much better than the above \(^1\!/_N\) trick, and the optimal update might depend on characteristics of the chosen variational family. Again, this is ongoing work.

Statistical inference

To this point, we've been talking about sampling \(x\). Sure, the name doesn't really matter, and this or any variational inference method can be used for this general problem of approximating a distribution. But there's at least a connotation that we're interested in sampling data. In fact, that's not usually the case; we've swept under the rug what is by far the most common application of variational inference, namely Bayesian inference.

In a Bayesian context, our \(p\) is not a distribution over data, but rather a "posterior" distribution over parameters. Instead of taking a parameter and giving a way to evaluate data, it takes data and gives a way to evaluate parameters. Going "backward" is the whole point. As you might expect, sampling from \(p\) is hard, and is really the problem of Bayesian inference.

Let's consider a simple example. Say we have a collection of \((x_j, y_j)\) pairs. Given the \(x_j\)s, we might model the \(y_j\)s as being produced like this:

\[ \begin{aligned} α &∼ \text{Normal}(0,1) \\ β &∼ \text{Normal}(0,2) \\ \hat{y}_j &= α + β x_j \\ y_j &∼ \text{Normal}(\hat{y}_j, 1) \end{aligned} \]

The posterior distribution \((\alpha, \beta| x, y)\) plays the role of \(p\). As above, we're not allowed to used \(p\) to generate a sample (that's too expensive). All we can use it for is to evaluate proposals. So we could write it like this:

function logp(α,β)
    ℓp = 0.0
    ℓp += logpdf(Normal(0,1), α)
    ℓp += logpdf(Normal(0,2), β)
    yhat = α .+ β .* x
    ℓp += sum(logpdf.(Normal.(yhat, 1), y) )
    return ℓp
end

This is not efficient!! For real-world use we'd have lots of data, and would probably want to be sure to do the sum in parallel. I've done zero performance tuning, on any of this. Fortunately, Julia is fast enough that performance tuning is just gravy :)

An aside

Before we get to \(q\), we need to mention something strange about \(p\). In the above implementation of logp, the arguments are both Real. But we'd really like to be able to pass it a pair of vectors, and avoid rewriting the whole thing.

Julia is great with broadcasting, so we could just call logp.(α_vec, β_vec) and be done with it. But the difficulty goes beyond Julia.

Having to work "up a dimension" is one of the things like makes Bayesian inference hard. We can never just talk about a value, but instead about a distribution of values.

Fredrik Bagge Carlson's MonteCarloMeasurements.jl is a big help with this. You can pass around a cloud of Particles as if it's just a number, and most of the bookkeeping is taken care of for you. You even get a free performance boost!

Anyway, here's the code for the inference loop:

function runInference(x,y,logp)
    N = 1000 

    # initialize q
    q = MvNormal(2,100000.0) # Really this would be fit from a sample from the prior
    α,β = Particles(N,q)
    m = asmatrix(α,β)
     = sum(logp(α,β)) - Particles(logpdf(q,m))

    numiters = 60
    elbo = Vector{Float64}(undef, numiters)
    for j in 1:numiters
        α,β = Particles(N,q)
        m = asmatrix(α,β)
         = logp(α,β) - Particles(logpdf(q,m))
        elbo[j] = mean() 
        ss = suffstats(MvNormal, m,  exp( - maximum()).particles .+ 1/N)
        q = fit_mle(MvNormal, ss)
    end
    (α,β,q,,elbo)
end

There's still a little bit of converting between representations, notably the little asmatrix helper function, for a matrix representation of a tuple of particles:

asmatrix(ps...) = Matrix([ps...])'

Still, for a first pass, I'm pretty happy with it.

The data were generated from (α=3,β=4). Here are the final inferred values:

julia> @show α β ;
α = 2.98 ± 0.1
β = 4.0 ± 0.11
 = -151.0 ± 0.055

This shows a posterior standard deviation of around 0.1 for both \(\alpha\) and \(\beta\). We also still have access to the underlying particles, in case we'd like bivariate scatter plots, or anything else we're interested in.

Let's see how things went with convergence:

negative-elbo

Note that the \(y\)-axis is on a log scale. At least for this simple example, we see an exponential rate of convergence!

I'm planning to integrate this into Soss.jl. Until then, if you'd like to try this for yourself, you should be able to copy and paste this gist into any REPL for Julia 1.1.1 or above.

The description here of variational inference is not rigorous, but is intended to build intuition in the context of importance sampling. For a more thorough discussion, see Blei et al (2016).

bayes  Soss  julia