Skip to content

Instantly share code, notes, and snippets.

@aflaxman
Last active January 29, 2022 16:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aflaxman/c2c1b343d1550bda20fc813407e1c81d to your computer and use it in GitHub Desktop.
Save aflaxman/c2c1b343d1550bda20fc813407e1c81d to your computer and use it in GitHub Desktop.
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from collections import namedtuple
from jax import random, lax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, util, init_to_sample
from numpyro.infer.mcmc import MCMCKernel
ABCState = namedtuple("ABCState", ["z", "rng_key"])
class ABC(MCMCKernel):
def __init__(self, model, data, threshold, summary_statistic, max_attempts_per_sample
):
self._model = model
self._data = data
self._predictive = util.Predictive(self._model, num_samples=1)
self._threshold = jnp.array(threshold)
self._summary_statistic = summary_statistic
self._max_attempts_per_sample = max_attempts_per_sample
@property
def sample_field(self):
return "z"
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
assert rng_key.ndim == 1, "only non-vectorized, for now"
proposal = self._predictive(rng_key, *model_args, **model_kwargs)
return ABCState(proposal, rng_key)
def sample(self, state, model_args, model_kwargs):
def while_condition_func(val):
distance, rng_key, proposal, n = val
return jnp.logical_and(distance > self._threshold,
n < self._max_attempts_per_sample)
def while_body_func(val):
distance, rng_key, proposal, n = val
rng_key, sample_key = random.split(rng_key)
proposal = self._predictive(sample_key, *model_args, **model_kwargs)
# FIXME: need to resample the values of the observed vars here
distance = self._summary_statistic(self._data, proposal)
return (distance, rng_key, proposal, n+1)
distance, rng_key, proposal, n = \
lax.while_loop(while_condition_func,
while_body_func,
(jnp.inf, # distance
state.rng_key, # rng_key
state.z, # proposal
0 # iteration
))
proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta'])
return ABCState(proposal, rng_key)
def my_model():
with numpyro.plate('I', 4):
theta = numpyro.sample('theta', dist.Uniform(-10, 10))
def sum_exceeds_threshold(threshold, proposal):
return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf)
def my_run(model):
rng_key = random.PRNGKey(12345)
sum_lower_bound = jnp.array(-1)
kernel = ABC(model,
data=sum_lower_bound, threshold=1,
summary_statistic=sum_exceeds_threshold,
max_attempts_per_sample=1_000)
mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1)
mcmc.run(rng_key)
posterior_samples = mcmc.get_samples()
plt.plot(posterior_samples['theta'][:,0,:].sum(axis=1), label='trace')
plt.ylabel('theta')
plt.axhline(sum_lower_bound, linestyle='dashed', color='k', label='lower bound')
my_run(my_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment