Exemplo n.º 1
0
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params)
Exemplo n.º 2
0
def test_laplace_approximation_custom_hessian():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10))
        mu = a + b * x
        with numpyro.plate("N", len(x)):
            numpyro.sample("y", dist.Normal(mu, 1), obs=y)

    x = random.normal(random.PRNGKey(0), (100,))
    y = 1 + 2 * x
    guide = AutoLaplaceApproximation(
        model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x)
    )
    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y)
    svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False)
    guide.get_transform(svi_result.params)
Exemplo n.º 3
0
def test_logistic_regression(auto_class, Elbo):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs',
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Elbo())
    svi_state = svi.init(rng_key_init, data, labels)

    # smoke test if analytic KL is used
    if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO:
        _, mean_field_loss = svi.update(svi_state, data, labels)
        svi.loss = Trace_ELBO()
        _, elbo_loss = svi.update(svi_state, data, labels)
        svi.loss = TraceMeanField_ELBO()
        assert abs(mean_field_loss - elbo_loss) > 0.5

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        if auto_class is not AutoDelta:
            median = guide.quantiles(params, [0.2, 0.5])
            assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0),
                    true_coefs,
                    rtol=0.1)
Exemplo n.º 4
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs)
    svi_state = svi.init(random.PRNGKey(0))
    params = svi.get_params(svi_state)
    neutra = NeuTraReparam(guide, params)
    reparam_model = neutra.reparam(model)
    _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(
        random.PRNGKey(2), reparam_model, model_kwargs=kwargs
    )
    latent_x = list(init_params[0].values())[0]
    pe_transformed = pe_fn_neutra(init_params[0])
    latent_y = neutra.transform(latent_x)
    log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn(guide._unpack_latent(latent_y))
    assert_allclose(pe_transformed, pe - log_det_jacobian, rtol=2e-7)
Exemplo n.º 5
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1,))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(dim + 1, [dim + 1, dim + 1],
                                               permutation=jnp.arange(dim + 1),
                                               skip_connections=guide._skip_connections,
                                               nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(actual_sample['offset'],
                    transforms.biject_to(constraints.interval(-1, 1))(expected_sample['offset']))
    check_eq(actual_output, expected_output)
Exemplo n.º 6
0
def test_collapse_beta_bernoulli():
    data = 0.

    def model():
        c = numpyro.sample("c", dist.Gamma(1, 1))
        with handlers.collapse():
            probs = numpyro.sample("probs", dist.Beta(c, 2))
            numpyro.sample("obs", dist.Bernoulli(probs), obs=data)

    def guide():
        a = numpyro.param("a", 1., constraint=constraints.positive)
        b = numpyro.param("b", 1., constraint=constraints.positive)
        numpyro.sample("c", dist.Gamma(a, b))

    svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 7
0
def test_improper():
    y = random.normal(random.PRNGKey(0), (100, ))

    def model(y):
        lambda1 = numpyro.sample(
            'lambda1', dist.ImproperUniform(dist.constraints.real, (), ()))
        lambda2 = numpyro.sample(
            'lambda2', dist.ImproperUniform(dist.constraints.real, (), ()))
        sigma = numpyro.sample(
            'sigma', dist.ImproperUniform(dist.constraints.positive, (), ()))
        mu = numpyro.deterministic('mu', lambda1 + lambda2)
        numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), y=y)
    svi_state = svi.init(random.PRNGKey(2))
    lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
Exemplo n.º 8
0
def run_svi_inference(model, guide, rng_key, X, Y, optimizer, n_epochs=1_000):

    # initialize svi
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    # initialize state
    init_state = svi.init(rng_key, X, Y.squeeze())

    # Run optimizer for 1000 iteratons.
    state, losses = jax.lax.scan(
        lambda state, i: svi.update(state, X, Y.squeeze()), init_state, n_epochs
    )

    # Extract surrogate posterior.
    params = svi.get_params(state)

    return params
Exemplo n.º 9
0
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T
    N = len(data)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event())
        with numpyro.plate("N", N):
            numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data)

    adam = optim.Adam(0.01)
    if auto_class == AutoDAIS:
        guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky")
    else:
        guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(
        random.PRNGKey(1), params, sample_shape=(1000,)
    )
    posterior_mean = jnp.mean(posterior_samples["beta"], 0)
    assert_allclose(posterior_mean, true_coefs, atol=0.05)

    if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]:
        quantiles = guide.quantiles(params, [0.2, 0.5, 0.8])
        assert quantiles["beta"].shape == (3, 2)

    # Predictive can be instantiated from posterior samples...
    predictive = Predictive(model, posterior_samples=posterior_samples)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, N, 2)

    # ... or from the guide + params
    predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, N, 2)
