A Prelude to Pyro

Lately I've been exploring Pyro, a recent development in probabilistic programming from Uber AI Labs. It's an exciting development that has a huge potential for large-scale applications.

In any technical writing, it's common (at least for me) to realize I need to add some introductory material before moving on. In writing about Pyro, this happened quite a bit, to the point that it warranted this post as a kind of warm-up.

What is Pyro?

Pyro bills itself as "Deep Universal Probabilistic Programming". Some of those terms might be new, so let's deconstruct it.

"Deep"

This is probably the most widely-known of these terms, and of course refers to deep learning. Pyro is built on PyTorch, a popular deep-learning library from Facebook.

PyTorch is similar in some ways to TensorFlow/Keras, but uses dynamic computational graphs; the graph defining the computational steps is built from scratch with each execution. We'll come back to this, as it turns out to be a crucial ingredient for Pyro's approach.

"Probabilistic Programming"

In the broadest sense, probabilistic programming is "programming language support for reasoning about probabilities". There are lots of scopes for this, depending what you mean by "support" and "reasoning", but by far the most common (and the target of Pyro) involves automation of Bayesian inference.

Through decades of research and development, there have been many probabilistic programming languages, perhaps the best-known of which have included BUGS, JAGS, and most recently Stan. These (and many less widely-known developments) share the constraint that the number and type of random choices is known statically, before the model is "run".

This is a reasonable assumption for a huge variety of models, and the constraint allows for specific inference methods. For example, the "GS" in BUGS and JAGS stands for "Gibbs sampling". Stan imposes the additional constraint that all parameters (but not data) must be continuous, in exchange for the No U-Turn Sampler ("NUTS") and Automatic Differentiation Variational Inference ("ADVI").

While the above constraints are reasonable, they do still limit the models that can be expressed...

"Universal"

In general-purpose programming languages, much of the history was shaped by the influence and early design decisions of two early languages, FORTRAN and Lisp. Probabilistic programming has followed a similar pattern, with the FORTRAN-like pragmatism of BUGS contrasted by the Lisp-like emphasis on flexibility and expressiveness in Church.

In "universal" languages like Church (and descendants like Anglican and WebPPL), a simulation is a model. An arbitrary number of random choices are made along the way, and we reason about those choices based on the observations. No need to constrain anything. Want new random variables at execution time? They're all yours. Stochastic recursion? Sure, go for it.

This may sound needlessly general, a sort of Bayesian Bacchanalia. But it's what we need to easily express nonparametric Bayesian models, which have plenty of applications in real-world problems. Say you're using latent Dirichlet allocation for text analysis. How many topics do you need? What if you had a thousand times as many documents - would you possibly want more topics? Hierarchical Dirichlet process mixture models are just the thing.

There is some inconsistency in the community about terminology. For some time the term "Turing complete" was used to describe this concept. But this isn't so useful, since for example Stan is Turing complete. Pyro's web site says it's universal because, "Pyro can represent any computable probability distribution", which is consistent with the definition from other researchers like Dan Roy. But it's very common to use the term more loosely, often without support of a formal proof.

Bayesian Inference

In most cases, building a Bayesian model involves specifying a prior \(P(\theta)\) and likelihood \(P(x|\theta)\). The posterior distribution is then

\[ P(\theta | x) = \frac{P(\theta) P(x|\theta)}{P(x)}\ . \]

The goal of Bayesian inference is to "understand" this distribution. Design choices in inference come down to what kind of understanding we're after, and what cost we're willing to pay to get there.

Perhaps the simplest useful thing we can do with the posterior is to find the value of \(\theta\) that maximizes \(P(\theta|x)\). This is Maximum a Posteriori, or MAP estimation, and has special cases in ridge and lasso.

MAP estimation is an optimization problem, and is usually very fast to compute. But there are limitations. As a point estimate, it considers the result to be "the one true estimate", with no accounting for uncertainty. It's also sensitive to reparameterization; a substitution like \(\theta=\log \tau\) usually leads to \(\hat\theta \neq \log \hat\tau\) .

Because of these and related shortcomings, Bayesians tend to eschew point estimates like MAP, instead preferring to sample from the posterior distribution. Markov chain Monte Carlo ("MCMC") methods work well for problems that are moderate in both data size and model complexity. But the benefits of sampling come at a price: sampling tends to be much slower and less scalable than optimization.

