AePPL needs the notion of measure to be able to perform operations such as truncate, or for more exotic use cases such as distributions defined on a non-trivial manifold. We will note the measure defined on the manifold (often , but not necessarily). For instance the parametrized measure (distribution) is defined as:

And the measure as:

We note a tensor that contains elements of type . Broadcasting rules apply when parameters of different dimensionalities apply, and the result is a tensor of measures of type where is the event space’s type.

To make the junction with Aesara’s RandomVariable=\s we define the =sample operator. is a function that takes a PRNG key , a measure and returns an element . Under the hood, sample finds the RandomVariable that corresponds to a given measure and creates a new apply node by calling the Op.

sample :: PRNGKey -> Measure -> TensorVariable

Base types

We must first define the types that the event space can take. Our goal is to get a minimum viable example for the normal distributions so will limit ourselves to and .

from aesara.raise_op import CheckAndRaise
import aesara.tensor as at
import abc
 
class CheckParameterValue(CheckAndRaise):
    """Implements a parameter value check in a graph."""
 
    def __init__(self, msg=""):
        super().__init__(TypeError, msg)
 
    def __str__(self):
        return f"Check{{{self.msg}}}"
 
class Domain(abc.ABC)
    @abc.abstractmethod
    def __call__(self, x):
        pass
 
class Real(Domain):
    def __call__(self, x):
        return CheckParameterValue("real")(x, at.isfinite(x))
 
class Positive(Domain):
    def __call__(self, x):
        return CheckParameterValue("x >= 0")(x, at.geq(x, 0), at.isfinite(x))
 
reals = Real()
positive = Positive()

The constraints that characterize the types are represented in the Aesara graph with a CheckParameterValue assertion.

Measures

We now define the type for the measure . We include information about the base measure, noting and for the Lebesgue and counting measures respectively.

import aesara.tensor as at
import abc
 
class Measure(abc.ABC):
    """A variable that represents a probability measure."""
    base_measure: Measure
 
class PrimitiveMeasure(abs.ABC):
    """A primitive measure"""
    domain: Domain
 
class Lebesgue(PrimitiveMeasure):
    def __init__(self, domain: Domain):
        self.domain = domain
 
class NormalMeasure(Measure):
    def __init__(self, loc, scale):
        self.parameters = {
            "loc": reals(loc),
            "scale": positive(scale)
        }
        self.base_measure = Lebesgue(reals)
        self.rv = at.random.normal

Sample

import aesara.random as ar
from multipledispatch import dispatch
 
def sample(rng, measure: Measure):
    return measure.rv(rng=rng, **parameters)

Logdensity

import singledispatch
 
@singledispatch.register
def logdensity(m: measure):
    raise NotImplementedError(f"No density associated with the provided {measure}")
 
@logdensity.register(NormalMeasure)
def normal_logdensity(m: Measure, values):
    (value,) = values
    mu, sigma = m.parameters["loc"], m.parameters["scale"]
    res = (
        -0.5 * at.pow((value - mu) / sigma, 2)
        - at.log(at.sqrt(2.0 * np.pi))
        - at.log(sigma)
    )
    return res
 

Other

def logdensity(measure, x):
    return _logdensity(measure, measure.base_measure, x)
 
def sample(rng_key, measure):
 
import singledispatch
 
class Measure(abc.ABC):
    """A variable that represents a probability measure."""
 
class Lebesgue(PrimitiveMeasure):
 
    def __init__(self, domain):
        self.domain = domain
 
class NormalMeasure(Measure):
 
    def __init__(self, mu, sigma):
        self.mu = reals(mu)
        self.sigma = positive(sigma)
        self.base_measure = Lebesgue(reals)