Julia for Probabilistic Metaprogramming

Since around 2010, I've been involved with using and developing probabilistic programming languages. So when I learn about new language, one of my first questions is whether it's a good fit for this kind of development. In this post, I'll talk a bit about working in this area with Julia, to motivate my Soss project.

Domain-Specific Languages

At a high level, a probabilistic programming languages is a kind of domain-specific language, or DSL. As you might expect, this is just a language with a specific purpose, in constrast with a general-purpose programming language.

Even if this term is new to you, you've probably used lots of these. HTML and Markdown for web content, Postscript for (2D) printers, G-code for 3D printers, SQL for database queries... all are examples of DSLs.

These examples have their own syntax, so parsing them requires code made specifically for that purpose. The tremendous flexibility of this approach comes at cost far beyond just parsing the language; none of the language tooling we usually take for granted will come for free.

Want syntax highlighting? You'll need to write it. Connect with external tools? You have to teach it how. Make a change to the language, and you'll need to change the parser accordingly. And this affects more than just developers, since designers of a language can rarely anticipate all the things a user might want to do.

An alternative to this heavy-handed approach is to embed a DSL in some host language, producing an embedded DSL, or EDSL. This is just a library in some host language that is written in such a way to feel like its own language. The distinction from "just a library" is a bit fuzzy, and usually pointless. The biggest advantage of even having the term is that it gives a way of thinking about the problem when designing this sort of thing.

EDSLs are especially popular in functional languages, due to the ability of monads to "overload the semicolon". In Python, the strongest support for EDSL development comes from things like metaclasses, decorators, and double underscore (or "dunderscore") methods like __add__.

Relative to a standalone DSL, the ability to leverage tools in the host language makes EDSLs much quicker to develop, and gives users the freedom to combine elements of the EDSL and host language in ways the developers never anticipated. In exchange, EDSLs usually have significant interpretive overhead. This can usually be overcome, given enough engineering budget.

Fortunately, there's a middle ground. But let's build som context before getting to that.

Probabilistic Programming

Probabilistic programming languages, or PPLs, are usually implemented as a DSL, either standalone (like Stan or Hakaru) or embedded (like PyMC3 or Pyro in Python, Anglican in Clojure, or Figaro in Scala). In either case, it's typical for interactions with probability distributions to take two forms, often abstractly referred to as sample and observe.

Let's take a simple example. Say we flip a (maybe unfair) coin \(N\) times and observe the sequence of heads and tails, and we'd like to infer \(P(\text{heads})\) for the coin. So, something like this:

\[ \begin{aligned} p &\sim \text{Uniform}(0,1) \\ y_n &\sim \text{Bernoulli}(p), n\in\{1,\cdots,N\}\ . \end{aligned} \]

From this way of writing it, there's not much distinction between \(p\) and \(y_n\). But of course, there's a big difference; we know the value of the \(y_n\)s, and would like to use that to reason about \(p\). So we often think of this as having two steps:

  1. Sample a value for \(p\) from the given distribution
  2. Observe each known \(y_n\), using the dependence to update the distribution of \(p\).

This is still a bit vague, and that's the point; the exact form of sample and observe depends on what inference routine we're using.

Perhaps the simplest thing we could do it rejection sampling. Here, sample means "sample" in the usual sense, and observe means "sample and filter". Here's some pseudo-code:

function rejectionSample(y, numSamples)
    N = length(y)
    posteriorSample = zeros(numSamples)  
    sampleNum = 1
    while sampleNum <= numSamples  
        # sample
        p = rand(Uniform(0,1))   
        
        # observe
        proposal = rand(Bernoulli(p),N) 
        if y == proposal
            posteriorSample[sampleNum] = p
            sampleNum += 1
        end
    end
    return posteriorSample
end

Did the syntax highlighting give it away? I lied about the "pseudo" part; this is valid code in Julia, using the Distributions.jl library. Here, let's try it out:

julia> using Distributions
julia> rejectionSample([1,0,1,1,0,1,1,0],20)
20-element Array{Float64,1}:
 0.8158976408772047 
 0.6450191099491256 
 0.45250580931562356
 0.6910986247512794 
 0.7140409029883459 
 0.7590618621191185 
 0.7657057442482806 
 0.7298275374588723 
 0.671549906910313  
 0.3973116677286921 
 0.3295084397546948 
 0.6253119802386622 
 0.47958145829736454
 0.4319525168178511 
 0.5541398855968984 
 0.5504098499307508 
 0.7751201367243916 
 0.8307277566881002 
 0.37112539281292123
 0.40474937271123146

The code above is horribly inefficient, because so many proposals are rejected, and (relatedly) because we're using a vector of Bernoulli samples instead of a single Binomial. But Julia still gets it going at a good pace:

julia> using BenchmarkTools
julia> @btime rejectionSample([1,0,1,1,0,1,1,0],20)
  157.213 μs (3378 allocations: 475.13 KiB)

... (same result)

The theory says that another term for Uniform(0,1) is Beta(1,1), and observing 5 ones and three 3 zeros should bring us to a Beta(1+5,1+3). This makes it easy to do a quick check:

julia> fit(Beta, rejectionSample([1,0,1,1,0,1,1,0],100000))
Beta{Float64}(α=6.032293654775521, β=4.018746844089239)

On my laptop, this takes about 2.5 seconds.

So Meta

In rejectionSample above, we passed in an observed value as an argument to the function. This is a common pattern, whatever inference algorithm we're using. So let's abstract it. Much of the point of probabilistic programming is separation of concerns between the model and the algorithm. So let's use a representation like this:

coin = @model y begin
    N = length(y)
    p ~ Uniform(0,1)
    y  Bernoulli(p) |> iid(N)
end

This incorporates the concepts from above:

  • ~ means sample
  • means observe

The symbol is unusual, but it's easy to input in Julia; just type \dotsim <TAB>.

There are a couple of other things about this that might be unfamiliar. First, @model is a macro. To (over)simplify, this means the body of coin will be parsed (to make sure it's valid Julia code), but not evaluated (unless the macro says so, which @model doesn't). The result is an abstract syntax tree, or AST that can be manipulated and transformed to generate high-performance code. We saw another macro above, @btime for benchmarking. Macros and related techniques for code generation are known as metaprogramming.

The |> in the line defining y is a pipe, similar to the concept of the same name in UNIX-like systems. This is just for notational convenience; Bernoulli(p) |> iid(N) is equivalent to iid(N)(Bernoulli(p)).

The iid function is not specific to Julia, but is defined as part of Soss. This is a concept from statistics; iid(N) means there should be N copies that are independent and identically distributed. [Unlike "www", this initialization actually saves a few syllables. Use it enough, and maybe the world will have a net gain!]

Wrapping up

There's a huge potential in this approach. Models are expressed at a high level, and code transformation gives a way for developers to specify inference algorithms in a way that generates high-performance code.

The @model macro, and lots more, is implemented in my Soss package. There's lots more to say about this:

  • Composability
  • Programmatic model transformation
  • Performant code generation
  • Inference algorithms

But this will have to wait for another time. Thanks for reading!

julia  bayes