Hidden Markov Models in Aesara

Implementing Hidden Markov Models is not easy in the existing probabilistic programming libraries in python, as they are unable to reconstruct the logprob because of the loops. In Aesara we just implement it like it is, as with .

Let us consider a hidden Markov model with \(N_t\) time steps and \(M_t\) possible states. The observation model is such that:

\begin{align*} (Y_t | S_t = c) &\sim \operatorname{N}\left(\mu_{c}, \sigma_{c}\right)\\ S_t &\sim \operatorname{Categorical}\left(\Gamma_0\right) \end{align*}

The hidden state sequence is defined by the Markov relation:

\begin{equation*} P(S_t|S_{t-1}) = \Gamma \end{equation*}

Where the transition matrix \(\Gamma\) is assumed to be constant.

import aesara
import aesara.tensor as at

srng = at.random.RandomStream(0)

N_tt = at.iscalar("N")
M_tt = at.iscalar("M")
Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")

mu_tt = at.vector("mus")
sigma_tt = 1.

def scan_fn(Gamma_t):
    S_t = srng.categorical(Gamma_t[0], name="S_t")
    Y_t = srng.normal(mu_tt[S_t], sigma_tt, name="Y_t")
    return Y_t, S_t

(Y_rv, S_rv), updates = aesara.scan(
    fn=scan_fn,
    non_sequences=[Gamma_rv],
    outputs_info=[{}, {}],
    strict=True,
    n_steps=N_tt
)

sample_fn = aesara.function((N_tt, M_tt, mu_tt), (Y_rv, S_rv), updates=updates)
print(sample_fn(10, 2, [-10, 10]))

Thanks to AePPL we can compute this models' logprobability function easily:

from aeppl import joint_logprob
import numpy as np

y_vv = Y_rv.clone()
s_vv = S_rv.clone()
Gamma_vv = Gamma_rv.clone()

values = {
    y_vv: np.random.normal(0, 1., size=10),
    s_vv: np.ones(10, dtype="int"),
    M_tt: 2,
    N_tt: 10,
    mu_tt: [-1., 1.],
    Gamma_vv:[[.5, .5], [.5, .5]],
}

logprob = joint_logprob({Y_rv: y_vv, S_rv: s_vv, Gamma_rv: Gamma_vv})
logprob_fn = aesara.function(list(values.keys()), logprob)
print(logprob_fn(*values.values()))

Links to this note