def test_beta_bernoulli(): from numpyro.contrib.tfp import distributions as dist warmup_steps, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2)) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 warmup_steps, num_samples = 1000, 8000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = kernel_cls(model=model, trajectory_length=8) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), labels) samples = mcmc.get_samples() assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) ref_params = {'theta': params['theta_auto_loc']} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1., 0.5 warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000) def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std) ** 2) init_params = np.array(0.) if kernel_cls is SA: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) else: kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert hmc_states.dtype == np.float64
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = np.dot(a, a.T) true_prec = np.linalg.inv(true_cov) def potential_fn(z): return 0.5 * np.dot(z.T, np.dot(true_prec, z)) init_params = np.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(np.mean(samples), true_mean, atol=0.02) assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_binomial_stable_x64(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p'].dtype == jnp.float64
def test_inference_data_constant_data(self): import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS x1 = 10 x2 = 12 y1 = np.random.randn(10) def model_constant_data(x, y1=None): _x = numpyro.sample("x", dist.Normal(1, 3)) numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1) nuts_kernel = NUTS(model_constant_data) mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) mcmc.run(PRNGKey(0), x=x1, y1=y1) posterior = mcmc.get_samples() posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1) predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2) inference_data = from_numpyro( mcmc, posterior_predictive=posterior_predictive, predictions=predictions, constant_data={"x1": x1}, predictions_constant_data={"x2": x2}, ) test_dict = { "posterior": ["x"], "posterior_predictive": ["y1"], "sample_stats": ["diverging"], "log_likelihood": ["y1"], "predictions": ["y1"], "observed_data": ["y1"], "constant_data": ["x1"], "predictions_constant_data": ["x2"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails
def main(args): _, fetch = load_dataset(LYNXHARE, shuffle=False) year, data = fetch() # data is in hare -> lynx order # use dense_mass for better mixing rate mcmc = MCMC( NUTS(model, dense_mass=True), num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(PRNGKey(1), N=data.shape[0], y=data) mcmc.print_summary() # predict populations pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] mu = jnp.mean(pop_pred, 0) pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0) plt.figure(figsize=(8, 6), constrained_layout=True) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend() plt.savefig("ode_plot.pdf")
def test_beta_bernoulli(): from tensorflow_probability.substrates.jax import distributions as tfd num_warmup, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta)) numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = tfd.Bernoulli(true_probs).sample( seed=random.PRNGKey(1), sample_shape=(1000, 2) ) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
def test_logistic_regression(): from numpyro.contrib.tfp import distributions as dist N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)
def test_dirichlet_categorical_x64(kernel_cls, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration)) numpyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = jnp.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) if kernel_cls is BarkerMH: kernel = BarkerMH(model=model, dense_mass=dense_mass) else: kernel = kernel_cls(model, trajectory_length=1.0, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02) if "JAX_ENABLE_X64" in os.environ: assert samples["p_latent"].dtype == jnp.float64
def test_gaussian_mixture_model(): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) with numpyro.plate("num_clusters", K, dim=-1): cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) true_cluster_means = jnp.array([1., 5., 10.]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(random.PRNGKey(0), (N,)) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(samples["phi"].mean(0).sort(), true_mix_proportions, atol=0.05) assert_allclose(samples["cluster_means"].mean(0).sort(), true_cluster_means, atol=0.2)
def run_inference(self, model, rng_key, X, Y, hypers, num_warmup=500, num_chains=1, num_samples=1000): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y, hypers) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: num_warmup, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: num_warmup, num_samples = (2000, 12000) else: num_warmup, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls( model=model, trajectory_length=8, find_heuristic_step_size=True ) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples["logits"].shape == (num_samples, N) # those coefficients are found by doing MAP inference using AutoDelta expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1) if "JAX_ENABLE_X64" in os.environ: assert samples["coefs"].dtype == jnp.float64
def main(m, seed): rng_key = random.PRNGKey(seed) cutoff_up = 800 cutoff_down = 100 if m <= 10: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m) else: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10) model = generative_model nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) seq = (seq['beliefs'][0][cutoff_down:cutoff_up], seq['beliefs'][1][cutoff_down:cutoff_up]) rng_key, _rng_key = random.split(rng_key) mcmc.run( _rng_key, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool), extra_fields=('potential_energy',) ) samples = mcmc.get_samples() waic = log_pred_density( model, samples, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool) )['waic'] jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic) print(mcmc.get_extra_fields()['potential_energy'].mean())
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
def test_binomial_stable_x64(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 num_warmup, num_samples = 200, 200 def model(data): p = numpyro.sample("p", dist.Beta(1.0, 1.0)) if with_logits: logits = logit(p) numpyro.sample( "obs", dist.Binomial(data["n"], logits=logits), obs=data["x"] ) else: numpyro.sample("obs", dist.Binomial(data["n"], probs=p), obs=data["x"]) data = {"n": 5000000, "x": 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p"], 0), data["x"] / data["n"], rtol=0.05) if "JAX_ENABLE_X64" in os.environ: assert samples["p"].dtype == jnp.float64
def test_logistic_regression(): from tensorflow_probability.substrates.jax import distributions as tfd N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = tfd.Bernoulli(logits=logits).sample(seed=random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", tfd.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples["logits"].shape == (num_samples, N) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.22)
def test_beta_bernoulli_x64(kernel_cls): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) kernel = kernel_cls(model=model, trajectory_length=1.) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def main(args): is_sphinxbuild = "NUMPYRO_SPHINXBUILD" in os.environ data = load_data() data_dict = make_birthdays_data_dict(data) mcmc = MCMC( NUTS(birthdays_model, init_strategy=init_to_median), num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=(not is_sphinxbuild), ) mcmc.run(jax.random.PRNGKey(0), **data_dict) if not is_sphinxbuild: mcmc.print_summary() if args.save_figure: samples = mcmc.get_samples() print(f"Saving figure at {args.save_figure}") fig = make_figure(data, samples) fig.savefig(args.save_figure) plt.close() return mcmc
def test_random_module_mcmc(backend): if backend == "flax": import flax linear_module = flax.nn.Dense.partial(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module elif backend == "haiku": import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(1)(x)) bias_name = "linear.b" weight_name = "linear.w" random_module = random_haiku_module def model(data, labels): nn = random_module("nn", linear_module, prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()}, input_shape=(dim,)) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) N, dim = 3000, 3 warmup_steps, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() assert set(samples.keys()) == {"nn/{}".format(bias_name), "nn/{}".format(weight_name)} assert_allclose(np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22)
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print_results(samples, transition_prob, emission_prob) print('\nMCMC elapsed time:', time.time() - start) # make plots fig, ax = plt.subplots(1, 1) x = onp.linspace(0, 1, 101) for i in range(transition_prob.shape[0]): for j in range(transition_prob.shape[1]): ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x), label="trans_prob[{}, {}], true value = {:.2f}" .format(i, j, transition_prob[i, j])) ax.set(xlabel="Probability", ylabel="Frequency", title="Transition probability posterior") ax.legend() plt.savefig("hmm_plot.pdf") plt.tight_layout()
def test_dense_mass(kernel_cls, rho): num_warmup, num_samples = 20000, 10000 true_cov = jnp.array([[10.0, rho], [rho, 0.1]]) def model(): numpyro.sample( "x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)) if kernel_cls is HMC or kernel_cls is NUTS: kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True) elif kernel_cls is BarkerMH: kernel = BarkerMH(model, dense_mass=True) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0)) mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt if kernel_cls is HMC or kernel_cls is NUTS: mass_matrix_sqrt = mass_matrix_sqrt[("x", )] mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt)) estimated_cov = jnp.linalg.inv(mass_matrix) assert_allclose(estimated_cov, true_cov, rtol=0.10) samples = mcmc.get_samples()["x"] assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50) assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05) assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20) assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
def inference(belief_sequences, obs, mask, rng_key): nuts_kernel = NUTS(model, dense_mass=True) mcmc = MCMC( nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False ) mcmc.run( rng_key, belief_sequences, obs, mask, extra_fields=('potential_energy',) ) samples = mcmc.get_samples() potential_energy = mcmc.get_extra_fields()['potential_energy'].mean() # mcmc.print_summary() return samples, potential_energy
class NutsHandler(Handler): def __init__( self, model, num_warmup=2000, num_samples=10000, num_chains=1, rng_key=0, to_numpy: bool = True, *args, **kwargs, ): self.model = model self.num_warmup = num_warmup self.num_samples = num_samples self.num_chains = num_chains self.rng_key, self.rng_key_ = random.split(random.PRNGKey(rng_key)) self.to_numpy = to_numpy self.kernel = NUTS(model, **kwargs) self.mcmc = MCMC(self.kernel, num_warmup, num_samples, num_chains=num_chains) def predict(self, *args, **kwargs): predictive = Predictive(self.model, self.posterior.data, **kwargs) self.predictive = Posterior(predictive(self.rng_key_, *args)) def fit(self, *args, **kwargs): self.num_samples = kwargs.get("num_samples", self.num_samples) self.mcmc.run(self.rng_key_, *args, **kwargs) self.posterior = Posterior(self.mcmc.get_samples(), self.to_numpy) def summary(self, *args, **kwargs): self.mcmc.print_summary(*args, **kwargs)
def infer(self, num_warmup=1000, num_samples=1000, num_chains=1, rng_key=PRNGKey(1), **args): '''Fit using MCMC''' args = dict(self.args, **args) kernel = NUTS(self, init_strategy=numpyro.infer.util.init_to_median()) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.run(rng_key, **self.obs, **args) mcmc.print_summary() self.mcmc = mcmc self.mcmc_samples = mcmc.get_samples() return self.mcmc_samples
def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = .1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (10_000, )) n, _ = data.shape num_warmup = 200 num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample( 'mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) with numpyro.plate('N', data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = HMCECS.taylor_proxy({'mean': ref_params}) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(0), data, extra_fields=['hmc_state.potential_energy']) pes = mcmc.get_extra_fields()['hmc_state.potential_energy'] samples = mcmc.get_samples() pes_full = vmap(lambda sample: log_density(model, (data, ), {}, { **sample, **{ 'N': jnp.arange(n) } })[0])(samples) assert jnp.var(jnp.exp(-pes - pes_full)) < 1.
def sample_posterior_gibbs(rng_key: random.PRNGKey, model, data: np.ndarray, Nsamples: int = 1000, alpha: float = 1, sigma: float = 0, T: int = 10, gibbs_fn=None, gibbs_sites=None): assert gibbs_fn is not None assert gibbs_sites is not None Npoints = len(data) inner_kernel = NUTS(model) kernel = HMCGibbs(inner_kernel, gibbs_fn=gibbs_fn, gibbs_sites=gibbs_sites) mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP) mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T) samples = mcmc.get_samples() z = samples['z'] assert z.shape == (Nsamples, Npoints) return z
def run_inference( model: Callable, at_bats: jnp.ndarray, hits: jnp.ndarray, rng_key: jnp.ndarray, *, num_warmup: int = 1500, num_samples: int = 3000, num_chains: int = 1, algo_name: str = "NUTS", ) -> Dict[str, jnp.ndarray]: if algo_name == "NUTS": kernel = NUTS(model) elif algo_name == "HMC": kernel = HMC(model) elif algo_name == "SA": kernel = SA(model) else: raise ValueError("Unknown algorithm name") mcmc = MCMC(kernel, num_warmup, num_samples, num_chains) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def main() -> None: df = load_dataset() test_index = 80 test_len = len(df) - test_index y_train = jnp.array(df.loc[:test_index, "value"], dtype=jnp.float32) # Inference kernel = NUTS(sgt) mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1) mcmc.run(random.PRNGKey(0), y_train, seasonality=38) mcmc.print_summary() posterior_samples = mcmc.get_samples() # Prediction predictive = Predictive(sgt, posterior_samples, return_sites=["y_forecast"]) posterior_predictive = predictive(random.PRNGKey(1), y_train, seasonality=38, future=test_len) root = pathlib.Path("./data/time_series") root.mkdir(exist_ok=True) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) plot_results( df["time"].values, df["value"].values, posterior_samples, posterior_predictive, test_index, root, )