# Measures in AePPL

AePPL needs the notion of measure to be able to perform operations such as `truncate`

, or for more exotic use cases such as distributions defined on a non-trivial manifold. We will note \(\mathbb{M} \mathcal{M}\) the measure defined on the manifold \(\mathcal{M}\) (often \(\mathbb{R}\), but not necessarily). For instance the parametrized \(\operatorname{normal}\) measure (distribution) is defined as:

\[ \frac{\mu : \mathbb{R}, \sigma : \mathbb{R}^+}{\operatorname{normal}(\mu, \sigma) \in \mathbb{M}\mathbb{R}} \]

And the \(\operatorname{dirichlet}\) measure as:

\[ \frac{\boldsymbol{\alpha} : \mathbb{R}^n}{\operatorname{dirichlet}(\boldsymbol{\alpha}) \in \mathbb{M}\Delta_n} \]

We note \(\mathbb{T}A\) a tensor that contains elements of type \(A\). Broadcasting rules apply when parameters of different dimensionalities apply, and the result is a tensor of measures of type \(\mathbb{T}(\mathbb{M}E)\) where \(E\) is the event space's type.

To make the junction with Aesara's `RandomVariable`

\s we define the `sample`

operator. \(\operatorname{sample}(k, m)\) is a function that takes a PRNG key \(k \in \Omega\) , a measure \(m : \mathbb{M}E\) and returns an element \(e \in E\). Under the hood, `sample`

finds the `RandomVariable`

that corresponds to a given measure and creates a new apply node by calling the `Op`

.

sample :: PRNGKey -> Measure -> TensorVariable

## Base types

We must first define the types \(E\) that the event space can take. Our goal is to get a minimum viable example for the normal distributions so will limit ourselves to \(\mathbb{R}\) and \(\mathbb{R}^+\).

from aesara.raise_op import CheckAndRaise import aesara.tensor as at import abc class CheckParameterValue(CheckAndRaise): """Implements a parameter value check in a graph.""" def __init__(self, msg=""): super().__init__(TypeError, msg) def __str__(self): return f"Check{{{self.msg}}}" class Domain(abc.ABC) @abc.abstractmethod def __call__(self, x): pass class Real(Domain): def __call__(self, x): return CheckParameterValue("real")(x, at.isfinite(x)) class Positive(Domain): def __call__(self, x): return CheckParameterValue("x >= 0")(x, at.geq(x, 0), at.isfinite(x)) reals = Real() positive = Positive()

The constraints that characterize the types are represented in the Aesara graph with a `CheckParameterValue`

assertion.

## Measures

We now define the type for the measure \(\mathbb{M}E\). We include information about the base measure, noting \(\mathbb{L}\) and \(\mathbb{C}\) for the Lebesgue and counting measures respectively.

import aesara.tensor as at import abc class Measure(abc.ABC): """A variable that represents a probability measure.""" base_measure: Measure class PrimitiveMeasure(abs.ABC): """A primitive measure""" domain: Domain class Lebesgue(PrimitiveMeasure): def __init__(self, domain: Domain): self.domain = domain class NormalMeasure(Measure): def __init__(self, loc, scale): self.parameters = { "loc": reals(loc), "scale": positive(scale) } self.base_measure = Lebesgue(reals) self.rv = at.random.normal

## Sample

import aesara.random as ar from multipledispatch import dispatch def sample(rng, measure: Measure): return measure.rv(rng=rng, **parameters)

## Logdensity

import singledispatch @singledispatch.register def logdensity(m: measure): raise NotImplementedError(f"No density associated with the provided {measure}") @logdensity.register(NormalMeasure) def normal_logdensity(m: Measure, values): (value,) = values mu, sigma = m.parameters["loc"], m.parameters["scale"] res = ( -0.5 * at.pow((value - mu) / sigma, 2) - at.log(at.sqrt(2.0 * np.pi)) - at.log(sigma) ) return res

## Other

def logdensity(measure, x): return _logdensity(measure, measure.base_measure, x) def sample(rng_key, measure):

import singledispatch class Measure(abc.ABC): """A variable that represents a probability measure.""" class Lebesgue(PrimitiveMeasure): def __init__(self, domain): self.domain = domain class NormalMeasure(Measure): def __init__(self, mu, sigma): self.mu = reals(mu) self.sigma = positive(sigma) self.base_measure = Lebesgue(reals)