Exemplo n.º 10
0
def test_collapse_beta_binomial_plate():
    data = np.array([0., 1., 5., 5.])

    def model():
        c = numpyro.sample("c", dist.Gamma(1, 1))
        with handlers.collapse():
            probs = numpyro.sample("probs", dist.Beta(c, 2))
            with numpyro.plate("plate", len(data)):
                numpyro.sample("obs", dist.Binomial(10, probs), obs=data)

    def guide():
        a = numpyro.param("a", 1., constraint=constraints.positive)
        b = numpyro.param("b", 1., constraint=constraints.positive)
        numpyro.sample("c", dist.Gamma(a, b))

    svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 11
0
def test_autoguide(deterministic):
    GLOBAL["count"] = 0
    guide = AutoDiagonalNormal(model)
    svi = SVI(model,
              guide,
              optim.Adam(0.1),
              Trace_ELBO(),
              deterministic=deterministic)
    svi_state = svi.init(random.PRNGKey(0))
    svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0],
                              svi_state)
    params = svi.get_params(svi_state)
    guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, ))

    if deterministic:
        assert GLOBAL["count"] == 5
    else:
        assert GLOBAL["count"] == 4
Exemplo n.º 12
0
def test_run(progress_bar):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", lambda key: random.normal(key),
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", lambda key: random.exponential(key),
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO())
    params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar)
    assert losses.shape == (1000,)
    assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
Exemplo n.º 13
0
def test_auto_guide(auto_class, init_loc_fn, num_particles):
    latent_dim = 3

    def model(obs):
        a = numpyro.sample("a", Normal(0, 1))
        return numpyro.sample("obs", Bernoulli(logits=a), obs=obs)

    obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim))

    rng_key = random.PRNGKey(0)
    guide_key, stein_key = random.split(rng_key)
    inner_guide = auto_class(model, init_loc_fn=init_loc_fn())

    with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr:
        inner_guide(obs)

    steinvi = SteinVI(
        model,
        auto_class(model, init_loc_fn=init_loc_fn()),
        Adam(1.0),
        Trace_ELBO(),
        RBFKernel(),
        num_particles=num_particles,
    )
    state = steinvi.init(stein_key, obs)
    init_params = steinvi.get_params(state)

    for name, site in inner_guide_tr.items():
        if site.get("type") == "param":
            assert name in init_params
            inner_param = site
            init_value = init_params[name]
            expected_shape = (num_particles, *np.shape(inner_param["value"]))
            assert init_value.shape == expected_shape
            if "auto_loc" in name or name == "b":
                assert np.alltrue(init_value != np.zeros(expected_shape))
                assert np.unique(init_value).shape == init_value.reshape(
                    -1).shape
            elif "scale" in name:
                assert_array_approx_equal(init_value,
                                          np.full(expected_shape, 0.1))
            else:
                assert_array_approx_equal(init_value,
                                          np.full(expected_shape, 0.0))
Exemplo n.º 14
0
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param("a",
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param("b", b_init, constraint=constraints.positive)
        numpyro.sample("x", dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {"_auto_latent": x_init[None]})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params["a"], a_init, rtol=1e-6)
    assert_allclose(params["b"], b_init, rtol=1e-6)
    assert_allclose(params["auto_loc"], guide._init_latent, rtol=1e-6)
    assert_allclose(params["auto_scale"],
                    jnp.ones(1) * guide._init_scale,
                    rtol=1e-6)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_init) - dist.Normal(
            a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Exemplo n.º 15
0
def test_get_params(kernel, auto_guide, init_loc_fn, problem):
    _, data, model = problem()
    guide, optim, elbo = (
        auto_guide(model, init_loc_fn=init_loc_fn),
        Adam(1e-1),
        Trace_ELBO(),
    )

    stein = SteinVI(model, guide, optim, elbo, kernel)
    stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data))

    svi = SVI(model, guide, optim, elbo)
    svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data))
    assert svi_params.keys() == stein_params.keys()

    for name, svi_param in svi_params.items():
        assert (stein_params[name].shape == np.repeat(svi_param[None, ...],
                                                      stein.num_particles,
                                                      axis=0).shape)
