def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def test_svi_discrete_latent(): cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()] mixed_inf_cls = [TraceGraph_ELBO()] assert not any([c.can_infer_discrete for c in cont_inf_only_cls]) assert all([c.can_infer_discrete for c in mixed_inf_cls]) def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) for elbo in cont_inf_only_cls: svi = SVI(model, guide, optim.Adam(1), elbo) s_name = type(elbo).__name__ w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" with pytest.warns(UserWarning, match=w_msg): svi.run(random.PRNGKey(0), 10)
@pytest.mark.parametrize("optim_class, args, kwargs", optimizers) def test_optim_multi_params(optim_class, args, kwargs): params = { "x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0]) } opt = optax_to_numpyro(optim_class(*args, **kwargs)) opt_state = opt.init(params) for i in range(2000): opt_state = step(opt_state, opt) for _, param in opt.get_params(opt_state).items(): assert jnp.allclose(param, jnp.zeros(3)) @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optax.adam(0.05)
def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)
return ELBO().loss(random.PRNGKey(0), {}, model, guide, x) def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x) elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.) renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.) assert_allclose(elbo_loss, renyi_loss, rtol=1e-6) assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) @pytest.mark.parametrize('elbo', [ ELBO(), RenyiELBO(num_particles=10), ]) def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
def elbo_loss_fn(x): return ELBO().loss(random.PRNGKey(0), {}, model, guide, x) def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x) elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.) renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.) assert_allclose(elbo_loss, renyi_loss, rtol=1e-6) assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) @pytest.mark.parametrize('elbo', [ ELBO(), pytest.param(RenyiELBO(num_particles=10), marks=pytest.mark.xfail(reason="https://github.com/pyro-ppl/numpyro/issues/414")) ]) def test_beta_bernoulli(elbo): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q))