# Rewriting Aesara graphs

This article is a port of Brandon Willard's Tour of the Symbolic PyMC library, and is a simplified version of the example in Aesara's documentation. The text is almost a verbatim copy of the original, but mistakes are obviously mine.

In this document we will be implementing a symbolic "search-and-replace" that changes aesara graphs like `at.dot(A, x+y)` to `at.dot(A, x) + at.dot(A, y)`. In other words we will demonstrate how to implement the distributive property of the matrix multiplication so it can be applied to any aesara graph. Aesara allows one to implement rewrite rules like the distributive property—and many other sophisticated manipulation of graphs—by providing flexible, pure Python versions of core operations in symbolic computation. These operations are then combined and orchestrated through the relational programming DSL miniKanren.

More specifically, we’ll introduce the basic unification and reification operations and explicitly show how they relate to graph manipulation and the modeling of high-level mathematical relations. Along the way, we’ll cover some of the necessary details behind Aesara graphs.

We start by creating a graph of our target expressions–i.e. `at.dot(A, x + y)` in Aesara. We need to do this in order to determine exactly what we’re searching for and–later–what to put in its place.

```import aesara.tensor as at

A_tt = at.matrix("A")
x_tt = at.vector("x")
y_tt = at.vector("y")
z_tt = at.dot(A_tt, x_tt + y_tt)
```

We can get a text print-out of the graph using the debug print function `dprint`

```import aesara

aesara.dprint(z_tt)
```
```dot [id A] ''
|A [id B]
|x [id D]
|y [id E]
```

The output of `dprint` shows the underlying operators (`dot`, `add`) and their arguments.

To "math/search for" combinations of Aesara operations–or, as we just saw, graphs–we use unification; to "replace" parts of the graph (well, produce a copy with replaced parts) we use reificatoin. Aesara provides support for these via expression-tuples.

## S-expressions

We can convert an Aesara graphs into an S-expression-like form using etuples.

```from etuples import etuple, etuplize
from IPython.lib.pretty import pprint

z_et = etuplize(z_tt)
pprint(z_et)
```
```e(
e(aesara.tensor.math.Dot),
A,
e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y))
```

An `etuple` is like a normal `tuple`, except that its first element is a `Callable` and the remaining elements are the `Callable`'s arguments. As above, a pretty-printed `etuple` looks like a `tuple` prefixed by an `e`.

By working with `etuples` we can use arbitrary Python functions in conjunction with Aesara graphs and logic variable arguments. Basically, and `etuple` can be manipulated until all of its constituent logic variables are replaced with valid arguments to the function/operators. At that point the etuple can be evaluated.

For instance we can create an `etuple` that uses the function `at.add` with a logic variable argument.

```from unification import var

x_lv, y_lv = var('x'), var('y')
```

It wouldn't normally be possible to call the `at.add` function with these argument types, as demonstrated in this example:

```try:
except NotImplementedError as e:
print(str(e))
```
```Cannot convert ~x to a tensor variable.
```

We'll get a similar error if we attempt to evaluate the `etuple` by accessing its `ExpressionTuple.evaled_obj` property. However, after performing a simple manipulation that replaces the logic variable with a valid input to `at.add` (reificatoin), we are able to evaluate the `etuple` and obtain an Aesara Tensor result, as demonstrated by the following code:

```from unification import reify

```
```e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
1.0,
1.0)
```
```pprint(new_add_pattern.evaled_obj)
```
```Elemwise{add,no_inplace}.0
```

Working with S-expressions is much like manipulating a subset of Python AST, so, when using `etuples`, one is–in effect-meta programming (e.g. by automating the production and evaluation of Aesara graphs using Python code). As a matter of fact, `etuples` could be recast as `ast.Expr` and `ast.Call` objects that, though the use of `eval`, could achieve the same results-albeit without the more convenient tuple-like structuring.

## Operators and their parameters

In etuplized-graph-print the `etuple` form of our matrix-multuplication graph `z_et` produced Aesaa operators

```pprint(z_et[0])
```
```e(aesara.tensor.math.Dot)
```

## Unification and reification

With the ability to use logic variables and Aesara graphs together, we can now "search" or "match" arbitrary graphs using unification and produce new graphs by replacing logic variables using reification.

