# Rewriting Aesara graphs

This article is a port of Brandon Willard's Tour of the Symbolic PyMC library, and is a simplified version of the example in Aesara's documentation. The text is almost a verbatim copy of the original, but mistakes are obviously mine.

In this document we will be implementing a symbolic "search-and-replace" that changes aesara graphs like `at.dot(A, x+y)`

to `at.dot(A, x) + at.dot(A, y)`

. In other words we will demonstrate how to implement the distributive property of the matrix multiplication so it can be applied to any aesara graph. Aesara allows one to implement rewrite rules like the distributive property—and many other sophisticated manipulation of graphs—by providing flexible, pure Python versions of core operations in symbolic computation. These operations are then combined and orchestrated through the relational programming DSL miniKanren.

More specifically, we’ll introduce the basic unification and reification operations and explicitly show how they relate to graph manipulation and the modeling of high-level mathematical relations. Along the way, we’ll cover some of the necessary details behind Aesara graphs.

We start by creating a graph of our target expressions–i.e. `at.dot(A, x + y)`

in Aesara. We need to do this in order to determine exactly what we’re searching for and–later–what to put in its place.

import aesara.tensor as at A_tt = at.matrix("A") x_tt = at.vector("x") y_tt = at.vector("y") z_tt = at.dot(A_tt, x_tt + y_tt)

We can get a text print-out of the graph using the debug print function `dprint`

```
import aesara
aesara.dprint(z_tt)
```

dot [id A] '' |A [id B] |Elemwise{add,no_inplace} [id C] '' |x [id D] |y [id E]

The output of `dprint`

shows the underlying operators (`dot`

, `add`

) and their arguments.

To "math/search for" combinations of Aesara operations–or, as we just saw, graphs–we use **unification**; to "replace" parts of the graph (well, produce a copy with replaced parts) we use **reificatoin**. Aesara provides support for these via **expression-tuples**.

## S-expressions

We can convert an Aesara graphs into an S-expression-like form using etuples.

from etuples import etuple, etuplize from IPython.lib.pretty import pprint z_et = etuplize(z_tt) pprint(z_et)

e( e(aesara.tensor.math.Dot), A, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))

An `etuple`

is like a normal `tuple`

, except that its first element is a `Callable`

and the remaining elements are the `Callable`

's arguments. As above, a pretty-printed `etuple`

looks like a `tuple`

prefixed by an `e`

.

By working with `etuples`

we can use **arbitrary Python functions** in conjunction with Aesara graphs and logic variable arguments. Basically, and `etuple`

can be manipulated until all of its constituent logic variables are replaced with valid arguments to the function/operators. At that point the etuple can be evaluated.

For instance we can create an `etuple`

that uses the function `at.add`

with a logic variable argument.

from unification import var x_lv, y_lv = var('x'), var('y') add_pattern = etuple(etuplize(at.add), x_lv, y_lv)

It wouldn't normally be possible to call the `at.add`

function with these argument types, as demonstrated in this example:

try: at.add(x_lv, 1) except NotImplementedError as e: print(str(e))

Cannot convert ~x to a tensor variable.

We'll get a similar error if we attempt to evaluate the `etuple`

by accessing its `ExpressionTuple.evaled_obj`

property. However, after performing a simple manipulation that replaces the logic variable with a valid input to `at.add`

(reificatoin), we are able to evaluate the `etuple`

and obtain an Aesara Tensor result, as demonstrated by the following code:

from unification import reify new_add_pattern = reify(add_pattern, {x_lv: 1., y_lv: 1.}) pprint(new_add_pattern)

e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), 1.0, 1.0)

pprint(new_add_pattern.evaled_obj)

Elemwise{add,no_inplace}.0

Working with S-expressions is much like manipulating a subset of Python AST, so, when using `etuples`

, one is–in effect-**meta programming** (e.g. by automating the production and evaluation of Aesara graphs using Python code). As a matter of fact, `etuples`

could be recast as `ast.Expr`

and `ast.Call`

objects that, though the use of `eval`

, could achieve the same results-albeit without the more convenient tuple-like structuring.

