def test_predictive_with_improper(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample( 'loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"] assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, args.unroll_loop) samples = mcmc.get_samples() print_results(samples, transition_prob, emission_prob) print('\nMCMC elapsed time:', time.time() - start) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) x = np.linspace(0, 1, 101) for i in range(transition_prob.shape[0]): for j in range(transition_prob.shape[1]): ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x), label="trans_prob[{}, {}], true value = {:.2f}".format( i, j, transition_prob[i, j])) ax.set(xlabel="Probability", ylabel="Frequency", title="Transition probability posterior") ax.legend() plt.savefig("hmm_plot.pdf")
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray, rng_key: jnp.ndarray, num_warmup: int, num_samples: int, num_chains: int, interval_size: float = 0.95) -> None: """ Estimate the effect size. """ kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, design_matrix, outcome) # 0th column is intercept (not getting called) # 1st column is effect of getting called # 2nd column is effect of gender (should be none since assigned at random) coef = mcmc.get_samples()['coefficients'] print_results(coef, interval_size)
def test_prior_with_sample_shape(): data = { "J": 8, "y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]), "sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]), } def schools_model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'], )) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) num_samples = 500 mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples) mcmc.run(random.PRNGKey(0)) assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
def test_improper_normal(max_tree_depth): true_coef = 0.9 def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model, max_tree_depth=max_tree_depth) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
def test_bernoulli_latent_model(): def model(data): y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.)) with numpyro.plate("data", data.shape[0]): y = numpyro.sample("y", dist.Bernoulli(y_prob)) z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data) N = 2000 y_prob = 0.3 y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N, )) z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1)) data = dist.Normal(2. * z, 1.0).sample(random.PRNGKey(2)) nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
def test_initial_inverse_mass_matrix(dense_mass): def model(): numpyro.sample("x", dist.Normal(0, 1).expand([3])) numpyro.sample("z", dist.Normal(0, 1).expand([2])) expected_mm = jnp.arange(1, 4.0) kernel = NUTS( model, dense_mass=dense_mass, inverse_mass_matrix={("x", ): expected_mm}, adapt_mass_matrix=False, ) mcmc = MCMC(kernel, num_warmup=1, num_samples=1) mcmc.run(random.PRNGKey(0)) inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix assert set(inverse_mass_matrix.keys()) == {("x", ), ("z", )} expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm assert_allclose(inverse_mass_matrix[("x", )], expected_mm) assert_allclose(inverse_mass_matrix[("z", )], jnp.ones(2))
def test_predictive(parallel): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive = Predictive(model, samples, parallel=parallel) predictive_samples = predictive(random.PRNGKey(1)) assert predictive_samples.keys() == {"beta_sq", "obs"} predictive.return_sites = ["beta", "beta_sq", "obs"] predictive_samples = predictive(random.PRNGKey(1)) # check shapes assert predictive_samples["beta"].shape == (100, ) + true_probs.shape assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape assert predictive_samples["obs"].shape == (100, ) + data.shape # check sample mean obs = predictive_samples["obs"].reshape((-1, ) + true_probs.shape).astype( np.float32) assert_allclose(obs.mean(0), true_probs, rtol=0.1)
def test_dirichlet_categorical_x64(kernel_cls, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = jnp.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == jnp.float64
def test_random_module_mcmc(backend, init): if backend == "flax": import flax linear_module = flax.linen.Dense(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module kwargs_name = "inputs" elif backend == "haiku": import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(1)(x)) bias_name = "linear.b" weight_name = "linear.w" random_module = random_haiku_module kwargs_name = "x" N, dim = 3000, 3 num_warmup, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1.0, dim + 1.0) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) if init == "shape": kwargs = {"input_shape": (3,)} elif init == "kwargs": kwargs = {kwargs_name: data} def model(data, labels): nn = random_module( "nn", linear_module, {bias_name: dist.Cauchy(), weight_name: dist.Normal()}, **kwargs ) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() assert set(samples.keys()) == { "nn/{}".format(bias_name), "nn/{}".format(weight_name), } assert_allclose( np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22, )
def run_inference(model, args, rng_key, X, Y): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def init_mcmc( self, model, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, sampler="NUTS", sampler_kwargs={}, **kwargs, ) -> MCMC: """Initialises the MCMC sampler. Args: model (callable): [desc] num_warmup (int): [description]. Defaults to 1000. num_samples (int): [description]. Defaults to 1000. num_chains (int): [description]. Defaults to 1. sampler (str, or numpyro.infer.mcmc.MCMCKernel): Choose one of ['NUTS'], or pass a numpyro mcmc kernel. sampler_kwargs (dict): Keyword arguments to pass to the chosen sampler. **kwargs: Keyword arguments to pass to mcmc instance. """ self._mcmc_support_warnings() if isinstance(sampler, str): sampler = sampler.lower() if sampler != "nuts": raise ValueError(f"Sampler '{sampler}' not supported.") target_accept_prob = sampler_kwargs.pop("target_accept_prob", 0.98) init_strategy = sampler_kwargs.pop( "init_strategy", lambda site=None: init_to_median(site=site, num_samples=100), ) step_size = sampler_kwargs.pop("step_size", 0.1) sampler = NUTS( model, target_accept_prob=target_accept_prob, init_strategy=init_strategy, step_size=step_size, **sampler_kwargs, ) # if num_chains > 1: # self.batch_ndims = 2 # I.e. two dims for chains then samples return MCMC( sampler, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, **kwargs, )
class Sampler(): def __init__(self, model, data=None): self.data = data self.num_warmup = 1000 self.num_samples = 2000 self.num_chains = 4 self.mcmc = MCMC(NUTS(model), num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains) self.data = data def fit(self, data): self.data = data self.mcmc.run(random.PRNGKey(0), **data) self.post = self.mcmc.get_samples() return self.post # posterior samples def predict(self, data): pass
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs): def model(probs, locs): c = numpyro.sample("c", dist.Categorical(probs)) numpyro.sample("x", dist.Normal(locs[c], 0.5)) probs = jnp.array([0.15, 0.3, 0.3, 0.25]) locs = jnp.array([-2, 0, 2, 4]) sampler = kernel(inner_kernel(model, trajectory_length=1.2), modified=modified, **kwargs) mcmc = MCMC(sampler, num_warmup=1000, num_samples=200000, progress_bar=False) mcmc.run(random.PRNGKey(0), probs, locs) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1) assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4) assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1) assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
def test_beta_bernoulli_x64(kernel_cls): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) kernel = kernel_cls(model=model, trajectory_length=1.) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def main(args): print('Simulating data...') (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words) = simulate_data( random.PRNGKey(1), num_categories=args.num_categories, num_words=args.num_words, num_supervised_data=args.num_supervised, num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = NUTS(semi_supervised_hmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words) samples = mcmc.get_samples() print_results(samples, transition_prob, emission_prob) print('\nMCMC elapsed time:', time.time() - start)
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = jnp.dot(a, a.T) true_prec = jnp.linalg.inv(true_cov) def potential_fn(z): return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z)) init_params = jnp.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples), true_mean, atol=0.02) assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_binomial_stable_x64(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p'].dtype == jnp.float64
def run_inference(model, capture_history, sex, rng_key, args): if args.algo == "NUTS": kernel = NUTS(model) elif args.algo == "HMC": kernel = HMC(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, capture_history, sex) mcmc.print_summary() return mcmc.get_samples()
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 warmup_steps, num_samples = 1000, 8000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = kernel_cls(model=model, trajectory_length=8) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(2), labels) samples = mcmc.get_samples() assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64
def sample(model, num_samples, num_warmup, num_chains, seed=0, chain_method="parallel", summary=True, **kwargs): """Run the No-U-Turn sampler """ rng_key = random.PRNGKey(seed) kernel = NUTS(model) # Note: sampling more than one chain doesn't show a progress bar mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, chain_method=chain_method) mcmc.run(rng_key, **kwargs) if summary: mcmc.print_summary() # Return a fitted MCMC object return mcmc
def test_inference_data_constant_data(self): import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS x1 = 10 x2 = 12 y1 = np.random.randn(10) def model_constant_data(x, y1=None): _x = numpyro.sample("x", dist.Normal(1, 3)) numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1) nuts_kernel = NUTS(model_constant_data) mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) mcmc.run(PRNGKey(0), x=x1, y1=y1) posterior = mcmc.get_samples() posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1) predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2) inference_data = from_numpyro( mcmc, posterior_predictive=posterior_predictive, predictions=predictions, constant_data={"x1": x1}, predictions_constant_data={"x2": x2}, ) test_dict = { "posterior": ["x"], "posterior_predictive": ["y1"], "sample_stats": ["diverging"], "log_likelihood": ["y1"], "predictions": ["y1"], "observed_data": ["y1"], "constant_data": ["x1"], "predictions_constant_data": ["x2"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails
def test_missing_plate(monkeypatch): K, N = 3, 1000 def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) # plate/to_event is missing here cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) true_cluster_means = jnp.array([1.0, 5.0, 10.0]) true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample( random.PRNGKey(0), (N, )) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(random.PRNGKey(1)) nuts_kernel = NUTS(gmm) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) with pytest.raises(AssertionError, match="Missing plate statement"): mcmc.run(random.PRNGKey(2), data) monkeypatch.setattr(numpyro.infer.util, "_validate_model", lambda model_trace: None) with pytest.raises(Exception): mcmc.run(random.PRNGKey(2), data) assert len(_PYRO_STACK) == 0
def _sample(current_state, seed): step_size = jax.tree_map(jax.numpy.ones_like, init_state) nuts_kernel = NUTS( potential_fn=lambda x: -logp_fn_jax(*x), # model=model, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps", )) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields( group_by_chain=True)["num_steps"] return samples, leapfrogs_taken
def test_linear_model_sigma(kernel_cls, N=90, P=40, sigma=0.07, warmup_steps=500, num_samples=500): np.random.seed(1) X = np.random.randn(N * P).reshape((N, P)) XX = np.matmul(np.transpose(X), X) Y = X[:, 0] + sigma * np.random.randn(N) XY = np.sum(X * Y[:, None], axis=0) def model(X, Y): N, P = X.shape sigma = numpyro.sample("sigma", dist.HalfCauchy(1.0)) beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P))) mean = jnp.sum(beta * X, axis=-1) numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y) gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y) hmc_kernel = kernel_cls(model) kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['beta']) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0), X, Y) beta_mean = np.mean(mcmc.get_samples()['beta'], axis=0) assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05) sigma_mean = np.mean(mcmc.get_samples()['sigma'], axis=0) assert_allclose(sigma_mean, sigma, atol=0.25)
def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): true_loc = jnp.array([0.3, 0.1, 0.9]) num_warmup, num_samples = 200, 200 data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample( random.PRNGKey(1), (10000, )) def model(data, subsample_size): mean = numpyro.sample('mean', dist.Normal().expand((3, )).to_event(1)) with numpyro.plate('batch', data.shape[0], dim=-2, subsample_size=subsample_size): sub_data = numpyro.subsample(data, 0) numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) ref_params = { 'mean': true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0)) } proxy_fn = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(0), data, subsample_size) samples = mcmc.get_samples() assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0), true_loc, atol=0.1) assert len(samples['mean']) == num_samples
def main(args): model = models[args.model] _, fetch = load_dataset(JSB_CHORALES, split='train', shuffle=False) lengths, sequences = fetch() if args.num_sequences: sequences = sequences[0:args.num_sequences] lengths = lengths[0:args.num_sequences] logger.info('-' * 40) logger.info('Training {} on {} sequences'.format( model.__name__, len(sequences))) # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes with default args) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clip(0, args.truncate) sequences = sequences[:, :args.truncate] logger.info('Each sequence has shape {}'.format(sequences[0].shape)) logger.info('Starting inference...') rng_key = random.PRNGKey(2) start = time.time() kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, sequences, lengths, args=args) mcmc.print_summary() logger.info('\nMCMC elapsed time: {}'.format(time.time() - start))
def test_linear_model_log_sigma( kernel_cls, N=100, P=50, sigma=0.11, num_warmup=500, num_samples=500 ): np.random.seed(0) X = np.random.randn(N * P).reshape((N, P)) XX = np.matmul(np.transpose(X), X) Y = X[:, 0] + sigma * np.random.randn(N) XY = np.sum(X * Y[:, None], axis=0) def model(X, Y): N, P = X.shape log_sigma = numpyro.sample("log_sigma", dist.Normal(1.0)) sigma = jnp.exp(log_sigma) beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P))) mean = jnp.sum(beta * X, axis=-1) numpyro.deterministic("mean", mean) numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y) gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y) hmc_kernel = kernel_cls(model) kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["beta"]) mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False ) mcmc.run(random.PRNGKey(0), X, Y) beta_mean = np.mean(mcmc.get_samples()["beta"], axis=0) assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05) sigma_mean = np.exp(np.mean(mcmc.get_samples()["log_sigma"], axis=0)) assert_allclose(sigma_mean, sigma, atol=0.25)
def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = 0.1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (10_000,) ) n, _ = data.shape num_warmup = 200 num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample( "mean", dist.Normal(ref_params, jnp.ones_like(ref_params)) ) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = HMCECS.taylor_proxy({"mean": ref_params}) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"]) pes = mcmc.get_extra_fields()["hmc_state.potential_energy"] samples = mcmc.get_samples() pes_full = vmap( lambda sample: log_density( model, (data,), {}, {**sample, **{"N": jnp.arange(n)}} )[0] )(samples) assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
def run_inference(model, args, rng_key): kernel = NUTS(model) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key) mcmc.print_summary() return mcmc.get_samples()