Exemplo n.º 16
0
def test_subsample_guide(auto_class):

    # The model adapted from tutorial/source/easyguide.ipynb
    def model(batch, subsample, full_size):
        drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5))
        with handlers.substitute(data={"data": subsample}):
            plate = numpyro.plate("data", full_size, subsample_size=len(subsample))
        assert plate.size == 50

        def transition_fn(z_prev, y_curr):
            with plate:
                z_curr = numpyro.sample("state", dist.Normal(z_prev, drift))
                y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr)
            return z_curr, y_curr

        _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps)
        return result

    def create_plates(batch, subsample, full_size):
        with handlers.substitute(data={"data": subsample}):
            return numpyro.plate("data", full_size, subsample_size=subsample.shape[0])

    guide = auto_class(model, create_plates=create_plates)

    full_size = 50
    batch_size = 20
    num_time_steps = 8
    with handlers.seed(rng_seed=0):
        data = model(None, jnp.arange(full_size), full_size)
    assert data.shape == (num_time_steps, full_size)

    svi = SVI(model, guide, optim.Adam(0.02), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0), data[:, :batch_size],
                         jnp.arange(batch_size), full_size=full_size)
    update_fn = jit(svi.update, static_argnums=(3,))
    for epoch in range(2):
        beg = 0
        while beg < full_size:
            end = min(full_size, beg + batch_size)
            subsample = jnp.arange(beg, end)
            batch = data[:, beg:end]
            beg = end
            svi_state, loss = update_fn(svi_state, batch, subsample, full_size)
Exemplo n.º 17
0
def train_model_no_dp(rng,
                      model,
                      guide,
                      data,
                      batch_size,
                      num_data,
                      num_epochs,
                      silent=False,
                      **kwargs):
    """ trains a given model using SVI (no DP!) and the globally defined parameters and data """

    optimizer = Adam(1e-3)

    svi = SVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data)

    import d3p.random.debug
    return _train_model(d3p.random.convert_to_jax_rng_key(rng),
                        d3p.random.debug, svi, data, batch_size, num_data,
                        num_epochs, silent)
