Random Variables in Aesara

AeP: PRNG and RandomVariable representation in Aesara

Summary

I propose a new approach to represent RandomVariable and PRNG states in Aesara's IR, based on the design of splittable PRNGs. The representation introduces minimal change to the existing RandomVariable interface while being more expressive. It should be easy to transpile to Aesara's current compilation target, and is compatible with the higher-level RandomStream interface.

References

On counter-based PRNGs:

On splittable PRNG design for functional programs:

Desiderata

A good PRNG design satisfies the following conditions:

  1. It is expressive: the behavior of the system is predictable by the caller, and allows them to expression any probabilistic program;
  2. It makes it possible to build reproducible programs ("seeding");
  3. It is embedded in in Aesara's IR;
  4. It can be manipulated by Aesara's rewrite system;
  5. It can be easily transpiled to current backends;
  6. It enables vectorization with generalized universal functions;

Motivation behind this proposal

This proposal is motivated by two issues that illustrate the shortcomings of the current representation of RandomVariable\s and PRNG states in Aesara:

  • In #1036 the use of default_output to hide the PRNG state from the user is causing multiple headaches in the etuplization of RandomVariable\s and unification/reification of expressions with RandomVariable\s. This is the only Op in Aesara that makes use of this property, and to "special-casing" the etuplization logic for RandomVariable\s often appeared as the easiest solution.
  • In #66 in AeMCMC, expressing expansions like the following convolution of two normal variables is overly complex:

    from etuples import etuple
    from kanren import var
    
    mu_x, mu_y, sigma2_x, sigma2_y = var(), var(), var(), var()
    
    rng, size, dtype = var(), var(), var()
    X_et = etuple(
        etuplize(at.random.normal),
        rng,
        size,
        dtype,
        etuple(
            etuplize(at.add),
            mu_x,
            mu_y
        ),
        etuple(
            etuplize(at.add),
            sigma2_x,
            sigma2_y,
        )
    )
    
    rng_x, size_x, dtype_x = var(), var(), var()
    rng_y, size_y, dtype_y = var(), var(), var()
    Y_et = etuple(
        etuplize(at.add),
        etuple(
            etuplize(at.random.normal),
            rng_x,
            size_x,
            dtype_x,
            mu_x,
            sigma2_x
        ),
        etuple(
            etuplize(at.random.normal),
            rng_y,
            size_y,
            dtype_y,
            mu_y,
            sigma2_y
        )
    )
    

    It is indeed not clear what the values of rng_x and rng_y should be given the value of rng. A few other application-related shortcomings of the current representation will be given below.

Proposal

In the following we focus on the symbolic representation of random variables and PRNG states in Aesara's IR. We leave discussions about compilation targets and solution to the previous issues for the end.

If we represent the internal state of the PRNG by the type RandState (short for RandomStateType), the current design of RandomVariable\s can be summarized by the following simplified signature:

RandomVariable :: RandState -> (RandState, TensorVariable)

In other words, RandomVariable\s are responsible for both advancing the state of the PRNG, and producing a random value. This double responsibility is what creates graph dependencies between nodes that have otherwise no data dependency. The following snippet illustrates this:

import aesara
import aesara.tensor as at

rng = at.random.type.RandomStateType()('rng')

rng_x, x_rv = at.random.normal(0, 1, rng=rng, name='x').owner.outputs
rng_y, y_rv = at.random.normal(0, 1, rng=rng_x, name='y').owner.outputs
z_rv = at.random.normal(0, 1, rng=rng_y, name='z')
w_at = x_rv + y_rv + z_rv

aesara.dprint(w_at)
# Elemwise{add,no_inplace} [id A]
#  |Elemwise{add,no_inplace} [id B]
#  | |normal_rv{0, (0, 0), floatX, False}.1 [id C] 'x'
#  | | |rng [id D]
#  | | |TensorConstant{[]} [id E]
#  | | |TensorConstant{11} [id F]
#  | | |TensorConstant{0} [id G]
#  | | |TensorConstant{1} [id H]
#  | |normal_rv{0, (0, 0), floatX, False}.1 [id I] 'y'
#  |   |normal_rv{0, (0, 0), floatX, False}.0 [id C]
#  |   |TensorConstant{[]} [id J]
#  |   |TensorConstant{11} [id K]
#  |   |TensorConstant{0} [id L]
#  |   |TensorConstant{1} [id M]
#  |normal_rv{0, (0, 0), floatX, False}.1 [id N] 'z'
#    |normal_rv{0, (0, 0), floatX, False}.0 [id I]
#    |TensorConstant{[]} [id O]
#    |TensorConstant{11} [id P]
#    |TensorConstant{0} [id Q]
#    |TensorConstant{1} [id R]

As we can see in the graph representation, rng_x (id C) is being used as an input to y and rng_y (id I) is being used as an input to z. There is however no data dependency between x, y or z. The intuition that they should not be linked is probably what led to "hiding" these PRNG state outputs so they are not re-used, and the RandomStream interface.

Creating spurious sequential dependencies by threading PRNG states is indeed unsatisfactory from a representation perspective, and unnecessarily complicates the rewrites. It is also problematic for two other reasons:

  • Parallelization and Vectorization: Using random variables in user-defined generalized universal functions is going to require a lot of compiler magic to make sure that the random state is updated properly, and the behavior will be completely opaque to the user;
  • The fact that callers cannot be intentional about what they do with the random state is limiting. This can be necessary in pratical applications, for instance to implement coupled sampling algorithms in which two algorithms share the same random state.

A natural idea is to simplify the design of RandomVariable\s so that it is only responsible for one thing: generating a random value from a PRNG state. The Op thus creates an Apply node that takes a RandState (using the above notation) as input and outputs a (random) Variable:

RandomVariable :: RandState -> Variable

Providing a RandState to a RandomVariable needs to intentional, and this must be reflected in the user interface. We thus make rng an explicit input of the RandomVariable's __call__ method. This way a user can write:

import aesara.tensor as at

# rng_x, rng_y and rng_z are created before that.
x_rv = at.random.normal(rng_x, 0, 1)
y_rv = at.random.normal(rng_y, 0, 1)
z_rv = at.random.normal(rng_z, 0, 1)

Or, if they want the PRNG state to be shared (silly example, but a legitimate need):

import aesara.tensor as at

# rng_x, rng_y and rng_z are created before that.
x_rv = at.random.normal(rng_x, 0, 1)
y_rv = at.random.normal(rng_x, 0, 1)
z_rv = at.random.normal(rng_x, 0, 1)

This interface presupposes the existence of two operators. First, to build reproducible programs, we need an operator that creates a RandState from a seed, which can be the constructor of RandState itself:

RandState.__init__ :: Seed -> RandState

And then, we need another operator that creates an updated RandomState from a RandomState, so that RandomVariable\s created with these two different states would output different numbers. Let's call it next:

next :: RandomState -> RandomState

We can thus fill in the blanks in the previous code examples:

import aesara.tensor as at

rng = at.random.RandState(0)
rng_y = at.random.next(rng_x)
rng_z = at.random.next(rng_y)

x_rv = at.random.normal(rng_x, 0, 1)
y_rv = at.random.normal(rng_y, 0, 1)
z_rv = at.random.normal(rng_z, 0, 1)

w_at = x_rv + y_rv + z_rv

The code has been specifically formatted to illustrate what we gain from this approach. x_rv, y_rv and z_rv have lost their direct dependency; we could easily execute these three statements in parallel. What we have done implicitly is to create two graphs: the graph between random variables which reflects the dependencies (or lack thereof) on each other's values, and the graph of the updates of the PRNG states. These graphs almost evolve in parallel.

This is similat to what I understand the RandomStream interface does: moving the updates of the PRNG states to the update graphs generated by Aesara's shared variables.

The next operator is however not completely satisfactory. Let us consider a more complex situation, where call is a function that requires a RandomState:

import aesara.tensor as at

rng = at.random.RandState(0)
rng_y = at.random.next(rng)

x_rv = call(rng_x)
y_rv = call(rng_y)
z_at = x_rv + y_rv

We can easily find an implementation of call that makes the previous code generate a random state collision:

def call(rng_a):
    a_rv = at.random.normal(rng_a, 0, 1)
    rng_b = at.random.next(rng_a)
    b_rv = at.random.normal(rng_b, 0, 1)
    return a_rv * b_rv

To avoid this kind of issues, we must thus require user-defined functions to return the last PRNG state along the result:

import aesara.tensor as at

def call(rng_a):
    a_rv = at.random.normal(rng_a, 0, 1)
    rng_b = at.random.next(rng_a)
    b_rv = at.random.normal(rng_b, 0, 1)
    return (a_rv * b_rv), at.random.next(rng_b)


