Measures in AePPL

AePPL needs the notion of measure to be able to perform operations such as truncate, or for more exotic use cases such as distributions defined on a non-trivial manifold. We will note $$\mathbb{M} \mathcal{M}$$ the measure defined on the manifold $$\mathcal{M}$$ (often $$\mathbb{R}$$, but not necessarily). For instance the parametrized $$\operatorname{normal}$$ measure (distribution) is defined as:

$\frac{\mu : \mathbb{R}, \sigma : \mathbb{R}^+}{\operatorname{normal}(\mu, \sigma) \in \mathbb{M}\mathbb{R}}$

And the $$\operatorname{dirichlet}$$ measure as:

$\frac{\boldsymbol{\alpha} : \mathbb{R}^n}{\operatorname{dirichlet}(\boldsymbol{\alpha}) \in \mathbb{M}\Delta_n}$

We note $$\mathbb{T}A$$ a tensor that contains elements of type $$A$$. Broadcasting rules apply when parameters of different dimensionalities apply, and the result is a tensor of measures of type $$\mathbb{T}(\mathbb{M}E)$$ where $$E$$ is the event space's type.

To make the junction with Aesara's RandomVariable\s we define the sample operator. $$\operatorname{sample}(k, m)$$ is a function that takes a PRNG key $$k \in \Omega$$ , a measure $$m : \mathbb{M}E$$ and returns an element $$e \in E$$. Under the hood, sample finds the RandomVariable that corresponds to a given measure and creates a new apply node by calling the Op.

sample :: PRNGKey -> Measure -> TensorVariable


Base types

We must first define the types $$E$$ that the event space can take. Our goal is to get a minimum viable example for the normal distributions so will limit ourselves to $$\mathbb{R}$$ and $$\mathbb{R}^+$$.

from aesara.raise_op import CheckAndRaise
import aesara.tensor as at
import abc

class CheckParameterValue(CheckAndRaise):
"""Implements a parameter value check in a graph."""

def __init__(self, msg=""):
super().__init__(TypeError, msg)

def __str__(self):
return f"Check{{{self.msg}}}"

class Domain(abc.ABC)
@abc.abstractmethod
def __call__(self, x):
pass

class Real(Domain):
def __call__(self, x):
return CheckParameterValue("real")(x, at.isfinite(x))

class Positive(Domain):
def __call__(self, x):
return CheckParameterValue("x >= 0")(x, at.geq(x, 0), at.isfinite(x))

reals = Real()
positive = Positive()


The constraints that characterize the types are represented in the Aesara graph with a CheckParameterValue assertion.

Measures

We now define the type for the measure $$\mathbb{M}E$$. We include information about the base measure, noting $$\mathbb{L}$$ and $$\mathbb{C}$$ for the Lebesgue and counting measures respectively.

import aesara.tensor as at
import abc

class Measure(abc.ABC):
"""A variable that represents a probability measure."""
base_measure: Measure

class PrimitiveMeasure(abs.ABC):
"""A primitive measure"""
domain: Domain

class Lebesgue(PrimitiveMeasure):
def __init__(self, domain: Domain):
self.domain = domain

class NormalMeasure(Measure):
def __init__(self, loc, scale):
self.parameters = {
"loc": reals(loc),
"scale": positive(scale)
}
self.base_measure = Lebesgue(reals)
self.rv = at.random.normal


Sample

import aesara.random as ar
from multipledispatch import dispatch

def sample(rng, measure: Measure):
return measure.rv(rng=rng, **parameters)


Logdensity

import singledispatch

@singledispatch.register
def logdensity(m: measure):
raise NotImplementedError(f"No density associated with the provided {measure}")

@logdensity.register(NormalMeasure)
def normal_logdensity(m: Measure, values):
(value,) = values
mu, sigma = m.parameters["loc"], m.parameters["scale"]
res = (
-0.5 * at.pow((value - mu) / sigma, 2)
- at.log(at.sqrt(2.0 * np.pi))
- at.log(sigma)
)
return res



Other

def logdensity(measure, x):
return _logdensity(measure, measure.base_measure, x)

def sample(rng_key, measure):


import singledispatch

class Measure(abc.ABC):
"""A variable that represents a probability measure."""

class Lebesgue(PrimitiveMeasure):

def __init__(self, domain):
self.domain = domain

class NormalMeasure(Measure):

def __init__(self, mu, sigma):
self.mu = reals(mu)
self.sigma = positive(sigma)
self.base_measure = Lebesgue(reals)