## Operators and their parameters

In etuplized-graph-print the `etuple`

form of our matrix-multuplication graph `z_et`

produced Aesaa operators

pprint(z_et[0])

e(aesara.tensor.math.Dot)

## Unification and reification

With the ability to use logic variables and Aesara graphs together, we can now "search" or "match" arbitrary graphs using **unification** and produce new graphs by replacing logic variables using **reification**.

We start by making "patterns" or templates for the subgraphs we would like to match. Patterns, in this case, take the form of S-expressions with the desired structure and logic variables in place of "unknown" or arbitrary terms that we might like to reference elsewhere.

`dot_pattern`

represents an S-expression that evaluateds to a graph in which two terms are matrix-multiplied.

from aesara.tensor.math import Dot A_lv, B_lv = var("A"), var("B") dot_pattern = etuple(etuple(Dot), A_lv, B_lv)

"Matching" a graph against this pattern is called **unification**. Unificatoin of two graphs implies unification of all sub-graphs and elements between them. When unification is successful, it returns a map of logic variables and their unified values. If there are no logic variables in the graphs it simply returns an empty map. If unification fails, it returns `False`

–at least in the implementation we use.

### Unification

We can perform unification using the function `unify`

. The result is a `dict`

mapping logic variables to their unified values.

from unification import unify s = unify(dot_pattern, z_et) pprint(s)

{~A: A, ~B: e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y)}

The logic variable `A`

has been correctly unified with `A_tt`

, while the logic variable `B`

has been correctly unified with the addition of `x_tt`

and `y_tt`

.

### Reification

Using `reify`

we can "fill-in"–or replace—the logic variables of our "pattern" with the matches obtained by `unify`

that are held within the variable s, or we could specify our own substitutions based on that information.

In the following snipped we simply exchange the `A_tt`

tensor with another `X_tt`

tensor and create a new graph with that value. The end result is a version of the original graph `z_et`

, with the new tensor.

X_tt = at.matrix("X") s[A_lv] = X_tt z_et_re = reify(dot_pattern, s) pprint(z_et_re)

e( e(aesara.tensor.math.Dot), X, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))

### Finishing our implementation

We can also reify an entirely different graph using the values extracted from the graph `z_et`

. In this case, we create an "output" pattern graph, to complement our new "input" pattern graph `dot_pattern`

. If we combine our dot product and addition `etuple`

patterns, we can extract all the arguments needed as input to a distributed multiplication pattern.

```
output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv))
pprint(output_pattern)
```

e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e(e(aesara.tensor.math.Dot), ~A, ~x), e(e(aesara.tensor.math.Dot), ~B, ~y))

With logic variables `A_lv`

, `x_lv`

and `y_lv`

mapped to their template-corresponding objects in another graph, we can reify `output_pattern`

and obtain a reified version of said graph.

Using the previous unification results contained in `s`

we only need to reify `output_pattern`

with those mappings. However, since our pattern refers to logic variables `x_lv`

and `y_lv`

we'll need to unify these logic variables with the appropriate terms in the graph.

```
s_add = unify(s[B_lv], add_pattern, s)
pprint(s_add)
```

{~A: X, ~B: e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y), ~x: x, ~y: y}

```
z_new = reify(output_pattern, s_add)
aesara.dprint(z_new.evaled_obj)
```

Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |X [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]

Using only the basics of unification and reification provided by Aesara one can extract specific elements from Aesara graphs and use them to implement mathematical identities/relations. Through clever use of multiple mathematical relations, one can–for example–construct graph **optimizations** that turn large classes of user-defined statistical models into computational tractable reformulations. Similarly, one can construct "normal forms" for models, making it possible to determine whether or not a user-defined model is suitable for a specific sampler.

Next we will introduce another major element of Aesara that orchestrates and simplifies sequences of unifications like we used earlier, provides control-flow-like capabilities, produces fully reified results of arbitrary forms and does so within a genuinely declarative formalism that carries much of the same power of logical programming: miniKanren!

## Relational programming in miniKanren

Aesara uses a Python implementation of the embedded domain-specific language miniKanren–provided by the `kanren`