rng = at.random.RandState(0)
x_rv, rng_x = call(rng)
y_rv, rng_y = call(rng_x)
z_at = x_rv + y_rv

Threading PRNG state is still necessary to guarantee correctness and the two call functions cannot be called in parallel. The issue arises because, even though we have separated PRNG state update and random value generation, our symbolic structure is still sequential: each RandState has one and only one ancestor. We can of course circumvent this issue knowing how many times next is called within the function, by "jumping" the same number of times to obtain rng_y, but this can quickly become complex (what if call is imported from somewhere else?).

It would make things easier if a RandState\s could have several children, and if each of these child led to separate streams of random number. Let us define the following split operator:

split :: RandState -> (RandState, RandState)

We require that we can never get the same RandState by calling split any number of times on either the left or right returned state. In other words, split should implicitly defines a binary tree in which all the nodes are unique. This can be easily represented by letting RandState holding a number in binary format. The leftmost child state is obtained by appending 0 to the parent's state and the rightmost child state by appending 1:

from typing import NamedTuple


class RandState(NamedTuple):
    key: int
    node_id: int = 0b1


def split(rng):
    left = RandState(rng.key, rng.node_id << 2)
    right = RandState(rng.key, (rng.node_id << 2) + 1)
    return left, right


rng = RandState(0)
l, r = split(rng)
ll, lr = split(l)

print(rng, l, lr)
# RandState(key=0, node_id=1) RandState(key=0, node_id=4) RandState(key=0, node_id=17)

If the generator called by RandomVariable can be made a deterministic function of this binary value, the computations are fully reproducible. We added a key attribute that can be specified by the user at initialization to seed the PRNG state. The tree structure is of course explicit in our graph representation, since l and r depend on rng via the split operator. Nevertheless, we can increment this internal state when building the graph in a way that allows us to compile without traversing the graph.

The next operator we previously defined becomes redundant within this representation. Since its interaction with the split operator would require careful thought we leave it aside in the following. Using the new operator our toy example becomes:

import aesara.tensor as at

rng = at.random.RandState(0)
rng_x, rng_y = at.random.split(rng)

x_rv = at.random.normal(rng_x, 0, 1)
y_rv = at.random.normal(rng_y, 0, 1)
z_at = x_rv + y_rv

Note that the "main" sub-graph that contains random variables, and the PRNG sub-graph are still minimally connected.

Finally, it is also natural to implement the splitn operator represented by:

splitn :: RandState -> Int -> (RandState, ..., RandState)

So we can write the following code:

at.random.split = at.random.Split()

rng = at.random.default_rng()
rng_v, rng_w, rng_x, rng_y = at.random.splitn(rng, 4)

v_rv = at.random.normal(rng_y, 0, 1)
w_rv = at.random.normal(rng_x, 0, 1)
x_rv = at.random.normal(rng_x, 0, 1)
y_rv = at.random.normal(rng_y, 0, 1)
z_at = v_rv + w_rv + x_rv + z_rv

Implementation

When it comes to practical implementations, this representation is only convenient for counter-based PRNGs like Philox implemented in NumPy: we generate a pair of (key, counter) from our RandState\s and pass these as an input to the generator.

  • RandState and split implementation

    The mock implementation of RandState and split above is naive in the sense that the counter space \(\mathcal{S}\) of real PRNGs does not usually extend indefinitely. In practice we will need to compress the state using a hashing function that also increments the key. To be immediately compatible with NumPy in the perform function we can use Philox's hash function to update the state as we build the graph. Since the hash is deterministic we can still walk the RandState tree in our representation and cheaply recompute the states should we need to.

    Op and Variable implementations to come.

  • =RandomVariable

    The modifications to RandomVariable Ops are minimal:

    • __call__ now takes a RandState as a positional argument;
    • make_node only returns out_var. The default_output attribute is not needed anymore.
  • RandomStream

    We can keep the RandomStream API, use a shared variable to hold the RandState and handle the splitting internally. The RNG sub-graphs are now found in the updates' graph.

    In a second time we may consider instantiating RandState as shared variables by default to decouple both the random variable and the PRNG state graphs. I am not sure of the tradeoffs here, but it may alleviate concerns related to graph rewrites.

Compilation

