def nuts(data, model, seed=None, iter=None, warmup=None, num_chains=None): assert type(data) == dict assert type(model) == Model assert seed is None or type(seed) == int iter, warmup, num_chains = apply_default_hmc_args(iter, warmup, num_chains) if seed is None: seed = np.random.randint(0, 2**32, dtype=np.uint32).astype(np.int32) rng = random.PRNGKey(seed) kernel = NUTS(model.fn) # TODO: We could use a way of avoid requiring users to set # `--xla_force_host_platform_device_count` manually when # `num_chains` > 1 to achieve parallel chains. mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains) mcmc.run(rng, **data) samples = mcmc.get_samples() # Here we re-run the model on the samples in order to collect # transformed parameters. (e.g. `b`, `mu`, etc.) Theses are made # available via the return value of the model. transformed_samples = run_model_on_samples_and_data( model.fn, samples, data) all_samples = dict(samples, **transformed_samples) loc = partial(location, data, samples, transformed_samples, model.fn) return Samples(all_samples, partial(get_param, all_samples), loc)
def test_binomial_stable(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p'].dtype == np.float64
def run_inference(model, args, rng): kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng) return mcmc.get_samples()
def run_inference(model, args, rng, X, Y, D_H): if args.num_chains > 1: rng = random.split(rng, args.num_chains) start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng, X, Y, D_H) print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def benchmark_hmc(args, features, labels): step_size = np.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps rng = random.PRNGKey(1) start = time.time() kernel = NUTS(model, trajectory_length=trajectory_length) mcmc = MCMC(kernel, 0, args.num_samples) mcmc.run(rng, features, labels) print('\nMCMC elapsed time:', time.time() - start)
def run_inference(model, args, rng, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng, X, Y) print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def test_improper_prior(): true_mean, true_std = 1., 2. num_warmup, num_samples = 1000, 8000 def model(data): mean = numpyro.param('mean', 0.) std = numpyro.param('std', 1., constraint=constraints.positive) return numpyro.sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
def test_improper_normal(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.param('loc', 0., constraint=constraints.interval(0., alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_predictive(): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive_samples = predictive(random.PRNGKey(1), model, samples) assert predictive_samples.keys() == {"obs"} predictive_samples = predictive(random.PRNGKey(1), model, samples, return_sites=["beta", "obs"]) # check shapes assert predictive_samples["beta"].shape == (100, 5) assert predictive_samples["obs"].shape == (100, 1000, 5) # check sample mean assert_allclose(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1)
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data, collect_warmup=True, collect_fields=('z', 'num_steps', 'adapt_state.step_size')) samples = mcmc.get_samples() assert len(samples[0]['loc']) == num_warmup + num_samples assert_allclose(np.mean(samples[0]['loc'], 0), true_coef, atol=0.05)
def test_prior_with_sample_shape(): data = { "J": 8, "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]), "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]), } def schools_model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'], )) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) num_samples = 500 mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples) mcmc.run(random.PRNGKey(0)) assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
def main(args): jax_config.update('jax_platform_name', args.device) print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print('\nMCMC elapsed time:', time.time() - start) print_results(samples, transition_prob, emission_prob)
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = np.dot(a, a.T) true_prec = np.linalg.inv(true_cov) def potential_fn(z): return 0.5 * np.dot(z.T, np.dot(true_prec, z)) init_params = np.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(np.mean(samples), true_mean, atol=0.02) assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_chain(use_init_params, chain_method): N, dim = 3000, 3 num_chains = 2 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains) mcmc.chain_method = chain_method init_params = None if not use_init_params else \ {'coefs': np.tile(np.ones(dim), num_chains).reshape(num_chains, dim)} mcmc.run(random.PRNGKey(2), labels, init_params=init_params) samples = mcmc.get_samples() assert samples['coefs'].shape[0] == num_chains * num_samples assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
def run_inference(dept, male, applications, admit, rng, args): kernel = NUTS(glmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains) mcmc.run(rng, dept, male, applications, admit) return mcmc.get_samples()
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") nuts_kernel = NUTS(potential_fn=dual_moon_pe) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.])) vanilla_samples = mcmc.get_samples() adam = optim.Adam(0.001) rng_init, rng_train = random.split(random.PRNGKey(1), 2) guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True) svi = SVI(dual_moon_model, guide, elbo, adam) svi_state = svi.init(rng_init) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(args.num_samples,)) transform = guide.get_transform(params) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") # TODO: exlore why neutra samples are not good # Issue: https://github.com/pyro-ppl/numpyro/issues/256 nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(10), init_params=init_params) zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
def test_change_point(): # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696 warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = np.where( np.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(4), count_data) samples = mcmc.get_samples() tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32) tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_x64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64
import argparse import os import numpy as onp from jax.config import config as jax_config import jax.numpy as np import jax.random as random from jax.scipy.special import logsumexp import numpyro from numpyro import handlers import numpyro.distributions as dist from numpyro.examples.datasets import BASEBALL, load_dataset from numpyro.mcmc import MCMC, NUTS """ Original example from Pyro: https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py Example has been adapted from [1]. It demonstrates how to do Bayesian inference using NUTS (or, HMC) in Pyro, and use of some common inference utilities. As in the Stan tutorial, this uses the small baseball dataset of Efron and Morris [2] to estimate players' batting average which is the fraction of times a player got a base hit out of the number of times they went up at bat. The dataset separates the initial 45 at-bats statistics from the remaining season. We use the hits data from the initial 45 at-bats to estimate the batting average for each player. We then use the remaining season's data to validate the predictions from our models.