Exemplo n.º 18
0
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param("a",
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param("b", b_init, constraint=constraints.positive)
        numpyro.sample("x", dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param("c",
                          c_init,
                          constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param("d", d_init, constraint=constraints.unit_interval)
        numpyro.sample("y", dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params["a"], a_init)
    assert_allclose(params["b"], b_init)
    assert_allclose(params["c"], c_init)
    assert_allclose(params["d"], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(
        a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Exemplo n.º 19
0
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5)
Exemplo n.º 20
0
def test_calc_particle_info(num_params, num_particles):
    seed = random.PRNGKey(nrandom.randint(0, 10_000))
    sizes = Poisson(5).sample(seed, (100, nrandom.randint(0, 10))) + 1

    uparam = tuple(np.empty(tuple(size)) for size in sizes)
    uparams = {string.ascii_lowercase[i]: uparam for i in range(num_params)}

    par_param_size = sum(map(lambda size: size.prod(), sizes)) // num_particles
    expected_start_end = zip(
        par_param_size * np.arange(num_params),
        par_param_size * np.arange(1, num_params + 1),
    )
    expected_pinfo = dict(
        zip(string.ascii_lowercase[:num_params], expected_start_end))

    stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel())
    pinfo = stein._calc_particle_info(uparams, num_particles)

    for k in pinfo.keys():
        assert pinfo[k] == expected_pinfo[k], f"Failed for seed {seed}"
Exemplo n.º 21
0
def test_svi_discrete_latent():
    cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()]
    mixed_inf_cls = [TraceGraph_ELBO()]

    assert not any([c.can_infer_discrete for c in cont_inf_only_cls])
    assert all([c.can_infer_discrete for c in mixed_inf_cls])

    def model():
        numpyro.sample("x", dist.Bernoulli(0.5))

    def guide():
        probs = numpyro.param("probs", 0.2)
        numpyro.sample("x", dist.Bernoulli(probs))

    for elbo in cont_inf_only_cls:
        svi = SVI(model, guide, optim.Adam(1), elbo)
        s_name = type(elbo).__name__
        w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
        with pytest.warns(UserWarning, match=w_msg):
            svi.run(random.PRNGKey(0), 10)
Exemplo n.º 22
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    x = 2.
    guide = substitute(guide, data={'x': x})
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
Exemplo n.º 23
0
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)
    ref_params = {'theta': params['theta_auto_loc']}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
Exemplo n.º 24
0
def test_autoguide_deterministic(auto_class):
    def model(y=None):
        n = y.size if y is not None else 1

        mu = numpyro.sample("mu", dist.Normal(0, 5))
        sigma = numpyro.param("sigma", 1, constraint=constraints.positive)

        with numpyro.plate("N", len(y)):
            y = numpyro.sample("y",
                               dist.Normal(mu, sigma).expand((n, )),
                               obs=y)
        numpyro.deterministic("z", (y - mu) / sigma)

    mu, sigma = 2, 3
    y = mu + sigma * random.normal(random.PRNGKey(0), shape=(300, ))
    y_train = y[:200]
    y_test = y[200:]

    guide = auto_class(model)
    optimiser = numpyro.optim.Adam(step_size=0.01)
    svi = SVI(model, guide, optimiser, Trace_ELBO())

    svi_result = svi.run(random.PRNGKey(0), num_steps=500, y=y_train)
    params = svi_result.params
    posterior_samples = guide.sample_posterior(random.PRNGKey(0),
                                               params,
                                               sample_shape=(1000, ))

    predictive = Predictive(model, posterior_samples, params=params)
    predictive_samples = predictive(random.PRNGKey(0), y_test)

    assert predictive_samples["y"].shape == (1000, 100)
    assert predictive_samples["z"].shape == (1000, 100)
    assert_allclose(
        (predictive_samples["y"] - posterior_samples["mu"][..., None]) /
        params["sigma"],
        predictive_samples["z"],
        atol=0.05,
    )
Exemplo n.º 25
0
def test_neals_funnel_smoke():
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0), dim)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, dim)
        return svi_state

    svi_state = lax.fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    neutra = NeuTraReparam(guide, params)
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_warmup=50, num_samples=50)
    mcmc.run(random.PRNGKey(1), dim)
    samples = mcmc.get_samples()
    transformed_samples = neutra.transform_sample(samples['auto_shared_latent'])
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
Exemplo n.º 26
0
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2,
                     [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
Exemplo n.º 27
0
@pytest.mark.parametrize("optim_class, args, kwargs", optimizers)
def test_optim_multi_params(optim_class, args, kwargs):
    params = {
        "x": jnp.array([1.0, 1.0, 1.0]),
        "y": jnp.array([-1, -1.0, -1.0])
    }
    opt = optax_to_numpyro(optim_class(*args, **kwargs))
    opt_state = opt.init(params)
    for i in range(2000):
        opt_state = step(opt_state, opt)
    for _, param in opt.get_params(opt_state).items():
        assert jnp.allclose(param, jnp.zeros(3))


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optax.adam(0.05)
Exemplo n.º 28
0
def main(args):
    encoder_nn = encoder(args.hidden_dim, args.z_dim)
    decoder_nn = decoder(args.hidden_dim, 28 * 28)
    adam = optim.Adam(args.learning_rate)
    svi = SVI(model,
              guide,
              adam,
              Trace_ELBO(),
              hidden_dim=args.hidden_dim,
              z_dim=args.z_dim)
    rng_key = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
    svi_state = svi.init(rng_key_init, sample_batch)

    @jit
    def epoch_train(svi_state, rng_key):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))

    @jit
    def eval_test(svi_state, rng_key):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.)
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](params['encoder$params'],
                                      test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(
            rng_key, 4)
        t_start = time.time()
        num_train, train_idx = train_init()
        _, svi_state = epoch_train(svi_state, rng_key_train)
        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(svi_state, rng_key_test)
        reconstruct_img(i, rng_key_reconstruct)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))
Exemplo n.º 29
0
def test_cond():
    def model():
        def true_fun(_):
            x = numpyro.sample("x", dist.Normal(4.0))
            numpyro.deterministic("z", x - 4.0)

        def false_fun(_):
            x = numpyro.sample("x", dist.Normal(0.0))
            numpyro.deterministic("z", x)

        cluster = numpyro.sample("cluster", dist.Normal())
        cond(cluster > 0, true_fun, false_fun, None)

    def guide():
        m1 = numpyro.param("m1", 2.0)
        s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
        m2 = numpyro.param("m2", 2.0)
        s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)

        def true_fun(_):
            numpyro.sample("x", dist.Normal(m1, s1))

        def false_fun(_):
            numpyro.sample("x", dist.Normal(m2, s2))

        cluster = numpyro.sample("cluster", dist.Normal())
        cond(cluster > 0, true_fun, false_fun, None)

    svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
    svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
    params = svi_result.params

    predictive = Predictive(
        model,
        guide=guide,
        params=params,
        num_samples=1000,
        return_sites=["cluster", "x", "z"],
    )
    result = predictive(random.PRNGKey(0))

    assert result["cluster"].shape == (1000,)
    assert result["x"].shape == (1000,)
    assert result["z"].shape == (1000,)

    mcmc = MCMC(
        NUTS(model),
        num_warmup=500,
        num_samples=2500,
        num_chains=4,
        chain_method="sequential",
    )
    mcmc.run(random.PRNGKey(0))

    x = mcmc.get_samples()["x"]
    assert x.shape == (10_000,)
    assert_allclose(
        [x[x > 2.0].mean(), x[x > 2.0].std(), x[x < 2.0].mean(), x[x < 2.0].std()],
        [4.01, 0.965, -0.01, 0.965],
        atol=0.1,
    )
    assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5)