It is essential that our representation of PRNG states and RandomVariable\s in the graph can be easily transpiled to the existing targets (C, Numba, JAX) and future targets. In the following I outline the transpilation process for the current targets.

  • Numba

    After #1245 Aesara will support NumPy's Generator API. Furthermore NumPy has support for Philox as a BitGenerator, a counter-based PRNG which can easily accomodate splittable PRNG representations. Assuming we can map each path in the PRNG graph to a (key, counter) tuple, the transpilation of RandomStream\s using the Philox BitGenerator should be straighforward. For the explicit splitting interface, we can directly translate the RandomVariable\s to NumPy Generator\s and seed these generators at compile time. So that:

    at.random.normal(rng, 0, 1)
    

    Becomes:

    gen = np.random.Generator(np.random.Philox(counter=rng.counter, key=rng.key))
    gen.normal(0, 1)
    
  • JAX

    Transpilation to JAX would be straightforward, as JAX uses a splittable PRNG representation. We will simply need to perform the following substitutions:

    rng = at.random.RandomState()
    rng_key = jax.random.PRNGKey()
    
    at.random.split(rng)
    jax.random.split(rng_key)
    
    at.random.splitn(rng, 10)
    jax.random.split(rng_key, 10)
    

Back to the motivating issues

The problems linked to the existence of the default_output attribute disappear since RandomVariable\s do not return PRNG states anymore. The one-to-many difficulty we are facing with the relations between etuplized graphs also disappears with a split operator. Using the example from the beginning we can for instance write:

from etuples import etuple
from kanren import var

mu_x, mu_y, sigma2_x, sigma2_y = var(), var(), var(), var()

rng, size, dtype = var(), var(), var()
X_et = etuple(
    etuplize(at.random.normal),
    rng,
    size,
    dtype,
    etuple(
        etuplize(at.add),
        mu_x,
        mu_y
    ),
    etuple(
        etuplize(at.add),
        sigma2_x,
        sigma2_y,
    )
)

rng_x, size_x, dtype_x = var(), var(), var()
rng_y, size_y, dtype_y = var(), var(), var()
Y_et = etuple(
    etuplize(at.add),
    etuple(
        etuplize(at.random.normal),
        etuple(
            nth,
            0,
            etuple(
                split,
                rng,
            )
        ),
        size_x,
        dtype_x,
        mu_x,
        sigma2_x
    ),
    etuple(
        etuplize(at.random.normal),
        etuple(
            nth,
            1,
            etuple(
                split,
                rng,
            )
        ),
        size_y,
        dtype_y,
        mu_y,
        sigma2_y
    )
)

Which is guaranteed to be collision-free by construction of the split operator, as long as the PRNG state used by the original normal distribution isn't passed to a split operator somewhere else in the original graph (todo: specify API requirements to guarantee uniqueness of the random numbers).

Playground

We would like to be able to control randomness from outside of Aesara. From JAX it is rather clear how that would work:

import aesara

rng = at.random.RNGState()
result = at.random.normal(rng, 0, 1)

# JAX compilation
import jax

fn = aesara.function([rng], result, mode="JAX")
rng_key = jax.random.PRNGKey(0)
fn(rng_key)

But for NumPy/Numba it is not as clear what we should be using.

# Numba compilation
import numpy as np

fn = aesara.function([rng], result, mode="JAX")
rng_key = np.random.default_rng(0)
fn(rng_key)

AeP: New abstraction for =RandomVariable=s

Problem

With the abstractions that are currently available in Aesara/AePPL one wants to define the truncation operator as:

import aesara.tensor as at
import aesara.tensor.random as ar
from aeppl import truncate

x_rv = truncate(ar.normal(0, 1), lower=0)

This reads as defining a new random variable \(X'\) such \(X' = \operatorname{truncate} \circ X\) where \(X \sim \operatorname{Normal}(0, 1)\). As such it is indistinguishable from the existing operator:

y_rv = ar.normal(0, 1)
x_rv = at.clip(y_rv)

