コード例 #1
0
ファイル: test_stein.py プロジェクト: pyro-ppl/numpyro
def test_svgd_loss_and_grads():
    true_coefs, data, model = uniform_normal()
    guide = AutoDelta(model)
    loss = Trace_ELBO()
    stein_uparams = {
        "alpha_auto_loc": np.array([
            -1.2,
        ]),
        "loc_base_auto_loc": np.array([
            1.53,
        ]),
    }
    stein = SteinVI(model, guide, Adam(0.1), loss, RBFKernel())
    stein.init(random.PRNGKey(0), *data)
    svi = SVI(model, guide, Adam(0.1), loss)
    svi.init(random.PRNGKey(0), *data)
    expected_loss = loss.loss(random.PRNGKey(1),
                              svi.constrain_fn(stein_uparams), model, guide,
                              *data)
    stein_loss, stein_grad = stein._svgd_loss_and_grads(
        random.PRNGKey(1), stein_uparams, *data)
    assert expected_loss == stein_loss
コード例 #2
0
ファイル: test_handlers.py プロジェクト: mhashemi0873/numpyro
def test_subsample_gradient(scale, subsample):
    data = jnp.array([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale

    def model(subsample):
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size) as ind:
                x = data[ind]
                z = numpyro.sample("z", dist.Normal(0, 1))
                numpyro.sample("x", dist.Normal(z, 1), obs=x)

    def guide(subsample):
        scale = numpyro.param("scale", 1.)
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size):
                loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0)
                numpyro.sample("z", dist.Normal(loc, scale))

    if scale != 1.:
        model = handlers.scale(model, scale=scale)
        guide = handlers.scale(guide, scale=scale)

    num_particles = 50000
    optimizer = optim.Adam(0.1)
    elbo = Trace_ELBO(num_particles=num_particles)
    svi = SVI(model, guide, optimizer, loss=elbo)
    svi_state = svi.init(random.PRNGKey(0), None)
    params = svi.optim.get_params(svi_state.optim_state)
    normalizer = 2 if subsample else 1
    if subsample_size == 1:
        subsample = jnp.array([0])
        loss1, grads1 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        subsample = jnp.array([1])
        loss2, grads2 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2)
        loss = loss1 + loss2
    else:
        subsample = jnp.array([0, 1])
        loss, grads = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)

    actual_loss = loss / normalizer
    expected_loss, _ = value_and_grad(lambda x: svi.loss.loss(
        svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))(
            params)
    assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision)

    actual_grads = {name: grad / normalizer for name, grad in grads.items()}
    expected_grads = {
        'loc': scale * jnp.array([0.5, -2.0]),
        'scale': scale * jnp.array([2.0])
    }
    assert actual_grads.keys() == expected_grads.keys()
    for name in expected_grads:
        assert_allclose(actual_grads[name],
                        expected_grads[name],
                        rtol=precision,
                        atol=precision)