Exemplo n.º 30
0
def benchmark_hmc(args, features, labels):
    rng_key = random.PRNGKey(1)
    start = time.time()
    # a MAP estimate at the following source
    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
    ref_params = {
        "coefs":
        jnp.array([
            +2.03420663e00,
            -3.53567265e-02,
            -1.49223924e-01,
            -3.07049364e-01,
            -1.00028366e-01,
            -1.46827862e-01,
            -1.64167881e-01,
            -4.20344204e-01,
            +9.47479829e-02,
            -1.12681836e-02,
            +2.64442056e-01,
            -1.22087866e-01,
            -6.00568838e-02,
            -3.79419506e-01,
            -1.06668741e-01,
            -2.97053963e-01,
            -2.05253899e-01,
            -4.69537191e-02,
            -2.78072730e-02,
            -1.43250525e-01,
            -6.77954629e-02,
            -4.34899796e-03,
            +5.90927452e-02,
            +7.23133609e-02,
            +1.38526391e-02,
            -1.24497898e-01,
            -1.50733739e-02,
            -2.68872194e-02,
            -1.80925727e-02,
            +3.47936489e-02,
            +4.03552800e-02,
            -9.98773426e-03,
            +6.20188080e-02,
            +1.15002751e-01,
            +1.32145107e-01,
            +2.69109547e-01,
            +2.45785132e-01,
            +1.19035013e-01,
            -2.59744357e-02,
            +9.94279515e-04,
            +3.39266285e-02,
            -1.44057125e-02,
            -6.95222765e-02,
            -7.52013028e-02,
            +1.21171586e-01,
            +2.29205526e-02,
            +1.47308692e-01,
            -8.34354162e-02,
            -9.34122875e-02,
            -2.97472421e-02,
            -3.03937674e-01,
            -1.70958012e-01,
            -1.59496680e-01,
            -1.88516974e-01,
            -1.20889175e00,
        ])
    }
    if args.algo == "HMC":
        step_size = jnp.sqrt(0.5 / features.shape[0])
        trajectory_length = step_size * args.num_steps
        kernel = HMC(
            model,
            step_size=step_size,
            trajectory_length=trajectory_length,
            adapt_step_size=False,
            dense_mass=args.dense_mass,
        )
        subsample_size = None
    elif args.algo == "NUTS":
        kernel = NUTS(model, dense_mass=args.dense_mass)
        subsample_size = None
    elif args.algo == "HMCECS":
        subsample_size = 1000
        inner_kernel = NUTS(
            model,
            init_strategy=init_to_value(values=ref_params),
            dense_mass=args.dense_mass,
        )
        # note: if num_blocks=100, we'll update 10 index at each MCMC step
        # so it took 50000 MCMC steps to iterative the whole dataset
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(ref_params))
    elif args.algo == "SA":
        # NB: this kernel requires large num_warmup and num_samples
        # and running on GPU is much faster than on CPU
        kernel = SA(model,
                    adapt_state_size=1000,
                    init_strategy=init_to_value(values=ref_params))
        subsample_size = None
    elif args.algo == "FlowHMCECS":
        subsample_size = 1000
        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
        params, losses = svi_result.params, svi_result.losses
        plt.plot(losses)
        plt.show()

        neutra = NeuTraReparam(guide, params)
        neutra_model = neutra.reparam(model)
        neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
        # no need to adapt mass matrix if the flow does a good job
        inner_kernel = NUTS(
            neutra_model,
            init_strategy=init_to_value(values=neutra_ref_params),
            adapt_mass_matrix=False,
        )
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(neutra_ref_params))
    else:
        raise ValueError(
            "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(rng_key,
             features,
             labels,
             subsample_size,
             extra_fields=("accept_prob", ))
    print("Mean accept prob:",
          jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
    mcmc.print_summary(exclude_deterministic=False)
    print("\nMCMC elapsed time:", time.time() - start)