Random Walk Rosenbluth-Metropolis-Hastings in Aesara

Right before I started working on MCX I wrote a simple benchmarks for PyTorch, Tensorflow and JAX on a very simple problem: using the random walk Rosenbluth-Metropolis-Hastings algorithm to sample from a mixture distribution. MCX was discontinued a bit more than a year ago, when I started working with a PPL based on Aesara. So let me revisit this simple example using Aeasara!

The full code was added to the repository. For this example we use the C backend, though Aesara also offers a Numba and a JAX backend.

Mixture model

In the original blog post I set to sample from a mixture distribution with 4 components. I had to write the corresponding log-probability density function by hand, i.e. without using a PPL. Implementing a mixture model in Aesara is straightforward. No need for a Mixture distribution (like in e.g. PyMC), you just write it like it is:

import aesara.tensor as at
import numpy as np

srng = at.random.RandomStream(0)

loc = np.array([-2, 0, 3.2, 2.5])
scale = np.array([1.2, 1, 5, 2.8])
weights = np.array([0.2, 0.3, 0.1, 0.4])

N_rv = srng.normal(loc, scale, name="N")
I_rv = srng.categorical(weights, name="I")
Y_rv = N_rv[I_rv]

We can generate forward samples from this model by compiling the model graph choosing Y_rv as an output:

import aesara

sample_fn = aesara.function((), Y_rv)
samples = [sample_fn() for _ in range(10000)]

print(samples[:10])
[array(2.51645571), array(0.16094803), array(4.16173818), array(-0.75365736), array(0.91897138), array(-1.96086176), array(2.60226408), array(2.28198192), array(-1.05260784), array(1.38469404)]

If you are not familiar with Theano/Aesara, the aesara.function may surprise you. What does it do exactly? When you manipulate Aesara tensors, you are not manipulating numbers, but rather you are describing the computation to perform on the inputs. As a result, the result of an Aesara operation is a graph:

aesara.dprint(Y_rv)
Subtensor{int64} [id A]
 |normal_rv{0, (0, 0), floatX, False}.1 [id B] 'N'
 | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FBAB334F680>) [id C]
 | |TensorConstant{[]} [id D]
 | |TensorConstant{11} [id E]
 | |TensorConstant{[-2.   0. .. 3.2  2.5]} [id F]
 | |TensorConstant{[1.2 1.  5.  2.8]} [id G]
 |ScalarFromTensor [id H]
   |categorical_rv{0, (1,), int64, False}.1 [id I] 'I'
     |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FBAB17AE7A0>) [id J]
     |TensorConstant{[]} [id K]
     |TensorConstant{4} [id L]
     |TensorConstant{[0.2 0.3 0.1 0.4]} [id M]

aesara.function is therefore used to compile the graph into a function that can be executed. For that, we need to specify the inputs and outputs of the function. In this case there are no outputs, and the value of Y_rv is the output.

To compute the log-probability density function we can use AePPL's joint_logprob function. AePPL transforms the Aesara model graph to get the graph that computes the model's joint logprob (see, working with computation graphs is nice!). We pass a dictionary that tells which value to associate with the random variables Y_rv and I_rv:

from aeppl import joint_logprob

y_vv = Y_rv.clone()
i_vv = I_rv.clone()
logprob = joint_logprob({Y_rv: y_vv, I_rv: i_vv})

print(logprob.eval({y_vv: 10., i_vv: 3}))
-6.452221131239579

Here we do not really care about the values that I_rv takes, so we marginalize the log-probability density function over I_rv:

logprob = []
for i in range(4):
    i_vv = at.as_tensor(i, dtype="int64")
    logprob.append(joint_logprob({Y_rv: y_vv, I_rv: i_vv}))
logprob = at.stack(logprob, axis=0)
total_logprob = at.logsumexp(at.log(weights) + logprob)

print(total_logprob.eval({y_vv: 10.}))
-6.961941398089025

Implement the algorithm

The random walk Rosenbluth-Metropolis-Hasting algorithm is also straightforward to implement:

def rw_metropolis_kernel(srng, logprob_fn):
    """Build the random walk Rosenbluth-Metropolis-Hastings (RNH) kernel."""

    def one_step(position, logprob):
        """Generate one sample using the random walk RMH algorithm.

        Attributes
        ----------
        position:
            The initial position.
        logprob:
            The initial value of the logprobability.

        Returns
        ------
        The next positions and values of the logprobability.

        """
        move_proposal = 0.1 * srng.normal(0, 1)
        proposal = position + move_proposal
        proposal_logprob = logprob_fn(proposal)

        log_uniform = at.log(srng.uniform())
        do_accept = log_uniform < proposal_logprob - logprob

        position = at.where(do_accept, proposal, position)
        logprob = at.where(do_accept, proposal_logprob, logprob)

        return position, logprob

    return one_step

Syntactically, aesara.tensor looks like a drop-in replacement to numpy. Remember, however, that these functions do not act on numbers but add an operation to an existing graph of computation. In particular, logprob_fn is a function that takes a graph (possibly a single variable), and returns the graph that computes the value of the log-probability density function.

So, does it work?

Let us sample 1000 chains concurrently for an increasing number of samples and compare the running time to NumPy's and JAX's:

rmh-aesara-comparison.png

For small number of samples, Aesara (C backend) and JAX spend most of their time compiling the kernel and NumPy is faster. Past \(10^4\) samples NumPy lags behind, with Aesara catching up with JAX around \(10^5\) samples.

Perspectives

Aesara is still young and holds many promises for the future, come help us! Here is what you can expect to change with this example in the near future:

Maginalize automatically. AePPL will soon allow to automatically marginalize over discrete random variable (see related issue).

Vectorize computation. The implementation for the multiple chain sampler is currently close to NumPy's for performance reasons, but you should soon be able to write the kernel for a single chain, and use the equivalent of np.vectorize or jax.vmap to vectorize the computation (see related issue).

Work with different backends. You will soon be able to compile this example using Aesara's JAX backend and Numba backend (work in progress, you can already try it!). This means you will be able to interact with different ecosystems and leverage the strengths of different compilers / hardware devices with the same model expression in python. This also means that your model code is more future-proof as you can make the backend move under it.

Build samplers automatically. AeMCMC analyzes your model graph and builds an efficient sampler for it.

Still not sure what Aesara is about? Read Brandon Willard's explanation.