def test_chain(use_init_params, chain_method): N, dim = 3000, 3 num_chains = 2 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains) mcmc.chain_method = chain_method init_params = None if not use_init_params else \ {'coefs': np.tile(np.ones(dim), num_chains).reshape(num_chains, dim)} mcmc.run(random.PRNGKey(2), labels, init_params=init_params) samples_flat = mcmc.get_samples() assert samples_flat['coefs'].shape[0] == num_chains * num_samples samples = mcmc.get_samples(group_by_chain=True) assert samples['coefs'].shape[:2] == (num_chains, num_samples) assert_allclose(np.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21)
def test_chain(use_init_params, chain_method): N, dim = 3000, 3 num_chains = 2 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) numpyro.deterministic("logits", logits) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains) mcmc.chain_method = chain_method init_params = (None if not use_init_params else { "coefs": jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim) }) mcmc.run(random.PRNGKey(2), labels, init_params=init_params) samples_flat = mcmc.get_samples() assert samples_flat["coefs"].shape[0] == num_chains * num_samples samples = mcmc.get_samples(group_by_chain=True) assert samples["coefs"].shape[:2] == (num_chains, num_samples) assert_allclose(jnp.mean(samples_flat["coefs"], 0), true_coefs, atol=0.21) # test if reshape works device_get(samples_flat["coefs"].reshape(-1))