Cyclical SGLD in Blackjax

Cyclical schedule

from typing import NamedTuple

class ScheduleState(NamedTuple):
    step_size: float
    do_sample: bool

def build_schedule(
    cycle_length = num_training_steps // num_cycles

    def schedule_fn(step_id):
        do_sample = False
        if ((step_id % cycle_length)/cycle_length) >= exploration_ratio:
            do_sample = True

        cos_out = jnp.cos(jnp.pi * (step_id % cycle_length) / cycle_length) + 1
        step_size = 0.5 * cos_out * initial_step_size

        return ScheduleState(step_size, do_sample)

    return schedule_fn

Let us visualize the schedule for 200k training steps divided in 4 cycles. At each cycle 1/4th of the steps are dedicated to exploration.


Cyclical SGLD step

We will reproduce one of the paper's example, sampling from an array of 25 gaussians.

import itertools

import jax
import jax.scipy as jsp
import jax.numpy as jnp

lmbda = 1/25
positions = [-4, -2, 0, 2, 4]
mu = jnp.array([list(prod) for prod in itertools.product(positions, positions)])
sigma = 0.03 * jnp.eye(2)

def logprob_fn(x, *_):
    return lmbda * jsp.special.logsumexp(
        jax.scipy.stats.multivariate_normal.logpdf(x, mu, sigma)

def sample_fn(rng_key):
    choose_key, sample_key = jax.random.split(rng_key)
    samples = jax.random.multivariate_normal(sample_key, mu, sigma)
    return jax.random.choice(choose_key, samples)

Let's plot the model's density; we will need the plot later to evaluate the sampler


Sample from the mixture of gaussians

The sampling kernel must be able to alternate between sampling and optimization periods that are determined by the scheduler.

from typing import NamedTuple

import blackjax
import optax

from blackjax.types import PyTree
from optax._src.base import OptState

class CyclicalSGMCMCState(NamedTuple):
    """State of the Cyclical SGMCMC sampler."""
    position: PyTree
    opt_state: OptState

def cyclical_sgld(grad_estimator_fn, loglikelihood_fn):

    sgld = blackjax.sgld(grad_estimator_fn)
    sgd = optax.sgd(1.)

    def init_fn(position):
        """Initialize Cyclical SGLD's state."""
        opt_state = sgd.init(position)
        return CyclicalSGMCMCState(position, opt_state)

    def step_fn(
        schedule_state: ScheduleState,
        state: CyclicalSGMCMCState,
        minibatch: PyTree
        """Cyclical SGLD kernel.

        TODO: Organize the inputs to match the SGLD API better.

            Key for JAX's pseudo-random number generator.
            The current state of the scheduler. Indicates whether the kernel
            should be sampling or optimizing, and the current step size.
            The current state of the Cyclical SGLD sampler.
            Not used in the mixture example, but this is where you would pass
            batches of data in any real application.


        def step_with_sgld(current_state):
            rng_key, state, minibatch, step_size = current_state
            new_position = sgld(rng_key, state.position, minibatch, step_size)
            return CyclicalSGMCMCState(new_position, state.opt_state)

        def step_with_sgd(current_state):
            _, state, minibatch, step_size = current_state
            grads = grad_estimator_fn(state.position, minibatch)
            rescaled_grads = - 1. * step_size * grads
            updates, new_opt_state = sgd.update(rescaled_grads, state.opt_state, state.position)
            new_position = optax.apply_updates(state.position, updates)
            return CyclicalSGMCMCState(new_position, new_opt_state)

        new_state = jax.lax.cond(
            (rng_key, state, minibatch, schedule_state.step_size)

        return new_state

    return init_fn, step_fn


Let's start with SGLD:

import jax
from fastprogress import progress_bar

# 50k iterations
num_training_steps = 50000
schedule_fn = lambda k: 0.05 * k ** (-0.55)
# TODO: There is no need to pre-compute the schedule
schedule = [schedule_fn(i) for i in range(1, num_training_steps+1)]

grad_fn = lambda x, _: jax.grad(logprob_fn)(x)
sgld = blackjax.sgld(grad_fn)

rng_key = jax.random.PRNGKey(3)
init_position = -10 + 20 * jax.random.uniform(rng_key, shape=(2,))

position = init_position
sgld_samples = []
for i in progress_bar(range(num_training_steps)):
    _, rng_key = jax.random.split(rng_key)
    position = jax.jit(sgld)(rng_key, position, 0, schedule[i])

Let's plot the trajectory:


Cyclical SGLD

Now let's sample using Cyclical SGLD.

import jax
from fastprogress import progress_bar

# 50k iterations
# M = 30
# initial step size = 0.09
# ratio exploration = 1/4
num_training_steps = 50000
schedule_fn = build_schedule(num_training_steps, 30, 0.09, 0.25)
# TODO: There is no need to pre-compute the schedule
schedule = [schedule_fn(i) for i in range(num_training_steps)]

grad_fn = lambda x, _: jax.grad(logprob_fn)(x)
init, step = cyclical_sgld(grad_fn, logprob_fn)

rng_key = jax.random.PRNGKey(3)
init_position = -10 + 20 * jax.random.uniform(rng_key, shape=(2,))
init_state = init(init_position)

state = init_state
cyclical_samples = []
for i in progress_bar(range(num_training_steps)):
    _, rng_key = jax.random.split(rng_key)
    state = jax.jit(step)(rng_key, schedule[i], state, 0)
    if schedule[i].do_sample:

It looks from the trajectory that the distribution is better explored:


Let's look at the distribution:


What's next

  • As Adrien Corenflos noted, Scott's rule for KDE assumes that the total number of points is the sample size, so is not fit for MCMC samples. We should instead pass the bandwidth manually \(n_{eff}^{(-1/(d+4))}\) with \(d=2\); It should capture more modes.
  • Compute the paper's Mode-coverage metric: when the number of samples falling within the radius \(r\) of a mode center is larger than a number \(N\) when we say the mode is covered;
  • Use on a "real" problem: CIFAR-100 with Resnet18 for instance;