package–to orchestrate more sophisticated uses of unification and reification. For a quick intro, see the basic introduction provided by the `kanren`

package. We'll cover most of the same basic material here.

To start, miniKanren uses **goals** (in the same sense as logic programming) to assert relations, and the `run`

function evaluates those goals and allows one to specify the exact amount and type of reified output desired from the **states** that satisfy the goals.

In their most basic form, miniKanren **states** are simply the substitution maps returned by unification, which–in the normal course of operations–are not dealt with directly.

### The basic goals

Normally, a user will only need to construct compound goals from a basic set of primitives. Arguably, the most primitive goal is the equivalence relation under unification denoted by `eq`

in Python.

In the following code block we ask for all successful results/reifications (signified by the `0`

argument) of the logic variable `var('q')`

for the goal `eq(var('q'), 1)`

, i.e. unify `var('q')`

with `1`

.

from kanren import run, eq q_lv = var('q') mk_res = run(0, q_lv, eq(q_lv, 1)) pprint(mk_res)

(1,)

Since miniKanren's `run`

always returns a stream of results, we obtain a tuple containing the reified values of `q_lv`

under the one possible state for which our stated goal successfully evaluates.

The other basic primitives represent conjunction and disjunction of miniKanren goals: `lall`

and `lany`

respectively.

from kanren import lall mk_res = run(0, q_lv, lall(eq(q_lv, 1), eq(q_lv, 2))) pprint(mk_res)

We just used `lall`

to obtain the conjunction of two unificatoin goals. Since we requested the same logic variable be unified with `1`

and `2`

simultaneously, which is imposssibe, we got back an empty stream of results–indicating failure.

Goal disjunction, `lany`

, will split a state stream accross goals, producing new distrinct states for each:

from kanren import lany mk_res = run(0, q_lv, lany(eq(q_lv, 1), eq(q_lv, 2))) pprint(mk_res)

The goal disjunction result shows that the logic variable `q_lv`

can be unified with either `1`

or `2`

under the two unification goals.

A common pattern of disjuntion and conjunction is called `conde`

, and it mirrors the Lisp function `cond`

, which is effectively a type compound `if ... elif ... elif ...`

. Specifically, `conde([x_1, ...], ..., [y_1,...])`

is the same as `lany(lall(x_1,...), ..., lall(y_1, ...))`

-i.e. a disjunction of goal conjunctions.

from kanren import conde r_lv = var("r") mk_res = run( 0, [q_lv, r_lv], conde( [eq(q_lv, 1), eq(r_lv, 10)], [eq(q_lv, 2), eq(r_lv, 20)] ) ) pprint(mk_res)

([1, 10], [2, 20])

We introduced another logic variable `r_lv`

and requested the reified values of a list containing both logic variables. The output resembles the idea thatif `q_lv`

is "equal" to `1`

, then `r_lv`

is "equal" to `10`

, etc. Unlike normal conditionals, each clause/branch isn't exclusive, instead each is realized when the goals in a branch can be successful.

The following code demonstrated when `conde`

can behave more like a traditional statement.

```
mk_res = run(0, [q_lv, r_lv],
lall(eq(q_lv, 1),
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)],
)))
pprint(mk_res)
```

([1, 10],)

### A better implementation

Since miniKanren uses unification and reification, we can apply its basic goals to Aesara graphs, as we did earlier, and reproduce the entire implementation in a much more concise manner.

```
mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv))
pprint(mk_res)
```

(e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e(e(aesara.tensor.math.Dot), A, x), e( e(aesara.tensor.math.Dot), e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y), y)),)

We obtain an etuple that we can evaluate to get the graph

aesara.dprint(mk_res[0].evaled_obj)

Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]

We did not need to use the conjunction operation `lall`

explicitly, because all remaining goal arguments to `run`

are automatically applied in conjunction.

Before moving on to the next section and goal construction, let us summarize everything we did in a self-contained exampe:

