def test_gaussian_hmm(): dim = 4 num_steps = 10 def model(data): with numpyro.plate("states", dim): transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim))) emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1)) emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1)) trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim))) for t, y in markov(enumerate(data)): x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob)) numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) trans_prob = transition[x] def _generate_data(): transition_probs = np.random.rand(dim, dim) transition_probs = transition_probs / transition_probs.sum(-1, keepdims=True) emissions_loc = np.arange(dim) emissions_scale = 1. state = np.random.choice(3) obs = [np.random.normal(emissions_loc[state], emissions_scale)] for _ in range(num_steps - 1): state = np.random.choice(dim, p=transition_probs[state]) obs.append(np.random.normal(emissions_loc[state], emissions_scale)) return np.stack(obs) data = _generate_data() nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(0), data)
def test_change_point_x64(): # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696 warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(4), count_data) samples = mcmc.get_samples() tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32) tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_x64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64
def test_change_point(): def model(count_data): n_count_data = count_data.shape[0] alpha = 1 / jnp.mean(count_data) lambda_1 = numpyro.sample('lambda_1', dist.Exponential(alpha)) lambda_2 = numpyro.sample('lambda_2', dist.Exponential(alpha)) # this is the same as DiscreteUniform(0, 69) tau = numpyro.sample('tau', dist.Categorical(logits=jnp.zeros(70))) idx = jnp.arange(n_count_data) lambda_ = jnp.where(tau > idx, lambda_1, lambda_2) with numpyro.plate("data", n_count_data): numpyro.sample('obs', dist.Poisson(lambda_), obs=count_data) count_data = jnp.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 1, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(0), count_data) samples = mcmc.get_samples() assert_allclose(samples["lambda_1"].mean(0), 18., atol=1.) assert_allclose(samples["lambda_2"].mean(0), 23., atol=1.)
def test_spire_model(self): nuts_kernel = NUTS(SPIRE.spire_model) mcmc = MCMC(nuts_kernel,num_samples=100,num_warmup=100) rng_key = random.PRNGKey(0) mcmc.run(rng_key,self.priors ) posterior_samples = mcmc.get_samples() self.assertIsNotNone(posterior_samples) self.assertEqual(posterior_samples['src_f'].shape[1], self.priors[0].nsrc)
def run_inference(model, args, rng_key, X, Y, D_H): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng_key, X, Y, D_H) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def run_inference(model, at_bats, hits, rng_key, args): kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def test_compile_warmup_run(num_chains, chain_method, progress_bar): def model(): numpyro.sample("x", dist.Normal(0, 1)) if num_chains == 1 and chain_method in ["sequential", "vectorized"]: pytest.skip("duplicated test") if num_chains > 1 and chain_method == "parallel": pytest.skip("duplicated test") rng_key = random.PRNGKey(0) num_samples = 10 mcmc = MCMC( NUTS(model), num_warmup=10, num_samples=num_samples, num_chains=num_chains, chain_method=chain_method, progress_bar=progress_bar, ) mcmc.run(rng_key) expected_samples = mcmc.get_samples()["x"] mcmc._compile(rng_key) # no delay after compiling mcmc.warmup(rng_key) mcmc.run(mcmc.last_state.rng_key) actual_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples, expected_samples) # test for reproducible if num_chains > 1: mcmc = MCMC( NUTS(model), num_warmup=10, num_samples=num_samples, num_chains=1, progress_bar=progress_bar, ) rng_key = random.split(rng_key)[0] mcmc.run(rng_key) first_chain_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
def test_empty_model(num_chains, chain_method, progress_bar): def model(): pass mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=num_chains, chain_method=chain_method, progress_bar=progress_bar) mcmc.run(random.PRNGKey(0)) assert mcmc.get_samples() == {}
def test_forward_mode_differentiation(): def model(): x = numpyro.sample("x", dist.Normal(0, 1)) y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x) numpyro.sample("obs", dist.Normal(y, 1), obs=1.) # this fails in reverse mode mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10) mcmc.run(random.PRNGKey(0))
def test_loose_warning_for_missing_plate(): def model(): x = numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("N", 10): numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones((5, 10))) mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) with pytest.warns(UserWarning, match="Missing a plate statement"): mcmc.run(random.PRNGKey(1))
def run_inference(model, args, rng_key, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def test_discrete_gibbs_bernoulli(random_walk, modified): def model(): numpyro.sample("c", dist.Bernoulli(0.8)) kernel = DiscreteHMCGibbs(NUTS(model), random_walk=random_walk, modified=modified) mcmc = MCMC(kernel, 1000, 200000, progress_bar=False) mcmc.run(random.PRNGKey(0)) samples = mcmc.get_samples()["c"] assert_allclose(jnp.mean(samples), 0.8, atol=0.05)
def test_improper_uniform(): def model(): numpyro.sample("c", dist.Bernoulli(0.8)) numpyro.sample( "u", dist.ImproperUniform(dist.constraints.unit_interval, (), ())) sampler = DiscreteHMCGibbs(NUTS(model)) mcmc = MCMC(sampler, num_warmup=10, num_samples=10, progress_bar=False) mcmc.run(random.PRNGKey(0))
def test_model_with_lift_handler(): def model(data): c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive) x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data) return x nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,))))
def test_sites_have_unique_names(): def model(): alpha = numpyro.sample("alpha", dist.Normal()) numpyro.deterministic("alpha", alpha * 2) mcmc = MCMC(NUTS(model), num_chains=1, num_samples=10, num_warmup=10) msg = "all sites must have unique names but got `alpha` duplicated" with pytest.raises(AssertionError, match=msg): mcmc.run(random.PRNGKey(0))
def test_random_module_mcmc(backend, init): if backend == "flax": import flax linear_module = flax.linen.Dense(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module kwargs_name = "inputs" elif backend == "haiku": import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(1)(x)) bias_name = "linear.b" weight_name = "linear.w" random_module = random_haiku_module kwargs_name = "x" N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1.0, dim + 1.0) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) if init == "shape": kwargs = {"input_shape": (3,)} elif init == "kwargs": kwargs = {kwargs_name: data} def model(data, labels): nn = random_module( "nn", linear_module, {bias_name: dist.Cauchy(), weight_name: dist.Normal()}, **kwargs ) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() assert set(samples.keys()) == { "nn/{}".format(bias_name), "nn/{}".format(weight_name), } assert_allclose( np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22, )
def test_init_strategy_substituted_model(): def model(): numpyro.sample("x", dist.Normal(0, 1)) numpyro.sample("y", dist.Normal(0, 1)) subs_model = numpyro.handlers.substitute(model, data={"x": 10.0}) mcmc = MCMC(NUTS(subs_model), num_warmup=10, num_samples=10) with pytest.warns(UserWarning, match="skipping initialization"): mcmc.run(random.PRNGKey(1))
def benchmark_hmc(args, features, labels): step_size = np.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps rng_key = random.PRNGKey(1) start = time.time() kernel = NUTS(model, trajectory_length=trajectory_length) mcmc = MCMC(kernel, 0, args.num_samples) mcmc.run(rng_key, features, labels) print('\nMCMC elapsed time:', time.time() - start)
def run_inference(self, args, rng_key, x_train, y_train, num_hidden): if args['num_chains'] > 1: rng_key = random.split(rng_key, args['num_chains']) kernel = NUTS(self.bnn_model) mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=args['num_chains']) mcmc.run(rng_key, x_train, y_train, num_hidden) return mcmc.get_samples()
def run_inference(dept, male, applications, admit, rng_key, args): kernel = NUTS(glmm) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, dept, male, applications, admit) return mcmc.get_samples()
def run_inference(model, capture_history, sex, rng_key, args): if args.algo == "NUTS": kernel = NUTS(model) elif args.algo == "HMC": kernel = HMC(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, capture_history, sex) mcmc.print_summary() return mcmc.get_samples()
def test_trivial_dirichlet(batch_shape): def model(): x = numpyro.sample("x", dist.Dirichlet(jnp.ones(1)).expand(batch_shape)) return numpyro.sample("y", dist.Normal(x, 1), obs=2) num_samples = 10 mcmc = MCMC(NUTS(model), 10, num_samples) mcmc.run(random.PRNGKey(0)) # because event_shape of x is (1,), x should only take value 1 assert_allclose(mcmc.get_samples()["x"], jnp.ones((num_samples,) + batch_shape + (1,)))
def __init__(self, model, data=None): self.data = data self.num_warmup = 1000 self.num_samples = 2000 self.num_chains = 4 self.mcmc = MCMC(NUTS(model), num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains) self.data = data
def run_inference(model, at_bats, hits, rng_key, args): kernel = NUTS(model) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def test_model_with_mask_false(): def model(): x = numpyro.sample("x", dist.Normal()) with numpyro.handlers.mask(mask=False): numpyro.sample("y", dist.Normal(x), obs=1) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1) mcmc.run(random.PRNGKey(1)) assert_allclose(mcmc.get_samples()['x'].mean(), 0., atol=0.1)
def test_mcmc_parallel_chain(deterministic): GLOBAL["count"] = 0 mcmc = MCMC(NUTS(model), 100, 100, num_chains=2) mcmc.run(random.PRNGKey(0), deterministic=deterministic) mcmc.get_samples() if deterministic: assert GLOBAL["count"] == 4 else: assert GLOBAL["count"] == 3
def test_discrete_gibbs_multiple_sites(): def model(): numpyro.sample("x", dist.Bernoulli(0.7).expand([3])) numpyro.sample("y", dist.Binomial(10, 0.3)) kernel = DiscreteHMCGibbs(NUTS(model)) mcmc = MCMC(kernel, 1000, 10000, progress_bar=False) mcmc.run(random.PRNGKey(0)) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01) assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
def run_inference(model, args, rng_key): kernel = NUTS(model) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key) mcmc.print_summary(exclude_deterministic=False) return mcmc.get_samples()
def run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs): kernel = NUTS(model, init_strategy=init_to_value(values=bvm_init_locs), max_tree_depth=7) mcmc = MCMC(kernel, num_samples=args.num_samples, num_warmup=args.num_warmup) mcmc.run(rng_key, data, len(data), num_mix_comp) mcmc.print_summary() post_samples = mcmc.get_samples() return post_samples
def test_enum_subsample_smoke(): def model(data): x = numpyro.sample("x", dist.Bernoulli(0.5)) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1): batch = numpyro.subsample(data, event_dim=0) numpyro.sample("obs", dist.Normal(x, 1), obs=batch) data = random.normal(random.PRNGKey(0), (10000, )) + 1 kernel = HMCECS(NUTS(model), num_blocks=10) mcmc = MCMC(kernel, 10, 10) mcmc.run(random.PRNGKey(0), data)