Let us consider a family of probability distributions where and a probability distribution on the domain . A discrete mixture model assumes that the value of a random variable can be drawm from one of the : at each step we draw one value from , and then draw from .
We can then write the likelihood of conditional on as:
Many probabilistic programming languages implement mixture via an ad-hoc Mixture distribution object. In Aesara/AePPL mixtures are expressed via the generative process described above, using the following Aesara constructs:
- Indexing an array with a random variable;
aesara.tensor.where;aesara.ifelse.ifelse.
Indexing an array
We can define mixture by using a random variable with a discrete support (Bernoulli, Categorical, etc.) to index an array of random variables:
import aeppl
import aesara
import aesara.tensor as at
import numpy as np
srng = at.random.RandomStream(0)
loc = np.array([-1, 0, 1, 2])
N_rv = srng.normal(loc, 1.)
p = np.array([0.2, 0.3, 0.1, 0.4])
I_rv = srng.categorical(p)
Y_rv = N_rv[I_rv]
sample_fn = aesara.function((), Y_rv)
print(sample_fn())
logprob, (y_vv, i_vv) = aeppl.joint_logprob(Y_rv, I_rv)
logprob_fn = aesara.function((y_vv, i_vv), logprob)
print(logprob_fn(10, 0))Using aesara.tensor.where
We can also define mixtures using aesara.tensor.where; the conditional can be based on other random variables:
import aeppl
import aesara
import aesara.tensor as at
srng = at.random.RandomStream(0)
x_rv = srng.normal(0, 1)
y_rv = srng.cauchy(0, 1)
i_rv = srng.normal(0, 2)
Y_rv = at.where(at.ge(i_rv, 1.), x_rv, y_rv)
sample_fn = aesara.function((), Y_rv)
print(sample_fn())
logprob, (y_vv, i_vv) = aeppl.joint_logprob(y_rv, i_rv)
logprob_fn = aesara.function((y_vv, i_vv), logprob)
print(logprob_fn(10, -1.))TODO Using aeara.ifelse.ifelse
We also soon be able to use aesara.ifelse to define mixures in the same way we use aesara.tensor.where:
import aeppl
import aesara
import aesata.tensor as at
srng = at.random.RandomStream(0)
x_rv = srng.normal(0, 1)
y_rv = srng.cauchy(0, 1)
i_rv = srng.bernoulli(0.5)
Y_rv = aesara.ifelse(i_rv, x_rv, y_rv)
logprob, (y_vv, i_vv) = aeppl.joint_logprob(Y_rv, i_rv)