AePPL, discrete mixtures and conditionals

Let us consider a family of probability distributions \(\left\{\mathbb{P}_\mu : \mu \in \mathcal{M} \right\}\) where \(\mathcal{M} = \left\{1, \dots, K\right\}\) and \(\mathbb{Q}\) a probability distribution on the domain \(\mathcal{M}\). A discrete mixture model assumes that the value of a random variable \(Y\) can be drawm from one of the \(\mathbb{P}_\mu\): at each step we draw one value \(m \in \mathcal{M}\) from \(\mathbb{Q}\), and then draw from \(\mathbb{P}_m\).

We can then write the likelihood of \(Y\) conditional on \(m\) as:

\[ \mathbb{P}(Y \mid z) = \mathbb{P}_z(Y) \]

Many probabilistic programming languages implement mixture via an ad-hoc Mixture distribution object. In Aesara/AePPL mixtures are expressed via the generative process described above, using the following Aesara constructs:

Indexing an array

We can define mixture by using a random variable with a discrete support (Bernoulli, Categorical, etc.) to index an array of random variables:

import aeppl
import aesara
import aesara.tensor as at
import numpy as np

srng = at.random.RandomStream(0)

loc = np.array([-1, 0, 1, 2])
N_rv = srng.normal(loc, 1.)

p = np.array([0.2, 0.3, 0.1, 0.4])
I_rv = srng.categorical(p)

Y_rv = N_rv[I_rv]

sample_fn = aesara.function((), Y_rv)
print(sample_fn())


logprob, (y_vv, i_vv) = aeppl.joint_logprob(Y_rv, I_rv)
logprob_fn = aesara.function((y_vv, i_vv), logprob)
print(logprob_fn(10, 0))

Using aesara.tensor.where

We can also define mixtures using aesara.tensor.where; the conditional can be based on other random variables:

import aeppl
import aesara
import aesara.tensor as at


srng = at.random.RandomStream(0)

x_rv = srng.normal(0, 1)
y_rv = srng.cauchy(0, 1)
i_rv = srng.normal(0, 2)

Y_rv = at.where(at.ge(i_rv, 1.), x_rv, y_rv)

sample_fn = aesara.function((), Y_rv)
print(sample_fn())

logprob, (y_vv, i_vv) = aeppl.joint_logprob(y_rv, i_rv)
logprob_fn = aesara.function((y_vv, i_vv), logprob)
print(logprob_fn(10, -1.))

TODO Using aeara.ifelse.ifelse

We also soon be able to use aesara.ifelse to define mixures in the same way we use aesara.tensor.where:

import aeppl
import aesara
import aesata.tensor as at


srng = at.random.RandomStream(0)

x_rv = srng.normal(0, 1)
y_rv = srng.cauchy(0, 1)
i_rv = srng.bernoulli(0.5)

Y_rv = aesara.ifelse(i_rv, x_rv, y_rv)

logprob, (y_vv, i_vv) = aeppl.joint_logprob(Y_rv, i_rv)

Links to this note