def nuts(data, model, seed=None, iter=None, warmup=None, num_chains=None): assert type(data) == dict assert type(model) == Model assert seed is None or type(seed) == int iter, warmup, num_chains = apply_default_hmc_args(iter, warmup, num_chains) if seed is None: seed = np.random.randint(0, 2**32, dtype=np.uint32).astype(np.int32) rng = random.PRNGKey(seed) kernel = NUTS(model.fn) # TODO: We could use a way of avoid requiring users to set # `--xla_force_host_platform_device_count` manually when # `num_chains` > 1 to achieve parallel chains. mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains) mcmc.run(rng, **data) samples = mcmc.get_samples() # Here we re-run the model on the samples in order to collect # transformed parameters. (e.g. `b`, `mu`, etc.) Theses are made # available via the return value of the model. transformed_samples = run_model_on_samples_and_data( model.fn, samples, data) all_samples = dict(samples, **transformed_samples) loc = partial(location, data, samples, transformed_samples, model.fn) return Samples(all_samples, partial(get_param, all_samples), loc)
def test_binomial_stable(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p'].dtype == np.float64
def run_inference(model, args, rng): kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng) return mcmc.get_samples()
def run_inference(model, args, rng, X, Y, D_H): if args.num_chains > 1: rng = random.split(rng, args.num_chains) start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng, X, Y, D_H) print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def benchmark_hmc(args, features, labels): step_size = np.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps rng = random.PRNGKey(1) start = time.time() kernel = NUTS(model, trajectory_length=trajectory_length) mcmc = MCMC(kernel, 0, args.num_samples) mcmc.run(rng, features, labels) print('\nMCMC elapsed time:', time.time() - start)
def run_inference(model, args, rng, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng, X, Y) print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def get_samples(rng, data, step_size, trajectory_length, target_accept_prob): kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length, target_accept_prob=target_accept_prob) mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method) mcmc.run(rng, data) return mcmc.get_samples()
def test_improper_normal(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.param('loc', 0., constraint=constraints.interval(0., alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_improper_prior(): true_mean, true_std = 1., 2. num_warmup, num_samples = 1000, 8000 def model(data): mean = numpyro.param('mean', 0.) std = numpyro.param('std', 1., constraint=constraints.positive) return numpyro.sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
def test_predictive(): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive_samples = predictive(random.PRNGKey(1), model, samples) assert predictive_samples.keys() == {"obs"} predictive_samples = predictive(random.PRNGKey(1), model, samples, return_sites=["beta", "obs"]) # check shapes assert predictive_samples["beta"].shape == (100, 5) assert predictive_samples["obs"].shape == (100, 1000, 5) # check sample mean assert_allclose(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1)
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data, collect_warmup=True, collect_fields=('z', 'num_steps', 'adapt_state.step_size')) samples = mcmc.get_samples() assert len(samples[0]['loc']) == num_warmup + num_samples assert_allclose(np.mean(samples[0]['loc'], 0), true_coef, atol=0.05)
def test_dirichlet_categorical(kernel_cls, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_unnormalized_normal(kernel_cls, dense_mass): true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_params = np.array(0.) kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=9, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) hmc_states = mcmc.get_samples() assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert hmc_states.dtype == np.float64
def test_prior_with_sample_shape(): data = { "J": 8, "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]), "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]), } def schools_model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'], )) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) num_samples = 500 mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples) mcmc.run(random.PRNGKey(0)) assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
def test_logistic_regression(kernel_cls): N, dim = 3000, 3 warmup_steps, num_samples = 1000, 8000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = kernel_cls(model=model, trajectory_length=10) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), labels) samples = mcmc.get_samples() assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = np.dot(a, a.T) true_prec = np.linalg.inv(true_cov) def potential_fn(z): return 0.5 * np.dot(z.T, np.dot(true_prec, z)) init_params = np.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(np.mean(samples), true_mean, atol=0.02) assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_diverging(kernel_cls, adapt_step_size): data = random.normal(random.PRNGKey(0), (1000, )) def model(data): loc = numpyro.sample('loc', dist.Normal(0., 1.)) numpyro.sample('obs', dist.Normal(loc, 1), obs=data) kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False) num_warmup = num_samples = 1000 mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(1), data, collect_fields=('z', 'diverging'), collect_warmup=True) num_divergences = mcmc.get_samples()[1].sum() if adapt_step_size: assert num_divergences <= num_warmup else: assert_allclose(num_divergences, num_warmup + num_samples)
def main(args): jax_config.update('jax_platform_name', args.device) print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print('\nMCMC elapsed time:', time.time() - start) print_results(samples, transition_prob, emission_prob)
def test_beta_bernoulli(kernel_cls): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) kernel = kernel_cls(model=model, trajectory_length=1.) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_chain(use_init_params, chain_method): N, dim = 3000, 3 num_chains = 2 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains) mcmc.chain_method = chain_method init_params = None if not use_init_params else \ {'coefs': np.tile(np.ones(dim), num_chains).reshape(num_chains, dim)} mcmc.run(random.PRNGKey(2), labels, init_params=init_params) samples = mcmc.get_samples() assert samples['coefs'].shape[0] == num_chains * num_samples assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
def run_inference(dept, male, applications, admit, rng, args): kernel = NUTS(glmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains) mcmc.run(rng, dept, male, applications, admit) return mcmc.get_samples()
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") nuts_kernel = NUTS(potential_fn=dual_moon_pe) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.])) vanilla_samples = mcmc.get_samples() adam = optim.Adam(0.001) rng_init, rng_train = random.split(random.PRNGKey(1), 2) guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True) svi = SVI(dual_moon_model, guide, elbo, adam) svi_state = svi.init(rng_init) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(args.num_samples,)) transform = guide.get_transform(params) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") # TODO: exlore why neutra samples are not good # Issue: https://github.com/pyro-ppl/numpyro/issues/256 nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(10), init_params=init_params) zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
def test_change_point(): # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696 warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = np.where( np.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(4), count_data) samples = mcmc.get_samples() tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32) tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_x64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64