Which also reads as defining a new random variable \(X' = \operatorname{clip} \circ X\) with \(X \sim \operatorname{Normal}(0, 1)\). However, we want the truncate operator here to mean something else, namely an operator that takes a probability measure \(\operatorname{Normal}(0, 1)\) and returns a new measure \(\hat{\mathbbm{1}}_{[0, \infty[} \operatorname{Normal}(0, 1)\).

Measures as first class citizens

This issue arises because measures are only implicitly defined in Aesara, where we manipulate RandomVariable objects instead. If we want to avoid the kind of ambiguity that we pointed at earlier, we need to treat them as first class citizens. We can thus define:

import aesara.tensor.random as ar

x_m = ar.normal(0, 1)
assert isinstance(x_m, ar.Measure)

where x_m is now a Measure. To bridge the gap with the existing RandomVariable API in Aesara, we define the aesara.tensor.random.sample operator which takes a measure and return an independent random variable that admits this measure every time it is called. sample(rng, \mathcal{F}) takes a rng key \(k \in \Omega\), a Measure object that represents the measure \(\mathbb{P}\) and returns a random variable \(X: \Omega \to E\) where \(E\) is the event space:

sample :: Key -> Measure -> RandomVariable

So to define a normally-distributed random variable one would write:

import aesara.tensor.random as ar

x = ar.sample(rng, ar.normal(0, 1))
assert isinstance(x, ar.RandomVariable)

This is somewhat similar to what is done in the Anglican probabilistic programming language implemented in Clojure. We retain the existing ambiguity that exists between random variables and their realizations, but that does not matter mathematically speaking.

We can keeps something similar to the RandomStream by defining

import aesara.tensor as at
import aesara.tensor.random as ar

srng = ar.RandomStream(0)
x = srng.sample(ar.normal(0, 1))

We will assume in the following that these primitives are available in Aesara.

Implications

Writing probabilistic programs

It cannot hurt to add an example of a simple probabilistic program and define what is supposed to happen in different circumstances. Let's take the example on the homepage of AePPL's documentation:

import aesara
import aesara.tensor as at
import aesara.tensor.random as ar

srng = ar.RandomStream(0)
S_rv = srng.sample(ar.invgamma(0.5, 0.5), name="S")
Y_rv = srng.sample(ar.normal(0, at.sqrt(S_rv)), name="Y")

# This returns samples from the prior distribution
fn = aesara.function([], [Y_rv])

Transforming measures

It is now possible to apply transformations to Measure\s directly, as with the motivating example:

import aesara.tensor.random as ar
import aeppl.transforms as aet

x_tr = aet.truncate(ar.normal(0, 1), lower=0)
x = ar.sample(rng, x_tr)

Logdensity

Log-densities are associated with measures, so it is natural to write:

import aeppl

aeppl.logdensity(ar.normal(0, 1), x)

Random Variable algebra

We can still perform operations on random variables:

import aesara.tensor as at
import aesara.tensor.random as ar

srng = ar.RandomStream(0)
x_rv = srng.sample(ar.normal(0, 1))
y_rv = srng.sample(ar.normal(0, 1))
z_rv = x_rv + y_rv

We can still impliclty define mixture distributions with:

import aeppl
import aesara.tensor as at
import aesara.tensor.random as ar

srng = at.random.RandomStream(0)

i = srng.sample(ar.bernoulli(0.5))
x = srng.sample(ar.normal(0, 1))
y = srng.sample(ar.normal(1, 1))

z = at.stack([x, y])
out_rv = z[i]

Or maybe now more simply using the measure-theoretic definition of mixtures as:

import aeppl
import aesara.tensor as at
import aesara.tensor.random as ar

srng = at.random.RandomStream(0)
rv = srng.sample(0.5 * ar.normal(0, 1) + 0.5 * ar.normal(0, 1))

Which has nice properties which follow from the fact that measures can be represented as integrals of the density:

aeppl.logdensity(0.5 * ar.normal(0, 1) + 0.5 * ar.normal(0, 1))
# is equivalent to
0.5 * aeppl.logdensity(ar.normal(0, 1)) + 0.5 * aeppl.logdensity(ar.normal(1, 1))

This begs the question of what to do with the following statement:

z_rv = srng.sample(0.5 * ar.normal(0, 1) + 0.5 * ar.normal(0, 1))

When doing forward sampling.

Disintegration of probabilistic programs

What we are usually interested in, however, is the log-density associated with a probabilistic program, a graph that contains random variables. We could write as before:

import aeppl
import aesara.tensor as at
import aesara.tensor.random as ar

srng = ar.RandomStream(0)

S_rv = srng.sample(invgamma(0.5, 0.5))
Y_rv = srng.sample(normal(0., at.sqrt(S_rv)))

logdensity, (y_vv, s_vv) = aeppl.joint_logdensity(Y_rv, S_rv)

This is slightly awkard before logdensity applies to Measure\s while joint_logdensity works with RandomVariable\s. I am not sure if we want to resolve this ambiguity by using a different name that might be confusing to the user.

Broadcasting and the size argument

import aesara.tensor as at
import aesara.tensor.random as ar

mu = at.as_tensor([-1., 0., 1.])
x_m = ar.normal(mu, sigma)
x_rv = ar.sample(rng, x_rm)

x_m here defines an array of three different measures, and sample returns one random variable per element of the array. We would need to be careful with definitions and types here but I can get on board with this. The shape of x_m is called the batch shape. Each sample has the dimensionality of the random variables' event space, so the shape of one sample is called the event shape. The shape of x_rv is defined as:

batch_shape + event_shape

We can define a vector of 10 iid random variables with the iid argument to sample:

import aesara.tensor as at
import aesara.tensor.random as ar

x_rv = ar.sample(rng, normal(0, 1), iid=(10,))

In the more complex case:

import aesara.tensor as at
import aesara.tensor.random as ar

mu = at.as_tensor([-1., 0., 1.])
x_m = ar.normal(mu, sigma)
x_rv = ar.sample(rng, x_rm, iid=(10,))
assert x_rv.shape == (10, 3)

Summary

import aesara.tensor as at
import aesara.tensor.random as ar

x_rv = ar.normal(0, 1)
x = ar.sample(rng, x_rv)
logdensity = ar.logdensity(x_rv, x)

RandomVariable Ops

We have a default_rng function, but the result does not behave as a generator in numpy.

from aesara.tensor.random import default_rng
rng = default_rng(32)
rng.type
from aesara.tensor.random.basic import NormalRV

norm = NormalRV()
norm_rv = norm(0, 1, size=(2,), rng=rng)

norm_rv.eval()

Aesara also defines aliases for the RandomVariable Ops:

from aesara.tensor.random import normal

normal_rv = normal(0, 1, size=(2,), rng=rng)
normal_rv.eval()

Let's look at the graphs that are produced:

import aesara
from aesara.tensor.random import default_rng, normal

rng = default_rng(0)
a_rv = normal(0, 1, rng=rng)
b_rv = normal(0, 1, rng=rng)
c_tt = a_rv + b_rv

d_rv = normal(0, 1, rng=rng)

aesara.dprint(c_tt * d_rv)

How does RandomGeneratorType work? It looks like it has internal state.

Define custom random variables

It is fairly simple as srng.gen(RV, *args) will call RV()(random_state, *args).

srng.gen(zero_truncated_betabinom, eta_at, kappa_rv, n_at),

where the RandomVariable is implemented as:

class ZeroTruncatedBetaBinomial(RandomVariable):
    r"""A zero-truncated beta-binomial distribution.

    This distribution is implemented in the :math:`\kappa`
    and :math:`\eta` parameterization, which is related to
    the standard :math:`\alpha` and :math:`\beta` parameterization
    of the beta-binomial through the following:

    .. math::
        \alpha = \eta / \kappa \\
        \beta = (1 - \eta) / \kappa

    Truncation aside, for a :math:`Y \sim \operatorname{BetaBinom}\left(N, \eta, \kappa\right)`,  # noqa: E501

    .. math::
        \operatorname{E}\left[ Y \right] = N \eta \\
        \operatorname{Var}\left[ Y \right] = N \eta (1 - \eta) (N \kappa + 1) / (\kappa + 1)


    Under this parameterization, :math:`\kappa` in the standard beta-binomial
    serves as an over-dispersion term with the following properties:

    .. math::
        \lim_{\kappa \to 0} \operatorname{Var}\left[ Y \right] = N \eta (1 - \eta) \\
        \lim_{\kappa \to \infty} \operatorname{Var}\left[ Y \right] = N^2 \eta (1 - \eta)

    In other words, :math:`\kappa` modulates between the standard binomial
    variance and :math:`N`-times that variance.

    The un-truncated probability mass function (PMF) is as follows:

    .. math::
        \frac{\operatorname{B}\left(\frac{\eta}{\kappa} + y, n - y + \frac{1 - \eta}{\kappa}\right) {\binom{n}{y}}}{\operatorname{B}\left(\frac{\eta}{\kappa}, \frac{1 - \eta}{\kappa}\right)}  # noqa: E501

    and the zero-truncated PMF is as follows:

    .. math::
        \frac{\operatorname{B}\left(\frac{\eta}{\kappa} + y, - \frac{\eta}{\kappa} + n - y + \frac{1}{\kappa}\right) {\binom{n}{y}}}{\operatorname{B}\left(\frac{\eta}{\kappa}, - \frac{\eta}{\kappa} + \frac{1}{\kappa}\right) - \operatorname{B}\left(\frac{\eta}{\kappa}, - \frac{\eta}{\kappa} + n + \frac{1}{\kappa}\right)}  # noqa: E501

    """
    name = "zero_truncated_betabinom"
    ndim_supp = 0
    ndims_params = [0, 0, 0]
    dtype = "int64"
    _print_name = ("ZeroTruncBetaBinom", "\\operatorname{BetaBinom}_{>0}")

    def __init__(self, rejection_threshold=200, **kwargs):
        """
        Parameters
        ----------
        rejection_threshold
            The number of rejection iterations to perform before raising an
            exception.
        """
        self.rejection_threshold = rejection_threshold
        super().__init__(**kwargs)

    def __call__(self, eta, kappa, n, size=None, **kwargs):
        """
        Parameters
        ----------
        eta
        kappa
        n
        """

        self.eta = at.as_tensor_variable(eta, dtype=aesara.config.floatX)
        self.kappa = at.as_tensor_variable(kappa, dtype=aesara.config.floatX)
        self.n = at.as_tensor_variable(n, dtype=np.int64)

        return super().__call__(eta, kappa, n, size=size, **kwargs)

    def rng_fn(self, rng, eta, kappa, n, size):
        """A naive hybrid rejection + inverse sampler."""

        n = np.asarray(n, dtype=np.int64)
        eta = np.asarray(eta, dtype=np.float64)
        kappa = np.asarray(kappa, dtype=np.float64)

        # Values below this will produce errors (plus, it means this is really
        # a binomial)
        alpha = np.clip(eta / kappa, near_zero, 1e100)
        beta = np.clip((1 - eta) / kappa, near_zero, 1e100)

        # def zt_bb_inv(n, alpha, beta, size=None):
        #     """A zero-truncated beta-binomial inverse sampler."""
        #     # bb_dist = scipy.stats.betabinom(n, alpha, beta)
        #     beta_smpls = np.clip(
        #         scipy.stats.beta(alpha, beta).rvs(size=size), 1e-10, np.inf
        #     )
        #     binom_dist = scipy.stats.binom(n, beta_smpls)
        #     u = np.random.uniform(size=size)
        #     F_0 = binom_dist.cdf(0)
        #     samples = binom_dist.ppf(F_0 + u * (1 - F_0))
        #     return samples

        samples = scipy.stats.betabinom(n, alpha, beta).rvs(size=size, random_state=rng)
        alpha = np.broadcast_to(alpha, samples.shape)
        beta = np.broadcast_to(beta, samples.shape)
        n = np.broadcast_to(n, samples.shape)
        rejects = samples <= 0

        thresh_count = 0
        while rejects.any():
            _n = n[rejects] if np.size(n) > 1 else n
            _alpha = alpha[rejects] if np.size(alpha) > 1 else alpha
            _beta = beta[rejects] if np.size(beta) > 1 else beta
            _size = rejects.sum()

            beta_smpls = np.clip(
                scipy.stats.beta(_alpha, _beta).rvs(size=_size, random_state=rng),
                near_zero,
                near_one,
            )
            samples[rejects] = scipy.stats.binom(_n, beta_smpls).rvs(
                size=_size, random_state=rng
            )
            # samples[rejects] = scipy.stats.betabinom(_n, _alpha, _beta).rvs(size=_size)  # noqa: E501

            new_rejects = samples <= 0
            if new_rejects.sum() == rejects.sum():
                if thresh_count > self.rejection_threshold:
                    # # Attempt rejection sampling until the rejection results
                    # # get stuck, then use the inverse-sampler
                    # samples[rejects] = zt_bb_inv(_n, _alpha, _beta, size=_size)
                    # break
                    # raise ValueError("The sampling rejection threshold was met")
                    warnings.warn(
                        "The sampling rejection threshold was met "
                        "and mean values were used as sample values"
                    )
                    sp_ref_dist = scipy.stats.betabinom(_n, _alpha, _beta)
                    trunc_mean = sp_ref_dist.mean() / (1 - sp_ref_dist.cdf(0))
                    assert np.all(trunc_mean >= 1)
                    samples[rejects] = trunc_mean
                    break
                else:
                    thresh_count += 1
            else:
                thresh_count = 0

            rejects = new_rejects

        return samples


zero_truncated_betabinom = ZeroTruncatedBetaBinomial()


def _logp(value, eta, kappa, n):
    return (
        # binomln(n, value)
        -at.log(n + 1)
        # - betaln(n - value + 1, value + 1)
        # + betaln(value + alpha, n - value + beta)
        # - betaln(alpha, beta)
        - at.gammaln(n - value + 1)
        - at.gammaln(value + 1)
        + at.gammaln(n + 2)
        + at.gammaln(value + eta / kappa)
        + at.gammaln(n - value + (1 - eta) / kappa)
        - at.gammaln(1 / kappa + n)
        - at.gammaln(eta / kappa)
        - at.gammaln((1 - eta) / kappa)
        + at.gammaln(1 / kappa)
    )


@_logprob.register(ZeroTruncatedBetaBinomial)
def zero_truncated_betabinom_logprob(op, values, *inputs, **kwargs):
    (values,) = values
    (eta, kappa, n) = inputs[3:]

    l0 = (
        # gammaln(alpha + beta)
        # + gammaln(n + beta)
        # - gammaln(beta)
        # - gammaln(alpha + beta + n)
        at.gammaln(1 / kappa)
        + at.gammaln(n + (1 - eta) / kappa)
        - at.gammaln((1 - eta) / kappa)
        - at.gammaln(1 / kappa + n)
    )

    log1mP0 = at.log1mexp(l0)
    # log1mP0 = 0

    res = CheckParameterValue("values <= n, eta > 0, kappa > 0")(
        at.switch(values > 0, _logp(values, eta, kappa, n) - log1mP0, -np.inf),
        at.all(values <= n),
        at.all(eta > 0),
        at.all(kappa > 0),
    )
    return res

Note that you can also define this random variables' logprob dispatching _logprob for the ZeroTruncBetaBinom.

Sampling vs Logprobability aeppl

  • How define the logprob of a custom distribution?

Shapes

Shapes are always a mess when it comes to random variables. In aesara we note two distinct shapes:

  • ndim_supp the number of dimensions of the RV's support.
  • ndim_params
  • size which is the sample size

Remember that shapes in Aesara can be determined at runtime! So if we assume that:

batch_shape = size
np.ndim(sample_shape) = ndim_supp
shape = sample_shape + batch_shape

And we should have a look at broadcasting rules because they are not all very obvious.

import aesara.tensor as at
from aesara.tensor.random import RandomStream

srng = RandomStream(0)
a_rv = srng.normal(0, 1, size=(2,3))
print(a_rv.eval())
mu = at.as_tensor([1., 2., 3.])
a_rv = srng.normal(mu, 1, size=(2,3))
print(a_rv.eval())
mu = at.as_tensor([1., 2.])
a_rv = srng.normal(mu, 1, size=(2,3))
print(a_rv.eval())

More complex is the case where the random variable is non-scalar, as multivariate normal. Here you can see that the "event shape" is equal to 2. The resulting shape, if we assume event_shape and batch_shape are tuples is given by:

shape = event_shape + batch_shape
import numpy as np

mu = np.r_[1, 2]
sigma = np.array([[.5, .5], [.4, .6]])
a_rv = srng.multivariate_normal(mu, sigma, size=(2, 5))
print(a_rv.eval().shape)

See Eric Ma's blog post on the topic.

Problems with RandomStream

Proposal

import aesara.tensor as at

rng = at.random.RandomState()

# RandomVariables divide the rng
a_rv, rng = at.random.normal(rng, 0, 1)
b_rv, _ = at.random.normal(rng, 0, 1)

# We have to update the rng manually
a_rv = at.random.normal(rng, 0, 1)
rng = at.random.update(rng)
b_rv = at.random.normal(rng, 0, 1)

rng_a, rng_b = at.random.split(rng)
a_rv = at.random.normal(rng_a, 0, 1)
b_rv = at.random.normal(rng_b, 0, 1)

rngs = at.random.split(rng, 10)
rvs = []
for rng in rngs:
    rvs.append(at.random.normal(rng, 0, 1))

How does that solve the previous issues?

  1. Monkey patching to specialize the RV Op\s
  2. RVs in S-expressions and rewrites

What does that complicate?

def standard_normal():

Links to this note