# Random Walk Rosenbluth-Metropolis-Hastings in Aesara

Right before I started working on MCX I wrote a simple benchmarks for PyTorch, Tensorflow and JAX on a very simple problem: using the random walk Rosenbluth-Metropolis-Hastings algorithm to sample from a mixture distribution. MCX was discontinued a bit more than a year ago, when I started working with a PPL based on Aesara. So let me revisit this simple example using Aeasara!

The full code was added to the repository. For this example we use the *C backend*, though Aesara also offers a Numba and a JAX backend.

## Mixture model

In the original blog post I set to sample from a mixture distribution with 4 components. I had to write the corresponding log-probability density function by hand, i.e. without using a PPL. Implementing a mixture model in `Aesara`

is straightforward. No need for a `Mixture`

distribution (like in e.g. PyMC), you just write it like it is:

import aesara.tensor as at import numpy as np srng = at.random.RandomStream(0) loc = np.array([-2, 0, 3.2, 2.5]) scale = np.array([1.2, 1, 5, 2.8]) weights = np.array([0.2, 0.3, 0.1, 0.4]) N_rv = srng.normal(loc, scale, name="N") I_rv = srng.categorical(weights, name="I") Y_rv = N_rv[I_rv]

We can generate forward samples from this model by compiling the model graph choosing `Y_rv`

as an output:

import aesara sample_fn = aesara.function((), Y_rv) samples = [sample_fn() for _ in range(10000)] print(samples[:10])

[array(2.51645571), array(0.16094803), array(4.16173818), array(-0.75365736), array(0.91897138), array(-1.96086176), array(2.60226408), array(2.28198192), array(-1.05260784), array(1.38469404)]

If you are not familiar with Theano/Aesara, the `aesara.function`

may surprise you. What does it do exactly? When you manipulate Aesara tensors, you are not manipulating numbers, but rather you are *describing the computation to perform on the inputs*. As a result, the result of an Aesara operation is a graph:

aesara.dprint(Y_rv)

Subtensor{int64} [id A] |normal_rv{0, (0, 0), floatX, False}.1 [id B] 'N' | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FBAB334F680>) [id C] | |TensorConstant{[]} [id D] | |TensorConstant{11} [id E] | |TensorConstant{[-2. 0. .. 3.2 2.5]} [id F] | |TensorConstant{[1.2 1. 5. 2.8]} [id G] |ScalarFromTensor [id H] |categorical_rv{0, (1,), int64, False}.1 [id I] 'I' |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FBAB17AE7A0>) [id J] |TensorConstant{[]} [id K] |TensorConstant{4} [id L] |TensorConstant{[0.2 0.3 0.1 0.4]} [id M]

`aesara.function`

is therefore used to *compile* the graph into a function that can be executed. For that, we need to specify the inputs and outputs of the function. In this case there are no outputs, and the value of `Y_rv`

is the output.

To compute the log-probability density function we can use AePPL's `joint_logprob`

function. AePPL transforms the Aesara model graph to get the graph that computes the model's joint logprob (see, working with computation graphs is nice!). We pass a dictionary that tells which value to associate with the random variables `Y_rv`

and `I_rv`

:

from aeppl import joint_logprob y_vv = Y_rv.clone() i_vv = I_rv.clone() logprob = joint_logprob({Y_rv: y_vv, I_rv: i_vv}) print(logprob.eval({y_vv: 10., i_vv: 3}))

-6.452221131239579

Here we do not really care about the values that `I_rv`

takes, so we marginalize the log-probability density function over `I_rv`

:

logprob = [] for i in range(4): i_vv = at.as_tensor(i, dtype="int64") logprob.append(joint_logprob({Y_rv: y_vv, I_rv: i_vv})) logprob = at.stack(logprob, axis=0) total_logprob = at.logsumexp(at.log(weights) + logprob) print(total_logprob.eval({y_vv: 10.}))

-6.961941398089025

## Implement the algorithm

The random walk Rosenbluth-Metropolis-Hasting algorithm is also straightforward to implement:

def rw_metropolis_kernel(srng, logprob_fn): """Build the random walk Rosenbluth-Metropolis-Hastings (RNH) kernel.""" def one_step(position, logprob): """Generate one sample using the random walk RMH algorithm. Attributes ---------- position: The initial position. logprob: The initial value of the logprobability. Returns ------ The next positions and values of the logprobability. """ move_proposal = 0.1 * srng.normal(0, 1) proposal = position + move_proposal proposal_logprob = logprob_fn(proposal) log_uniform = at.log(srng.uniform()) do_accept = log_uniform < proposal_logprob - logprob position = at.where(do_accept, proposal, position) logprob = at.where(do_accept, proposal_logprob, logprob) return position, logprob return one_step

Syntactically, `aesara.tensor`

looks like a drop-in replacement to `numpy`

. Remember, however, that these functions do not act on numbers but add an operation to an existing graph of computation. In particular, `logprob_fn`

is a function that takes a graph (possibly a single variable), and returns the graph that computes the value of the log-probability density function.

## So, does it work?

Let us sample 1000 chains concurrently for an increasing number of samples and compare the running time to NumPy's and JAX's:

For small number of samples, Aesara (C backend) and JAX spend most of their time compiling the kernel and NumPy is faster. Past \(10^4\) samples NumPy lags behind, with Aesara catching up with JAX around \(10^5\) samples.

## Perspectives

Aesara is still young and holds many promises for the future, come help us! Here is what you can expect to change with this example in the near future:

**Maginalize automatically.**`AePPL`

will soon allow to automatically marginalize over discrete random variable (see related issue).

** Vectorize computation.** The implementation for the multiple chain sampler is currently close to NumPy's for performance reasons, but you should soon be able to write the kernel for a single chain, and use the equivalent of

`np.vectorize`

or `jax.vmap`

to vectorize the computation (see related issue).
** Work with different backends.** You will soon be able to compile this example using Aesara's JAX backend and Numba backend (work in progress, you can already try it!). This means you will be able to interact with different ecosystems and leverage the strengths of different compilers / hardware devices with the

*same model expression*in python. This also means that your model code is more future-proof as you can make the backend move under it.

** Build samplers automatically.** AeMCMC analyzes your model graph and builds an efficient sampler for it.

Still not sure what Aesara is about? Read Brandon Willard's explanation.