def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(np.ones(2), np.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(1), model_args=(data, ), guide_args=(data, )) def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, ), guide_args=(data, )) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, ), guide_args=(data, )) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100, )) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(PRNGKey(0), model_args=(data, )) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, )) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def test_beta_bernoulli(): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(1), model_args=(data, )) assert_allclose(adam.get_params(svi_state.optim_state)['alpha_q'], 0.) def body_fn(i, val): svi_state, _ = svi.update(val, model_args=(data, )) return svi_state svi_state = fori_loop(0, 300, body_fn, svi_state) params = svi.get_params(svi_state) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, elbo, adam) svi_state = svi.init(rng_init, (data, ), (data, )) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, (data, ), (data, )) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, elbo, adam) svi_state = svi.init(rng_init, (data, ), (data, )) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, (data, ), (data, )) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc'], expected_values['alpha'] * expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) offset = numpyro.sample('offset', dist.Uniform(-1, 1)) logits = offset + np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init, model_args=(data, labels), guide_args=(data, labels)) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng, params) actual_output = guide.get_transform(params)(x) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(constraints.PermuteTransform( np.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=np.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity) arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)]) flows.append(InverseAutoregressiveTransform(arn)) transform = constraints.ComposeTransform(flows) rng_seed, rng_sample = random.split(rng) expected_sample = guide.unpack_latent( transform(dist.Normal(np.zeros(dim + 1), 1).sample(rng_sample))) expected_output = transform(x) assert_allclose(actual_sample['coefs'], expected_sample['coefs']) assert_allclose( actual_sample['offset'], constraints.biject_to(constraints.interval(-1, 1))( expected_sample['offset'])) assert_allclose(actual_output, expected_output)
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rngs = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = np.exp(random.normal(rngs[0])) + a_minval b_init = np.exp(random.normal(rngs[1])) c_init = random.uniform(rngs[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rngs[3]) obs = random.normal(rngs[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(0), (), ()) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal( a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = auto_class(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init, model_args=(data, labels), guide_args=(data, labels)) def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, labels), guide_args=(data, labels)) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class is not AutoIAFNormal: median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rngs = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = np.exp(random.normal(rngs[0])) + a_minval b_init = np.exp(random.normal(rngs[1])) x_init = random.normal(rngs[2]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_init})(*args, **kwargs) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['auto_loc'], guide._init_latent) assert_allclose(params['auto_scale'], np.ones(1)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) expected_loss = dist.Normal(guide._init_latent, 1).log_prob(x_init) - dist.Normal( a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss)
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") nuts_kernel = NUTS(potential_fn=dual_moon_pe) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.])) vanilla_samples = mcmc.get_samples() adam = optim.Adam(0.001) rng_init, rng_train = random.split(random.PRNGKey(1), 2) guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True) svi = SVI(dual_moon_model, guide, elbo, adam) svi_state = svi.init(rng_init) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(args.num_samples,)) transform = guide.get_transform(params) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") # TODO: exlore why neutra samples are not good # Issue: https://github.com/pyro-ppl/numpyro/issues/256 nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(10), init_params=init_params) zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()