def test_chain(): N, dim = 3000, 3 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) rngs = random.split(random.PRNGKey(2), 2) init_params, potential_fn, constrain_fn = initialize_model( rngs, model, labels) samples = mcmc(num_warmup, num_samples, init_params, num_chains=2, potential_fn=potential_fn, constrain_fn=constrain_fn) assert samples['coefs'].shape[0] == 2 * num_samples assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
def run_inference(model, at_bats, hits, rng, args): if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, constrain_fn = initialize_model(rng, model, at_bats, hits) hmc_states = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn) return hmc_states
def test_logistic_regression(algo): N, dim = 3000, 3 warmup_steps, num_samples = 1000, 8000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, labels) samples = mcmc(warmup_steps, num_samples, init_params, sampler='hmc', algo=algo, potential_fn=potential_fn, trajectory_length=10, constrain_fn=constrain_fn) assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64
def main(args): jax_config.update('jax_platform_name', args.device) print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng = random.PRNGKey(2) if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, constrain_fn = initialize_model( rng, semi_supervised_hmm, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, ) start = time.time() samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains, potential_fn=potential_fn, constrain_fn=constrain_fn, progbar=True) print('\nMCMC elapsed time:', time.time() - start) print_results(samples, transition_prob, emission_prob)
def test_dirichlet_categorical(algo, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) samples = mcmc(warmup_steps, num_samples, init_params, constrain_fn=constrain_fn, progbar=False, print_summary=False, potential_fn=potential_fn, algo=algo, trajectory_length=1., dense_mass=dense_mass) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def run_inference(model, args, rng): if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, constrain_fn = initialize_model(rng, model) samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains, potential_fn=potential_fn, constrain_fn=constrain_fn) return samples
def benchmark_hmc(args, features, labels): step_size = np.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps rng = random.PRNGKey(1) if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, _ = initialize_model(rng, model, features, labels) start = time.time() mcmc(0, args.num_samples, init_params, num_chains=args.num_chains, potential_fn=potential_fn, trajectory_length=trajectory_length) print('\nMCMC elapsed time:', time.time() - start)
def run_inference(dept, male, applications, admit, rng, args): if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, constrain_fn = initialize_model( rng, glmm, dept, male, applications, admit) samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains, potential_fn=potential_fn, constrain_fn=constrain_fn) return samples
def run_inference(model, args, rng, X, Y, hypers): if args.num_chains > 1: rng = random.split(rng, args.num_chains) init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, hypers) start = time.time() samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn) print('\nMCMC elapsed time:', time.time() - start) return samples
def run_inference(model, args, rng, X, Y, D_H): init_params, potential_fn, constrain_fn = initialize_model( rng, model, X, Y, D_H) samples = mcmc(args.num_warmup, args.num_samples, init_params, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn) return samples
def test_improper_prior(): true_mean, true_std = 1., 2. num_warmup, num_samples = 1000, 8000 def model(data): mean = param('mean', 0.) std = param('std', 1., constraint=constraints.positive) return sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,)) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data) samples = mcmc(num_warmup, num_samples, init_params, potential_fn=potential_fn, constrain_fn=constrain_fn) assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
def test_uniform_normal(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) samples = mcmc(1000, 1000, init_params, potential_fn=potential_fn, constrain_fn=constrain_fn) assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = np.dot(a, a.T) true_prec = np.linalg.inv(true_cov) def potential_fn(z): return 0.5 * np.dot(z.T, np.dot(true_prec, z)) init_params = np.zeros(D) samples = mcmc(warmup_steps, num_samples, init_params, potential_fn=potential_fn, dense_mass=True) assert_allclose(np.mean(samples), true_mean, atol=0.02) assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
def main(args): jax_config.update('jax_platform_name', args.device) print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), semi_supervised_hmm, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, ) samples = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=potential_fn, constrain_fn=constrain_fn) print_results(samples, transition_prob, emission_prob)
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]), potential_fn=dual_moon_pe, progbar=True) opt_init, opt_update, get_params = optimizers.adam(0.001) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden]) svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params) opt_state, _ = svi_init(rng_init) def body_fn(val, i): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_) return (opt_state_, rng_), loss print("Start training guide...") (last_state, _), losses = lax.scan(body_fn, (opt_state, rng_train), np.arange(args.num_iters)) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state, sample_shape=(args.num_samples,)) transform = guide.get_transform(last_state) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn) print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()