Variational inference offers a middle ground: Approximate the posterior with a parameterized distribution. This turns the sampling problem in to an optimization problem of finding the parameters to give the "best" approximation. By adjusting the complexity of the approximation, we can trade speed for approximation quality or vice-versa.

Variational Inference

Before we get into variational inference, it's convenient to change our notation a bit. We've been writing everything in terms of \(P(\cdots)\), but now we'll have two different distributions with some things in common, and we need to be able to keep everything straight.

To match most of the literature, we'll write \(p\) for the original distribution, and \(q\) for the approximation. We'll also use \(z\) for unobserved random variables; this corresponds to \(\theta\) above. So the goal is find a distribution \(q(z)\) that's a good approximation to the posterior \(p(z|x)\).

We still need to quantify what makes an approximation "good". One reasonable approach is to try to minimize the Kullback-Leibler divergence \(\text{KL}[q(z) \| p(z|x)]\). This turns out to be intractable to compute, but minimizing it is equivalent to maximizing a related quantity, the evidence lower bound, or "ELBO",

\[ \text{ELBO}(p,q) = \mathbb{E}_q[\log p(z,x)] - \mathbb{E}_q[\log q(z)]\ . \]

The big idea of variational inference is to tune the approximation \(q\) by maximizing the ELBO.

Let's think through how to interpret this. If we have a proposed approximation \(q\), the first terms gives a reward when \(p\) is large on some sample from \(q\). But we could cheat, by choosing \(q\) to concentrate on the MAP estimate of \(p\). So the second term balances this, encouraging \(q\) to be "spread out". You might recognize the second term (including the negative) as the entropy of \(q\).

Ok, so we need to find \(q\) to maximize the ELBO. What possibilities should we try? In its original form, variational inference used the calculus of variations, allowing very loose constraints. But in common use, we'll select some parameterized form. For example, in a simple example we might want \(q\) to choose \( z \sim \text{Normal}(\mu, \sigma)\ . \) Then \(q\) would be parameterized by \((\mu,\sigma)\), which we could write as \(q_{\mu, \sigma}\).

In this context, \((\mu,\sigma)\) is called the variational parameter. In the generic formulation, this is usually called \(\lambda\). So we're given \(p(z,x)\) and a parameterized form \(q_\lambda(z)\), and need to find \(\lambda\) to maximize \(\text{ELBO}(p,q_\lambda)\).

The formulation to this point still requires stepping through the entirety of the data \(x\). It's faster than MCMC, but doesn't yet offer a way of handling large data. In typical machine learning, we get around this by dropping gradient descent in favor of stochastic gradient descent. In variational methods, the role of SGD is played by stochastic variational inference.

For more details on variational inference, see
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877.

Back to Pyro

We've discussed the components of "deep universal probabilistic programming". Pyro models can involve any of these, or all of them at once. Variational autoencoders, deep Markov models, Gaussian processes... it really has a "sky's the limit" kind of feel.

Pyro supports a variety of inference methods, but its main focus is on stochastic variational inference. Pyro allows more generality than described above, through optional "fixed but unknown" parameters included in \(p\). The user specifies \(p\) and \(q\) through functions model and guide, respectively, each of which takes the same data as a parameter.

Parameters to optimize are introduced using the param function, while stochastic choices use sample.

Overall, the setup is like this:

Component Pyro code In model \(p\) In guide \(q\)
Parameters to optimize param(...) \(\varphi\) \(\lambda\)
Prior sample(...) \(p_\varphi(z)\) \(q_\lambda(z)\)
Likelihood sample(..., obs=...) \(p_\varphi(x \vert z)\) Not allowed

This approach turns out to be remarkably flexible:

  • If \(z=\emptyset\), it gives maximum likelihood estimation over \(\varphi\)
  • If \(\varphi=\emptyset\) and \(q_\lambda\) is a delta distribution at \(\lambda\), it gives MAP estimation
  • If \(\varphi=\emptyset\), it gives variational inference
  • If \(\varphi\equiv\lambda\) and \(p_\varphi (z)\equiv q_\lambda (z)\), it gives maximum marginal likelihood, also known as type II maximum likelihood or empirical Bayes
    [N.B. I don't know that constraints for this item can currently be expressed in Pyro]

Conclusion

While deep learning has generated a lot of mainstream excitement, probabilistic programming is still dramatically underused, and universal probabilistic programming even more so. Pyro's combination of these with scalable and flexible variational inference has the potential to change that. It's notoriously difficult to predict the influence of a new software library, but Pyro is certainly one to keep an eye on.