We start by making "patterns" or templates for the subgraphs we would like to match. Patterns, in this case, take the form of S-expressions with the desired structure and logic variables in place of "unknown" or arbitrary terms that we might like to reference elsewhere.

`dot_pattern` represents an S-expression that evaluateds to a graph in which two terms are matrix-multiplied.

```from aesara.tensor.math import Dot

A_lv, B_lv = var("A"), var("B")
dot_pattern = etuple(etuple(Dot), A_lv, B_lv)
```

"Matching" a graph against this pattern is called unification. Unificatoin of two graphs implies unification of all sub-graphs and elements between them. When unification is successful, it returns a map of logic variables and their unified values. If there are no logic variables in the graphs it simply returns an empty map. If unification fails, it returns `False`–at least in the implementation we use.

### Unification

We can perform unification using the function `unify`. The result is a `dict` mapping logic variables to their unified values.

```from unification import unify

s = unify(dot_pattern, z_et)
pprint(s)
```
```{~A: A,
~B: e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y)}
```

The logic variable `A` has been correctly unified with `A_tt`, while the logic variable `B` has been correctly unified with the addition of `x_tt` and `y_tt`.

### Reification

Using `reify` we can "fill-in"–or replace—the logic variables of our "pattern" with the matches obtained by `unify` that are held within the variable s, or we could specify our own substitutions based on that information.

In the following snipped we simply exchange the `A_tt` tensor with another `X_tt` tensor and create a new graph with that value. The end result is a version of the original graph `z_et`, with the new tensor.

```X_tt = at.matrix("X")
s[A_lv] = X_tt

z_et_re = reify(dot_pattern, s)
pprint(z_et_re)
```
```e(
e(aesara.tensor.math.Dot),
X,
e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y))
```

### Finishing our implementation

We can also reify an entirely different graph using the values extracted from the graph `z_et`. In this case, we create an "output" pattern graph, to complement our new "input" pattern graph `dot_pattern`. If we combine our dot product and addition `etuple` patterns, we can extract all the arguments needed as input to a distributed multiplication pattern.

```output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv))
pprint(output_pattern)
```
```e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
e(e(aesara.tensor.math.Dot), ~A, ~x),
e(e(aesara.tensor.math.Dot), ~B, ~y))
```

With logic variables `A_lv`, `x_lv` and `y_lv` mapped to their template-corresponding objects in another graph, we can reify `output_pattern` and obtain a reified version of said graph.

Using the previous unification results contained in `s` we only need to reify `output_pattern` with those mappings. However, since our pattern refers to logic variables `x_lv` and `y_lv` we'll need to unify these logic variables with the appropriate terms in the graph.

```s_add = unify(s[B_lv], add_pattern, s)
```
```{~A: X,
~B: e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y),
~x: x,
~y: y}
```
```z_new = reify(output_pattern, s_add)
aesara.dprint(z_new.evaled_obj)
```
```Elemwise{add,no_inplace} [id A] ''
|dot [id B] ''
| |X [id C]
| |x [id D]
|InplaceDimShuffle{x} [id E] ''
|dot [id F] ''
| |x [id D]
| |y [id H]
|y [id H]
```

Using only the basics of unification and reification provided by Aesara one can extract specific elements from Aesara graphs and use them to implement mathematical identities/relations. Through clever use of multiple mathematical relations, one can–for example–construct graph optimizations that turn large classes of user-defined statistical models into computational tractable reformulations. Similarly, one can construct "normal forms" for models, making it possible to determine whether or not a user-defined model is suitable for a specific sampler.

Next we will introduce another major element of Aesara that orchestrates and simplifies sequences of unifications like we used earlier, provides control-flow-like capabilities, produces fully reified results of arbitrary forms and does so within a genuinely declarative formalism that carries much of the same power of logical programming: miniKanren!

## Relational programming in miniKanren

Aesara uses a Python implementation of the embedded domain-specific language miniKanren–provided by the `kanren` package–to orchestrate more sophisticated uses of unification and reification. For a quick intro, see the basic introduction provided by the `kanren` package. We'll cover most of the same basic material here.

