Find horsehoe priors in an Aesara graph

In this short post we will try to unify subgraphs of an Aesara graph with a pattern that represents a horseshoe prior, a generalization of this unification in Aemcmc. We will consider the following negative binomial regresssion throughout:

import aesara.tensor as at

srng = at.random.RandomStream(0)

X_tt = at.matrix('X')
h_tt = at.scalar('h', dtype="int32")

tau_rv = srng.halfcauchy(1, size=1)
lmbda_rv = srng.halfcauchy(1, size=10)
beta_rv = srng.normal(0, tau_rv * lmbda_rv)

eta = X_tt @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.nbinom(h_tt, p)

Unification on the graphical structure

Remember that the horseshoe prior is mathematically defined as:

\begin{align*} \tau &\sim \operatorname{HalfCauchy}(0, 1)\\ \lambda_{j} &\sim \operatorname{HalfCauchy}(0, 1)\quad j \in \left[1,\dots, k\right]\\ \beta_{j} &\sim \operatorname{Normal}(0, \tau \;\lambda_{j})\quad j \in \left[1,\dots, k\right]\\ \end{align*}

We would like to be able to tell whether the model that the graph implements contains a horseshoe prior, and if so return the variables that represent \(\lambda\) and \(\tau\). We start by define a horseshoe pattern against which we are going to try to match the model, using unification's logic variables and etuples' emulation of S-expressions:

from unification import var
from etuples import etuple, etuplize

horseshoe_1_lv, horseshoe_2_lv = var('horseshoe_1'), var('horsehoe_2')
zero_lv = var('zero')
horseshoe_pattern = etuple(
    etuplize(at.random.normal),
    var(),
    var(),
    var(),
    zero_lv,
    etuple(
        etuplize(at.mul),
        horseshoe_1_lv,
        horseshoe_2_lv)
)

We can then unify this pattern with the subgraph beta_rv using pythological's unification package:

from unification import unify
from IPython.lib.pretty import pprint

s = unify(horseshoe_pattern, etuplize(beta_rv))
pprint(s)
{~_32: RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCBFFD8740>),
 ~_33: TensorConstant{[]},
 ~_34: TensorConstant{11},
 ~zero: TensorConstant{0},
 ~horseshoe_1: e(
   e(
     aesara.tensor.random.basic.HalfCauchyRV,
     'halfcauchy',
     0,
     (0, 0),
     'floatX',
     False),
   RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCBFFD8F20>),
   TensorConstant{(1,) of 1},
   TensorConstant{11},
   TensorConstant{1},
   TensorConstant{1.0}),
 ~horsehoe_2: e(
   e(
     aesara.tensor.random.basic.HalfCauchyRV,
     'halfcauchy',
     0,
     (0, 0),
     'floatX',
     False),
   RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCCBFFCB820>),
   TensorConstant{(1,) of 10},
   TensorConstant{11},
   TensorConstant{1},
   TensorConstant{1.0})}

Model unification

The pattern we used in the previous section only checks that the graphical structure matches the horseshoe prior, but not that the correct distributions and correct parameters are used. However, we can only identify an aesara graph with the mathematical model if structure and distributions match (or if the pattern represents the normal form of the model under a set of rewrite rules, but we'll keep things simple for now).

Let us thus improve the design by defining a unify_horseshoe function that not only checks the basic graphical structure but also the shapes and distributions. This function behaves similarly to unify; it returns the random variables (lambda_rv, tau_rv) when the subgraph is identified, and False if unification is impossible.

from aesara.graph.unify import eval_if_etuple
from etuples import etuple, etuplize
from unification import unify, var

def unify_horseshoe(graph):
    horseshoe_1_lv, horseshoe_2_lv = var('horseshoe_1'), var('horsehoe_2')
    zero_lv = var('zero')
    horseshoe_pattern = etuple(
        etuplize(at.random.normal),
        var(),
        var(),
        var(),
        zero_lv,
        etuple(
            etuplize(at.mul),
            horseshoe_1_lv,
            horseshoe_2_lv)
    )

    s = unify(graph, horseshoe_pattern)
    if s is False:
        return False

    # Check that horseshoe_1 was unified with a half-cauchy distributed RV
    halfcauchy_1 = eval_if_etuple(s[horseshoe_1_lv])
    if halfcauchy_1.owner is None or not isinstance(
        halfcauchy_1.owner.op, type(at.random.halfcauchy)
    ):
         return False

    # Check that horseshoe_2 was unified with a half-cauchy distributed RV
    halfcauchy_2 = eval_if_etuple(s[horseshoe_2_lv])
    if halfcauchy_2.owner is None or not isinstance(
        halfcauchy_2.owner.op, type(at.random.halfcauchy)
    ):
        return False
    # Check that at least one of the RVs is a scalar
    if halfcauchy_1.type.shape == (1,):
        lmbda_rv = halfcauchy_2
        tau_rv = halfcauchy_1
    elif halfcauchy_2.type.shape == (1,):
        lmbda_rv = halfcauchy_1
        tau_rv = halfcauchy_2
    else:
        return false

    return (lmbda_rv, tau_rv)

Again we check that we can unify the subgraph beta_rv:

print(unify_horseshoe(beta_rv))
(halfcauchy_rv{0, (0, 0), floatX, False}.out, halfcauchy_rv{0, (0, 0), floatX, False}.out)

Walk the graph

unify_horseshoe will only work if the pattern matches the whole graph. Indeed if we try to unify the horsehoe pattern with Y_rv we get:

print(unify_horseshoe(Y_rv))
False

To identify subgraphs we thus need to walk through the graph (here breadth-first) and attempt unification at each step:

from aesara.graph.basic import walk
from aesara.tensor.random.op import RandomVariable

def expand(var):
    if var.owner:
        return var.owner.inputs
    else:
        return

for node in walk([Y_rv], expand, bfs=True):
    try:
        if isinstance(node.owner.op, RandomVariable):
            s = unify_horseshoe(node)
            if s:
                break
    except AttributeError:
        continue

We can check that \(\tau\) and \(\lambda\) have been correctly identified:

pprint(s)
(halfcauchy_rv{0, (0, 0), floatX, False}.out,
 halfcauchy_rv{0, (0, 0), floatX, False}.out)

Being able to unify a pattern with a subgraph is a (small) first step towards being able to assign sampling steps to random variables in an arbitrary graph. In the following post we will show how we can automatically build a Gibbs sampler for (sub)graphs that represent a horseshoe prior.