import aesara.tensor as at
srng = at.random.RandomStream(0)
sigma_rv = srng.normal(1.)
mu_rv = srng.normal(0, 1)
Y_rv = srng.normal(mu_rv, sigma_rv)Sample prior
Sampling from the prior predictive distribution is a useful tool for debugging. One thus needs this function to be:
- fast
- Easy to use
- Easy to customize; is it easy to change the value of a parameter.
import aesara
from aesara.graph.basic import io_toposort
from aesara.tensor.random.op import RandomVariable
import jax
rng_key = jax.random.PRNGKey(3)
def count_model_rvs(rv_out):
"""Count the number of `RandomVariable` in a model"""
return len([node for node in io_toposort([], [rv_out]) if isinstance(node.op, RandomVariable)])
def split_to_tuple(rng_key, num):
keys = jax.numpy.split(jax.random.split(rng_key, num), num)
return tuple([key.squeeze() for key in keys])
def prior_sample(rng_key, num_samples, rv_out):
"""Return prior predictive samples"""
prior_fn = aesara.function([], Y_rv, mode="JAX").vm.jit_fn
num_rvs = count_model_rvs(Y_rv)
def take_one_sample(rng_key):
keys = [{"jax_state": key} for key in split_to_tuple(rng_key, num_rvs)]
return prior_fn(*keys)[0]
return jax.vmap(take_one_sample)(jax.random.split(rng_key, num_samples))samples = prior_sample(rng_key, 10, Y_rv)
samplesprior_sampler(Y_rv, mu_rv).run(rng_key, 1000, {a_tt: 1.})