To start, miniKanren uses goals (in the same sense as logic programming) to assert relations, and the `run` function evaluates those goals and allows one to specify the exact amount and type of reified output desired from the states that satisfy the goals.

In their most basic form, miniKanren states are simply the substitution maps returned by unification, which–in the normal course of operations–are not dealt with directly.

### The basic goals

Normally, a user will only need to construct compound goals from a basic set of primitives. Arguably, the most primitive goal is the equivalence relation under unification denoted by `eq` in Python.

In the following code block we ask for all successful results/reifications (signified by the `0` argument) of the logic variable `var('q')` for the goal `eq(var('q'), 1)`, i.e. unify `var('q')` with `1`.

```from kanren import run, eq

q_lv = var('q')
mk_res = run(0, q_lv, eq(q_lv, 1))
pprint(mk_res)
```
```(1,)
```

Since miniKanren's `run` always returns a stream of results, we obtain a tuple containing the reified values of `q_lv` under the one possible state for which our stated goal successfully evaluates.

The other basic primitives represent conjunction and disjunction of miniKanren goals: `lall` and `lany` respectively.

```from kanren import lall

mk_res = run(0, q_lv, lall(eq(q_lv, 1), eq(q_lv, 2)))
pprint(mk_res)
```

We just used `lall` to obtain the conjunction of two unificatoin goals. Since we requested the same logic variable be unified with `1` and `2` simultaneously, which is imposssibe, we got back an empty stream of results–indicating failure.

Goal disjunction, `lany`, will split a state stream accross goals, producing new distrinct states for each:

```from kanren import lany

mk_res = run(0, q_lv, lany(eq(q_lv, 1), eq(q_lv, 2)))
pprint(mk_res)
```

The goal disjunction result shows that the logic variable `q_lv` can be unified with either `1` or `2` under the two unification goals.

A common pattern of disjuntion and conjunction is called `conde`, and it mirrors the Lisp function `cond`, which is effectively a type compound `if ... elif ... elif ...`. Specifically, `conde([x_1, ...], ..., [y_1,...])` is the same as `lany(lall(x_1,...), ..., lall(y_1, ...))`-i.e. a disjunction of goal conjunctions.

```from kanren import conde

r_lv = var("r")

mk_res = run(
0,
[q_lv, r_lv],
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)]
)
)
pprint(mk_res)
```
```([1, 10], [2, 20])
```

We introduced another logic variable `r_lv` and requested the reified values of a list containing both logic variables. The output resembles the idea thatif `q_lv` is "equal" to `1`, then `r_lv` is "equal" to `10`, etc. Unlike normal conditionals, each clause/branch isn't exclusive, instead each is realized when the goals in a branch can be successful.

The following code demonstrated when `conde` can behave more like a traditional statement.

```mk_res = run(0, [q_lv, r_lv],
lall(eq(q_lv, 1),
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)],
)))
pprint(mk_res)
```
```([1, 10],)
```

### A better implementation

Since miniKanren uses unification and reification, we can apply its basic goals to Aesara graphs, as we did earlier, and reproduce the entire implementation in a much more concise manner.

```mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv))
pprint(mk_res)
```
```(e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
e(e(aesara.tensor.math.Dot), A, x),
e(
e(aesara.tensor.math.Dot),
e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y),
y)),)
```

We obtain an etuple that we can evaluate to get the graph

```aesara.dprint(mk_res[0].evaled_obj)
```
```Elemwise{add,no_inplace} [id A] ''
|dot [id B] ''
| |A [id C]
| |x [id D]
|InplaceDimShuffle{x} [id E] ''
|dot [id F] ''
| |x [id D]
| |y [id H]
|y [id H]
```

We did not need to use the conjunction operation `lall` explicitly, because all remaining goal arguments to `run` are automatically applied in conjunction.

Before moving on to the next section and goal construction, let us summarize everything we did in a self-contained exampe:

