Waste-free Sequential Monte Carlo

If there are \(N\) particles, we only resample \(M = N/P\) particles. Each resampled particle is moved \(P-1\) times and each iterate is taken to form a new sample of size \(N\).

How can we integrate both processes in the same algorithm?

The implement of sequential MC first consists in a resampling step:

resampling_idx = resampling_fn(weights, rng_key, num_resampled)
particles = jax.tree_map(lambda x: x[resampling_index], particles)

Waste-free SMC may be recast as a standard SMC sampler that propopagates and reweighs particles that are markov chains of length P.

Indeed we have \(M\) chains of length \(P\).

\(z \in \mathcal{Z} = \Xi^P\). The potential functions are as:

\begin{equation} G^{wf}(z) = \frac{1}{P} \sum_{p=1}^{P} G_{t}(z[p]) \end{equation}

and the initial distribution

\begin{equation} \nu^{wf}(\mathrm{d}z) = \prod_{p=1}^{P} \nu(\mathrm{d}z[p]) \end{equation}

And the transition kernel:

 M_{t}^{wf}(z_{t-1}, \mathrm{d}z_{t}) = \left\{\sum_{p=1}^{P}\right\}
  1. Choose one chain of length \(P\) with probability \(\propto \sum_p G_{t-1}(z_{t-1}[p])\);
  2. Choose one component \(q\) randomly with probability \(\propto G_{t-1}(z_{t-1}[q])\) to be the starting point of the next chain;
  3. Repeat

Which leads to the algorithm 2 in the paper:

  1. sampledidx <- resample(M, weights)
  2. sampledparticles = particles[sampledidx]
  3. traces = jax.scan(jax.vmap(mcmc)(keys, newparticles), P)
  4. particles = traces.flatten()
  5. weights = G(particles)
  6. return particles, weights

The equivalent algorithm 1 would be:

  1. sampledparticles = particles
  2. traces = jax.scan(jax.vmap(mcmc)(keys, newparticles), P)
  3. particles = traces[-1, :]
  4. weights = G(particles)
  5. return particles, weights


Let us now refactor the original smc code:

def smc(
    mcmc_kernel_factory: Callable,
    mcmc_state_generator: Callable,
    resampling_fn: Callable,
    num_mcmc_iterations: int,
    is_waste_free: bool = False,

    if is_waste_free:
        num_mcmc_iterations = num_mcmc_iterations - 1

    def kernel(
        rng_key: PRNGKey,
        state: SMCState,
        logprob_fn: Callable,
        log_weight_fn: Callable,
    ) -> Tuple[SMCState, SMCInfo]:

        weights, particles = state
        scan_key, resampling_key = jax.random.split(rng_key, 2)

        num_particles = weights.shape[0]
        if is_waste_free:
            sub_num_particles, remainder = divmod(
                num_particles, num_mcmc_iterations + 1
            if remainder > 0:
                raise ValueError(
                    "`num_mcmc_iterations` must be a divider "
                    f"of `num_particles`, {num_mcmc_iterations} and "
                    f"{num_particles} were given"
            sub_num_particles = num_particles

        resampling_index = resampling_fn(weights, resampling_key, sub_num_particles)
        particles = jax.tree_map(lambda x: x[resampling_index], particles)

        # First advance the particles using the MCMC kernel
        mcmc_kernel = mcmc_kernel_factory(logprob_fn)

        def mcmc_body_fn(carry, curr_key):
            curr_particles, n_accepted = carry
            keys = jax.random.split(curr_key, sub_num_particles)
            new_particles, mcmc_info = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
                keys, curr_particles
            n_accepted = n_accepted + mcmc_info.is_accepted
            return (new_particles, n_accepted), new_particles

        mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
            particles, logprob_fn

        keys = jax.random.split(scan_key, num_mcmc_iterations)
        (proposed_states, total_accepted), proposed_states_history = jax.lax.scan(
            mcmc_body_fn, (mcmc_state, jnp.zeros((sub_num_particles,))), keys
        acceptance_rate = jnp.mean(total_accepted / num_mcmc_iterations)

        if is_waste_free:
            initial_position, tree_def = jax.tree_flatten(mcmc_state.position)
            chains_history, _ = jax.tree_flatten(proposed_states_history.position)

            position_history = [
                jnp.concatenate([jnp.expand_dims(elem1, 0), elem2])
                for elem1, elem2 in zip(initial_position, chains_history)
            position_history = jax.tree_unflatten(tree_def, position_history)
            proposed_particles = jax.tree_map(
                lambda z: jnp.reshape(z, (num_particles,) + z.shape[2:]),

            proposed_particles = proposed_states.position
        # Resample the particles depending on their respective weights
        log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)
        weights, log_likelihood_increment = _normalize(log_weights)

        state = SMCState(weights, proposed_particles)
        info = SMCInfo(resampling_index, log_likelihood_increment, acceptance_rate)
        return state, info

    return kernel

First separate vanilly and waste-free SMC:

