Exemple #1
0
def main(_args):
    data = generate_data()
    init_rng_key = PRNGKey(1273)
    # nuts = NUTS(gmm)
    # mcmc = MCMC(nuts, 100, 1000)
    # mcmc.print_summary()
    seeded_gmm = seed(gmm, init_rng_key)
    model_trace = trace(seeded_gmm).get_trace(data)
    max_plate_nesting = _guess_max_plate_nesting(model_trace)
    enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1)
    svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.))
    svi_state = svi.init(init_rng_key, data)
    upd_fun = jax.jit(svi.update)
    with tqdm.trange(100_000) as pbar:
        for i in pbar:
            svi_state, loss = upd_fun(svi_state, data)
            pbar.set_description(f"SVI {loss}", True)
Exemple #2
0
def test_svi_discrete_latent():
    cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()]
    mixed_inf_cls = [TraceGraph_ELBO()]

    assert not any([c.can_infer_discrete for c in cont_inf_only_cls])
    assert all([c.can_infer_discrete for c in mixed_inf_cls])

    def model():
        numpyro.sample("x", dist.Bernoulli(0.5))

    def guide():
        probs = numpyro.param("probs", 0.2)
        numpyro.sample("x", dist.Bernoulli(probs))

    for elbo in cont_inf_only_cls:
        svi = SVI(model, guide, optim.Adam(1), elbo)
        s_name = type(elbo).__name__
        w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
        with pytest.warns(UserWarning, match=w_msg):
            svi.run(random.PRNGKey(0), 10)
Exemple #3
0
@pytest.mark.parametrize("optim_class, args, kwargs", optimizers)
def test_optim_multi_params(optim_class, args, kwargs):
    params = {
        "x": jnp.array([1.0, 1.0, 1.0]),
        "y": jnp.array([-1, -1.0, -1.0])
    }
    opt = optax_to_numpyro(optim_class(*args, **kwargs))
    opt_state = opt.init(params)
    for i in range(2000):
        opt_state = step(opt_state, opt)
    for _, param in opt.get_params(opt_state).items():
        assert jnp.allclose(param, jnp.zeros(3))


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optax.adam(0.05)
Exemple #4
0
 def renyi_loss_fn(x):
     return RenyiELBO(alpha=alpha,
                      num_particles=10).loss(random.PRNGKey(0), {}, model,
                                             guide, x)
Exemple #5
0
        return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha,
                         num_particles=10).loss(random.PRNGKey(0), {}, model,
                                                guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


@pytest.mark.parametrize('elbo', [
    ELBO(),
    RenyiELBO(num_particles=10),
])
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
Exemple #6
0
    def elbo_loss_fn(x):
        return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


@pytest.mark.parametrize('elbo', [
    ELBO(),
    pytest.param(RenyiELBO(num_particles=10),
                 marks=pytest.mark.xfail(reason="https://github.com/pyro-ppl/numpyro/issues/414"))
])
def test_beta_bernoulli(elbo):
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))