```import aesara
import aesara.tensor as at
from aesara.tensor.math import Dot

from etuples import etuple, etuplize
from kanren import eq, run
from unification import var

from IPython.lib.pretty import pprint

# Define the graph we want to "modify"
A_tt = at.matrix("A")
x_tt = at.vector("x")
y_tt = at.vector("y")
z_tt = at.dot(A_tt, x_tt + y_tt)

z_et = etuplize(z_tt)

# Input patterns and logic variables
x_lv, y_lv = var('x'), var('y')

A_lv, B_lv = var('A'), var('B')
dot_pattern = etuple(etuple(Dot), A_lv, B_lv)

# Output pattern
output_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), B_lv, y_lv))

# Using miniKanren
mk_res = run(1, output_pattern, eq(dot_pattern, z_et), eq(add_pattern, B_lv))
aesara.dprint(mk_res[0].evaled_obj)
```
```Elemwise{add,no_inplace} [id A] ''
|dot [id B] ''
| |A [id C]
| |x [id D]
|InplaceDimShuffle{x} [id E] ''
|dot [id F] ''
| |x [id D]
| |y [id H]
|y [id H]
```

When combinations of miniKanren goals comprise logical units, we can wrap their construction in functions which we call goal constructors.

### Goals Constructors

Using our distributive law example, we can create a goal constructor that creates our combined pattern and applies it in one go.

```def distributeo(in_g, out_g):
"""Create a oal that represents commuted matrix multiplicatoin and addition."""
A_lv, x_lv, y_lv = var(), var(), var()
dot_pattern = etuple(etuple(Dot), A_lv, etuple(etuplize(at.add), x_lv, y_lv))
dist_pattern = etuple(etuplize(at.add), etuple(etuple(Dot), A_lv, x_lv), etuple(etuple(Dot), A_lv, y_lv))

return lall(eq(in_g, dot_pattern), eq(out_g, dist_pattern))
```

Our goal constructor represent the relation for distribution of matrix multiplication and addition. In this sense, it can be run both ways i.e. it can "expand" a multiplication by distributing it through addition, and it can "contract" it by doing the opposite.

In the following example we "expand" the multiplication:

```q_lv = var()
mk_res = run(1, q_lv, distributeo(z_et, q_lv))
z_expanded_et = mk_res[0].evaled_obj
aesara.dprint(z_expanded_et)
```
```Elemwise{add,no_inplace} [id A] ''
|dot [id B] ''
| |A [id C]
| |x [id D]
|dot [id E] ''
|A [id C]
|y [id F]
```

And in the following example we "contract" the previously expanded result

```q_lv = var()
mk_res = run(1, q_lv, distributeo(q_lv, z_expanded_et))
z_contracted_et = mk_res[0].evaled_obj
aesara.dprint(z_contracted_et)
```
```dot [id A] ''
|A [id B]
|x [id D]
|y [id E]
```

### Graph-based goals

In most situation the desired graphs will be subgraphs of much larger ones. Aesara introduces some miniKanren goals that apply other goals throughout graphs until a fixed-point is reached. This sequence of operations is generally necessary for graph simplification and rewriting.

In the following example we create a new graph that contains `at.dot(A, x+y)` as a subgraph.

```e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
e(
e(
aesara.tensor.elemwise.Elemwise,
<aesara.scalar.basic.Mul at 0x7fd9c9f1a560>,
<frozendict {}>),
e(
e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True),
TensorConstant{2}),
e(
e(aesara.tensor.math.Dot),
A,
e(
e(
aesara.tensor.elemwise.Elemwise,
<frozendict {}>),
x,
y))),
e(
e(aesara.tensor.elemwise.DimShuffle, (), ('x',), True),
TensorConstant{1.0}))
```

We define `graph_walko`, a function that walks term graphs and will apply our `distributeo` goal throughout the graph until the applicable subgraph is found (and replaced)

```from etuples.core import ExpressionTuple
from kanren.graph import walko
from kanren import eq
from functools import partial

graph_walko = partial(walko, rator_goal=eq)

q_lv = var()
mk_res = run(1, q_lv, graph_walko(distributeo, z_graph_et, q_lv))
aesara.dprint(mk_res[0].evaled_obj)
```
```Elemwise{add,no_inplace} [id A] ''
|Elemwise{mul,no_inplace} [id B] ''
| |InplaceDimShuffle{x} [id C] ''
| | |TensorConstant{2} [id D]