def run_inference(model, args, rng_key, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC( kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(rng_key, X, Y) mcmc.print_summary(exclude_deterministic=False) samples = mcmc.get_samples() summary_dict = summary(samples, group_by_chain=False) print("\nMCMC elapsed time:", time.time() - start) return summary_dict
def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1., 0.5 warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_params = np.array(0.) kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert hmc_states.dtype == np.float64
def test_predictive_with_improper(): true_coef = 0.9 def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"] assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
def test_improper_normal(max_tree_depth): true_coef = 0.9 def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha) ), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model, max_tree_depth=max_tree_depth) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
def test_reuse_mcmc_run(jit_args, shape): y1 = np.random.normal(3, 0.1, (100, )) y2 = np.random.normal(-3, 0.1, (shape, )) def model(y_obs): mu = numpyro.sample("mu", dist.Normal(0.0, 1.0)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.0)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) # Run MCMC on zero observations. kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=300, num_samples=500, jit_model_args=jit_args) mcmc.run(random.PRNGKey(32), y1) # Re-run on new data - should be much faster. mcmc.run(random.PRNGKey(32), y2) assert_allclose(mcmc.get_samples()["mu"].mean(), -3.0, atol=0.1)
def test_predictive(parallel): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive = Predictive(model, samples, parallel=parallel) predictive_samples = predictive(random.PRNGKey(1)) assert predictive_samples.keys() == {"beta_sq", "obs"} predictive.return_sites = ["beta", "beta_sq", "obs"] predictive_samples = predictive(random.PRNGKey(1)) # check shapes assert predictive_samples["beta"].shape == (100, ) + true_probs.shape assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape assert predictive_samples["obs"].shape == (100, ) + data.shape # check sample mean assert_allclose( predictive_samples["obs"].reshape((-1, ) + true_probs.shape).mean(0), true_probs, rtol=0.1)
def run_inference(model, args, rng_key, X, Y): start = time.time() # demonstrate how to use different HMC initialization strategies if args.init_strategy == "value": init_strategy = init_to_value(values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5}) elif args.init_strategy == "median": init_strategy = init_to_median(num_samples=10) elif args.init_strategy == "feasible": init_strategy = init_to_feasible() elif args.init_strategy == "sample": init_strategy = init_to_sample() elif args.init_strategy == "uniform": init_strategy = init_to_uniform(radius=1) kernel = NUTS(model, init_strategy=init_strategy) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, thinning=args.thinning, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
class Sampler(): def __init__(self, model, data=None): self.data = data self.num_warmup = 1000 self.num_samples = 2000 self.num_chains = 4 self.mcmc = MCMC(NUTS(model), num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains) self.data = data def fit(self, data): self.data = data self.mcmc.run(random.PRNGKey(0), **data) self.post = self.mcmc.get_samples() return self.post # posterior samples def predict(self, data): pass
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs): def model(probs, locs): c = numpyro.sample("c", dist.Categorical(probs)) numpyro.sample("x", dist.Normal(locs[c], 0.5)) probs = jnp.array([0.15, 0.3, 0.3, 0.25]) locs = jnp.array([-2, 0, 2, 4]) sampler = kernel(inner_kernel(model, trajectory_length=1.2), modified=modified, **kwargs) mcmc = MCMC(sampler, num_warmup=1000, num_samples=200000, progress_bar=False) mcmc.run(random.PRNGKey(0), probs, locs) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1) assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4) assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1) assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
def test_beta_bernoulli(): from numpyro.contrib.tfp import distributions as dist warmup_steps, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2)) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
def test_bernoulli_latent_model(): @config_enumerate def model(data): y_prob = numpyro.sample("y_prob", dist.Beta(1.0, 1.0)) with numpyro.plate("data", data.shape[0]): y = numpyro.sample("y", dist.Bernoulli(y_prob)) z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) numpyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data) N = 2000 y_prob = 0.3 y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N, )) z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1)) data = dist.Normal(2.0 * z, 1.0).sample(random.PRNGKey(2)) nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples['loc']) == num_warmup assert len(samples['loc']) == num_samples assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_beta_bernoulli_x64(kernel_cls): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) kernel = kernel_cls(model=model, trajectory_length=1.) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def sample( model, num_samples, num_warmup, num_chains=2, seed=0, chain_method="parallel", summary=True, **kwargs, ): """Run the No-U-Turn sampler Args: model: an NumPyro model function num_samples: number of samples to draw in each chain num_warmup: number of samples to use for tuning in each chain num_chains: number of chains to draw (default: {2}) **kwargs: other arguments to be passed to the model function seed: random seed (default: {0}) chain_method: one of NumPyro's sampling methods — "parallel" / "sequential" / "vectorized" (default: {"parallel"}) summary: print diagnostics, including the Effective sample size and the Gelman-Rubin test (default: {True}) Returns: mcmc: A fitted MCMC object """ rng_key = random.PRNGKey(seed) kernel = NUTS(model) # Note: sampling more than one chain doesn't show a progress bar mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, chain_method=chain_method) mcmc.run(rng_key, **kwargs) if summary: mcmc.print_summary() return mcmc
def test_scan(): def model(T=10, q=1, r=1, phi=0., beta=0.): def transition(state, i): x0, mu0 = state x1 = numpyro.sample('x', dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample('y', dist.Normal(mu1, r)) numpyro.deterministic('y2', y1 * 2) return (x1, mu1), (x1, y1) mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q)) y0 = numpyro.sample('y_0', dist.Normal(mu0, r)) _, xy = scan(transition, (x0, mu0), jnp.arange(T)) x, y = xy return jnp.append(x0, x), jnp.append(y0, y) T = 10 num_samples = 100 kernel = NUTS(model) mcmc = MCMC(kernel, 100, num_samples) mcmc.run(jax.random.PRNGKey(0), T=T) assert set(mcmc.get_samples()) == {'x', 'y', 'y2', 'x_0', 'y_0'} mcmc.print_summary() samples = mcmc.get_samples() x = samples.pop('x')[0] # take 1 sample of x # this tests for the composition of condition and substitute # this also tests if we can use `vmap` for predictive. future = 5 predictive = Predictive(numpyro.handlers.condition(model, {'x': x}), samples, return_sites=['x', 'y', 'y2'], parallel=True) result = predictive(jax.random.PRNGKey(1), T=T + future) expected_shape = (num_samples, T + future) assert result['x'].shape == expected_shape assert result['y'].shape == expected_shape assert result['y2'].shape == expected_shape assert_allclose(result['x'][:, :T], jnp.broadcast_to(x, (num_samples, T))) assert_allclose(result['y'][:, :T], samples['y'])
def test_binomial_stable_x64(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p'].dtype == jnp.float64
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print('\nMCMC elapsed time:', time.time() - start) print_results(samples, transition_prob, emission_prob)
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = jnp.dot(a, a.T) true_prec = jnp.linalg.inv(true_cov) def potential_fn(z): return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z)) init_params = jnp.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples), true_mean, atol=0.02) assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_beta_bernoulli(): from tensorflow_probability.substrates.jax import distributions as tfd num_warmup, num_samples = (500, 2000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample("p_latent", tfd.Beta(alpha, beta)) numpyro.sample("obs", tfd.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = tfd.Bernoulli(true_probs).sample(seed=random.PRNGKey(1), sample_shape=(1000, 2)) kernel = NUTS(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) ref_params = {'theta': params['theta_auto_loc']} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 warmup_steps, num_samples = 1000, 8000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = kernel_cls(model=model, trajectory_length=8) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), labels) samples = mcmc.get_samples() assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64
def test_mcmc_progbar(): true_mean, true_std = 1., 2. num_warmup, num_samples = 10, 10 def model(data): mean = numpyro.param('mean', 0.) std = numpyro.param('std', 1., constraint=constraints.positive) return numpyro.sample('obs', dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(2), data) mcmc.run(random.PRNGKey(3), data) mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False) mcmc1.run(random.PRNGKey(2), data) with pytest.raises(AssertionError): check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) mcmc1.warmup(random.PRNGKey(2), data) mcmc1.run(random.PRNGKey(3), data) check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)
def test_inference_data_constant_data(self): import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS x1 = 10 x2 = 12 y1 = np.random.randn(10) def model_constant_data(x, y1=None): _x = numpyro.sample("x", dist.Normal(1, 3)) numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1) nuts_kernel = NUTS(model_constant_data) mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) mcmc.run(PRNGKey(0), x=x1, y1=y1) posterior = mcmc.get_samples() posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1) predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2) inference_data = from_numpyro( mcmc, posterior_predictive=posterior_predictive, predictions=predictions, constant_data={"x1": x1}, predictions_constant_data={"x2": x2}, ) test_dict = { "posterior": ["x"], "posterior_predictive": ["y1"], "sample_stats": ["diverging"], "log_likelihood": ["y1"], "predictions": ["y1"], "observed_data": ["y1"], "constant_data": ["x1"], "predictions_constant_data": ["x2"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails
def test_structured_mass(): def model(cov): w = numpyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1)) x = numpyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1)) y = numpyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1)) z = numpyro.sample("z", dist.Normal(0, 1000).expand([1]).to_event(1)) wxyz = jnp.concatenate([w, x, y, z]) numpyro.sample("obs", dist.MultivariateNormal(jnp.zeros(5), cov), obs=wxyz) w_cov = np.array([[1.5, 0.5], [0.5, 1.5]]) xy_cov = np.array([[2.0, 1.0], [1.0, 3.0]]) z_var = np.array([2.5]) cov = np.zeros((5, 5)) cov[:2, :2] = w_cov cov[2:4, 2:4] = xy_cov cov[4, 4] = z_var kernel = NUTS(model, dense_mass=[("w", ), ("x", "y")]) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1) mcmc.run(random.PRNGKey(1), cov) inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix assert_allclose(inverse_mass_matrix[("w", )], w_cov, atol=0.5, rtol=0.5) assert_allclose(inverse_mass_matrix[("x", "y")], xy_cov, atol=0.5, rtol=0.5) assert_allclose(inverse_mass_matrix[("z", )], z_var, atol=0.5, rtol=0.5) kernel = NUTS(model, dense_mass=[("w", ), ("y", "x")]) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1) mcmc.run(random.PRNGKey(1), cov) inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix assert_allclose(inverse_mass_matrix[("w", )], w_cov, atol=0.5, rtol=0.5) assert_allclose(inverse_mass_matrix[("y", "x")], xy_cov[::-1, ::-1], atol=0.5, rtol=0.5) assert_allclose(inverse_mass_matrix[("z", )], z_var, atol=0.5, rtol=0.5)
def test_model_with_multiple_exec_paths(jit_args): def model(a=None, b=None, z=None): int_term = numpyro.sample("a", dist.Normal(0.0, 0.2)) x_term, y_term = 0.0, 0.0 if a is not None: x = numpyro.sample("x", dist.HalfNormal(0.5)) x_term = a * x if b is not None: y = numpyro.sample("y", dist.HalfNormal(0.5)) y_term = b * y sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = int_term + x_term + y_term numpyro.sample("obs", dist.Normal(mu, sigma), obs=z) a = jnp.exp(np.random.randn(10)) b = jnp.exp(np.random.randn(10)) z = np.random.randn(10) # Run MCMC on zero observations. kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=20, num_samples=10, jit_model_args=jit_args) mcmc.run(random.PRNGKey(1), a, b=None, z=z) assert set(mcmc.get_samples()) == {"a", "x", "sigma"} mcmc.run(random.PRNGKey(2), a=None, b=b, z=z) assert set(mcmc.get_samples()) == {"a", "y", "sigma"} mcmc.run(random.PRNGKey(3), a=a, b=b, z=z) assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"}
def test_mcmc_progbar(): true_mean, true_std = 1.0, 2.0 num_warmup, num_samples = 10, 10 def model(data): mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False)) std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False)) return numpyro.sample("obs", dist.Normal(mean, std), obs=data) data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data) mcmc.run(random.PRNGKey(3), data) mcmc1 = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc1.run(random.PRNGKey(2), data) with pytest.raises(AssertionError): check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) mcmc1.warmup(random.PRNGKey(2), data) mcmc1.run(random.PRNGKey(3), data) check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4) check_close(mcmc1.post_warmup_state, mcmc.post_warmup_state, atol=1e-4, rtol=1e-4)
def test_compile_warmup_run(num_chains, chain_method, progress_bar): def model(): numpyro.sample("x", dist.Normal(0, 1)) if num_chains == 1 and chain_method in ['sequential', 'vectorized']: pytest.skip('duplicated test') if num_chains > 1 and chain_method == 'parallel': pytest.skip('duplicated test') rng_key = random.PRNGKey(0) num_samples = 10 mcmc = MCMC(NUTS(model), 10, num_samples, num_chains, chain_method=chain_method, progress_bar=progress_bar) mcmc.run(rng_key) expected_samples = mcmc.get_samples()["x"] mcmc._compile(rng_key) # no delay after compiling mcmc.warmup(rng_key) mcmc.run(mcmc._warmup_state.rng_key) actual_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples, expected_samples) # test for reproducible if num_chains > 1: mcmc = MCMC(NUTS(model), 10, num_samples, 1, progress_bar=progress_bar) rng_key = random.split(rng_key)[0] mcmc.run(rng_key) first_chain_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
def main(args): _, fetch = load_dataset(LYNXHARE, shuffle=False) year, data = fetch() # data is in hare -> lynx order # use dense_mass for better mixing rate mcmc = MCMC( NUTS(model, dense_mass=True), num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(PRNGKey(1), N=data.shape[0], y=data) mcmc.print_summary() # predict populations pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] mu = jnp.mean(pop_pred, 0) pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0) plt.figure(figsize=(8, 6), constrained_layout=True) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend() plt.savefig("ode_plot.pdf")
def test_gaussian_mixture_model(): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) with numpyro.plate("num_clusters", K, dim=-1): cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) true_cluster_means = jnp.array([1., 5., 10.]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(random.PRNGKey(0), (N,)) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(samples["phi"].mean(0).sort(), true_mix_proportions, atol=0.05) assert_allclose(samples["cluster_means"].mean(0).sort(), true_cluster_means, atol=0.2)
def test_logistic_regression(): from numpyro.contrib.tfp import distributions as dist N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)