import aesara import aesara.tensor as at from aesara.tensor.math import Dot from etuples import etuple, etuplize from kanren import eq, run from unification import var from IPython.lib.pretty import pprint # Define the graph we want to "modify" A_tt = at.matrix("A") x_tt = at.vector("x") y_tt = at.vector("y") z_tt = at.dot(A_tt, x_tt + y_tt) z_et = etuplize(z_tt) # Input patterns and logic variables x_lv, y_lv = var('x'), var('y') add_pattern = etuple(etuplize(at.add), x_lv, y_lv) A_lv, B_lv = var('A'), var('B') dot_pattern = etuple(etuple(Dot), A_lv, B_lv) # Output pattern output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv)) # Using miniKanren mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv)) aesara.dprint(mk_res[0].evaled_obj)

Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |InplaceDimShuffle{x} [id E] '' |dot [id F] '' |Elemwise{add,no_inplace} [id G] '' | |x [id D] | |y [id H] |y [id H]

When combinations of miniKanren goals comprise logical units, we can wrap their construction in functions which we call **goal constructors**.

### Goals Constructors

Using our distributive law example, we can create a goal constructor that creates our combined pattern and applies it in one go.

def distributeo(in_g, out_g): """Create a oal that represents commuted matrix multiplicatoin and addition.""" A_lv, x_lv, y_lv = var(), var(), var() dot_pattern = etuple(etuple(Dot), A_lv, etuple(etuplize(at.add), x_lv, y_lv)) dist_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), A_lv, y_lv)) return lall(eq(in_g, dot_pattern), eq(out_g, dist_pattern))

Our goal constructor represent the **relation** for distribution of matrix multiplication and addition. In this sense, it can be run **both** ways i.e. it can "expand" a multiplication by distributing it through addition, and it can "contract" it by doing the opposite.

In the following example we "expand" the multiplication:

q_lv = var() mk_res = run(1, q_lv, distributeo(z_et, q_lv)) z_expanded_et = mk_res[0].evaled_obj aesara.dprint(z_expanded_et)

Elemwise{add,no_inplace} [id A] '' |dot [id B] '' | |A [id C] | |x [id D] |dot [id E] '' |A [id C] |y [id F]

And in the following example we "contract" the previously expanded result

q_lv = var() mk_res = run(1, q_lv, distributeo(q_lv, z_expanded_et)) z_contracted_et = mk_res[0].evaled_obj aesara.dprint(z_contracted_et)

dot [id A] '' |A [id B] |Elemwise{add,no_inplace} [id C] '' |x [id D] |y [id E]

### Graph-based goals

In most situation the desired graphs will be subgraphs of much larger ones. Aesara introduces some miniKanren goals that apply other goals throughout graphs until a fixed-point is reached. This sequence of operations is generally necessary for graph simplification and rewriting.

In the following example we create a new graph that contains `at.dot(A, x+y)`

as a subgraph.

e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Mul at 0x7fd9c9f1a560>, <frozendict {}>), e( e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True), TensorConstant{2}), e( e(aesara.tensor.math.Dot), A, e( e( aesara.tensor.elemwise.Elemwise, <aesara.scalar.basic.Add at 0x7fd9c9f1a440>, <frozendict {}>), x, y))), e( e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True), TensorConstant{1.0}))

We define `graph_walko`

, a function that walks term graphs and will apply our `distributeo`

goal throughout the graph until the applicable subgraph is found (and replaced)

from etuples.core import ExpressionTuple from kanren.graph import walko from kanren import eq from functools import partial graph_walko = partial(walko, rator_goal=eq) q_lv = var() mk_res = run(1, q_lv, graph_walko(distributeo, z_graph_et, q_lv)) aesara.dprint(mk_res[0].evaled_obj)

Elemwise{add,no_inplace} [id A] '' |Elemwise{mul,no_inplace} [id B] '' | |InplaceDimShuffle{x} [id C] '' | | |TensorConstant{2} [id D] | |Elemwise{add,no_inplace} [id E] '' | |dot [id F] '' | | |A [id G] | | |x [id H] | |dot [id I] '' | |A [id G] | |y [id J] |InplaceDimShuffle{x} [id K] '' |TensorConstant{1.0} [id L]