Exemplo n.º 1
0
def test_obs_mask_ok(Elbo, mask, num_particles):
    data = np.array([7., 7., 7.])

    def model():
        x = numpyro.sample("x", dist.Normal(0., 1.))
        with numpyro.plate("plate", len(data)):
            y = numpyro.sample("y",
                               dist.Normal(x, 1.),
                               obs=data,
                               obs_mask=mask)
            if not_jax_tracer(y):
                assert ((y == data) == mask).all()

    def guide():
        loc = numpyro.param("loc", np.zeros(()))
        scale = numpyro.param("scale",
                              np.ones(()),
                              constraint=constraints.positive)
        x = numpyro.sample("x", dist.Normal(loc, scale))
        with numpyro.plate("plate", len(data)):
            with handlers.mask(mask=np.invert(mask)):
                numpyro.sample("y_unobserved", dist.Normal(x, 1.))

    elbo = Elbo(num_particles=num_particles)
    svi = SVI(model, guide, numpyro.optim.Adam(1), elbo)
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 2
0
def test_obs_mask_multivariate_ok(Elbo, mask, num_particles):
    data = np.full((4, 3), 7.0)

    def model():
        x = numpyro.sample("x",
                           dist.MultivariateNormal(np.zeros(3), np.eye(3)))
        with numpyro.plate("plate", len(data)):
            y = numpyro.sample("y",
                               dist.MultivariateNormal(x, np.eye(3)),
                               obs=data,
                               obs_mask=mask)
            if not_jax_tracer(y):
                assert ((y == data).all(-1) == mask).all()

    def guide():
        loc = numpyro.param("loc", np.zeros(3))
        cov = numpyro.param("cov",
                            np.eye(3),
                            constraint=constraints.positive_definite)
        x = numpyro.sample("x", dist.MultivariateNormal(loc, cov))
        with numpyro.plate("plate", len(data)):
            with handlers.mask(mask=np.invert(mask)):
                numpyro.sample("y_unobserved",
                               dist.MultivariateNormal(x, np.eye(3)))

    elbo = Elbo(num_particles=num_particles)
    svi = SVI(model, guide, numpyro.optim.Adam(1), elbo)
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 3
0
def test_logistic_regression(auto_class, Elbo):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(0, 1).expand([dim]).to_event())
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data,
                                                         axis=-1))
        with numpyro.plate("N", len(data)):
            return numpyro.sample("obs",
                                  dist.Bernoulli(logits=logits),
                                  obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Elbo())
    svi_state = svi.init(rng_key_init, data, labels)

    # smoke test if analytic KL is used
    if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO:
        _, mean_field_loss = svi.update(svi_state, data, labels)
        svi.loss = Trace_ELBO()
        _, elbo_loss = svi.update(svi_state, data, labels)
        svi.loss = TraceMeanField_ELBO()
        assert abs(mean_field_loss - elbo_loss) > 0.5

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median["coefs"], true_coefs, rtol=0.1)
        # test .quantile method
        if auto_class is not AutoDelta:
            median = guide.quantiles(params, [0.2, 0.5])
            assert_allclose(median["coefs"][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(posterior_samples["coefs"], 0),
                    expected_coefs,
                    rtol=0.1)
def fit_advi(model, num_iter, learning_rate=0.01, seed=0):
    """Automatic Differentiation Variational Inference using a Normal variational distribution
    with a diagonal covariance matrix.

    Args:
        model: a NumPyro's model function
        num_iter: number of iterations of gradient descent (Adam)
        learning_rate: the step size for the Adam algorithm (default: {0.01})
        seed: random seed (default: {0})

    Returns:
        a set of results of type ADVIResults
    """
    rng_key = random.PRNGKey(seed)
    adam = Adam(learning_rate)
    # Automatically create a variational distribution (aka "guide" in Pyro's terminology)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key)

    # Run optimization
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(num_iter))
    results = ADVIResults(svi=svi,
                          guide=guide,
                          state=last_state,
                          losses=losses)
    return results
Exemplo n.º 5
0
def run_inference(model, inputs, method=None):
    if method is None:
        # NUTS
        num_samples = 5000
        logger.info('NUTS sampling')
        kernel = NUTS(model)
        mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples)
        rng_key = random.PRNGKey(0)
        mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', ))
        logger.info(r'MCMC summary for: {}'.format(model.__name__))
        mcmc.print_summary(exclude_deterministic=False)
        samples = mcmc.get_samples()
    else:
        #SVI
        logger.info('Guide generation...')
        rng_key = random.PRNGKey(0)
        guide = AutoDiagonalNormal(model=model)
        logger.info('Optimizer generation...')
        optim = Adam(0.05)
        logger.info('SVI generation...')
        svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs)
        init_state = svi.init(rng_key)
        logger.info('Scan...')
        state, loss = lax.scan(lambda x, i: svi.update(x), init_state,
                               np.zeros(2000))
        params = svi.get_params(state)
        samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, ))
        logger.info(r'SVI summary for: {}'.format(model.__name__))
        numpyro.diagnostics.print_summary(samples,
                                          prob=0.90,
                                          group_by_chain=False)
    return samples
Exemplo n.º 6
0
def test_logistic_regression(auto_class, Elbo):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs',
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Elbo())
    svi_state = svi.init(rng_key_init, data, labels)

    # smoke test if analytic KL is used
    if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO:
        _, mean_field_loss = svi.update(svi_state, data, labels)
        svi.loss = Trace_ELBO()
        _, elbo_loss = svi.update(svi_state, data, labels)
        svi.loss = TraceMeanField_ELBO()
        assert abs(mean_field_loss - elbo_loss) > 0.5

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(params, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0),
                    true_coefs,
                    rtol=0.1)
Exemplo n.º 7
0
def test_collapse_beta_bernoulli():
    data = 0.

    def model():
        c = numpyro.sample("c", dist.Gamma(1, 1))
        with handlers.collapse():
            probs = numpyro.sample("probs", dist.Beta(c, 2))
            numpyro.sample("obs", dist.Bernoulli(probs), obs=data)

    def guide():
        a = numpyro.param("a", 1., constraint=constraints.positive)
        b = numpyro.param("b", 1., constraint=constraints.positive)
        numpyro.sample("c", dist.Gamma(a, b))

    svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 8
0
def test_collapse_beta_binomial_plate():
    data = np.array([0., 1., 5., 5.])

    def model():
        c = numpyro.sample("c", dist.Gamma(1, 1))
        with handlers.collapse():
            probs = numpyro.sample("probs", dist.Beta(c, 2))
            with numpyro.plate("plate", len(data)):
                numpyro.sample("obs", dist.Binomial(10, probs), obs=data)

    def guide():
        a = numpyro.param("a", 1., constraint=constraints.positive)
        b = numpyro.param("b", 1., constraint=constraints.positive)
        numpyro.sample("c", dist.Gamma(a, b))

    svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    svi.update(svi_state)
Exemplo n.º 9
0
def svi(model, guide, num_steps, lr, rng_key, X, Y):
    """
    Helper function for doing SVI inference.
    """
    svi = SVI(model, guide, optim.Adam(lr), ELBO(num_particles=1), X=X, Y=Y)

    svi_state = svi.init(rng_key)
    print('Optimizing...')
    state, loss = lax.scan(lambda x, i: svi.update(x), svi_state,
                           np.zeros(num_steps))

    return loss, svi.get_params(state)
Exemplo n.º 10
0
def test_improper():
    y = random.normal(random.PRNGKey(0), (100,))

    def model(y):
        lambda1 = numpyro.sample('lambda1', dist.ImproperUniform(dist.constraints.real, (), ()))
        lambda2 = numpyro.sample('lambda2', dist.ImproperUniform(dist.constraints.real, (), ()))
        sigma = numpyro.sample('sigma', dist.ImproperUniform(dist.constraints.positive, (), ()))
        mu = numpyro.deterministic('mu', lambda1 + lambda2)
        numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), y=y)
    svi_state = svi.init(random.PRNGKey(2))
    lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
Exemplo n.º 11
0
def test_module():
    x = random.normal(random.PRNGKey(0), (100, 10))
    y = random.normal(random.PRNGKey(1), (100,))

    def model(x, y):
        nn = numpyro.module("nn", Dense(1), (10,))
        mu = nn(x).squeeze(-1)
        sigma = numpyro.sample("sigma", dist.HalfNormal(1))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), x=x, y=y)
    svi_state = svi.init(random.PRNGKey(2))
    lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(1000))
Exemplo n.º 12
0
def test_collapse_beta_binomial():
    total_count = 10
    data = 3.

    def model1():
        c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive)
        c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive)
        with handlers.collapse():
            probs = numpyro.sample("probs", dist.Beta(c1, c0))
            numpyro.sample("obs", dist.Binomial(total_count, probs), obs=data)

    def model2():
        c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive)
        c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive)
        numpyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), obs=data)

    trace1 = handlers.trace(model1).get_trace()
    trace2 = handlers.trace(model2).get_trace()
    assert "probs" in trace1
    assert "obs" not in trace1
    assert "probs" not in trace2
    assert "obs" in trace2

    svi1 = SVI(model1, lambda: None, numpyro.optim.Adam(1), Trace_ELBO())
    svi2 = SVI(model2, lambda: None, numpyro.optim.Adam(1), Trace_ELBO())
    svi_state1 = svi1.init(random.PRNGKey(0))
    svi_state2 = svi2.init(random.PRNGKey(0))
    params1 = svi1.get_params(svi_state1)
    params2 = svi2.get_params(svi_state2)
    assert_allclose(params1["c1"], params2["c1"])
    assert_allclose(params1["c0"], params2["c0"])

    params1 = svi1.get_params(svi1.update(svi_state1)[0])
    params2 = svi2.get_params(svi2.update(svi_state2)[0])
    assert_allclose(params1["c1"], params2["c1"])
    assert_allclose(params1["c0"], params2["c0"])
Exemplo n.º 13
0
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params)
Exemplo n.º 14
0
def run_svi_inference(model, guide, rng_key, X, Y, optimizer, n_epochs=1_000):

    # initialize svi
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    # initialize state
    init_state = svi.init(rng_key, X, Y.squeeze())

    # Run optimizer for 1000 iteratons.
    state, losses = jax.lax.scan(
        lambda state, i: svi.update(state, X, Y.squeeze()), init_state, n_epochs
    )

    # Extract surrogate posterior.
    params = svi.get_params(state)

    return params
Exemplo n.º 15
0
def test_autoguide(deterministic):
    GLOBAL["count"] = 0
    guide = AutoDiagonalNormal(model)
    svi = SVI(model,
              guide,
              optim.Adam(0.1),
              Trace_ELBO(),
              deterministic=deterministic)
    svi_state = svi.init(random.PRNGKey(0))
    svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0],
                              svi_state)
    params = svi.get_params(svi_state)
    guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, ))

    if deterministic:
        assert GLOBAL["count"] == 5
    else:
        assert GLOBAL["count"] == 4
Exemplo n.º 16
0
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5)
Exemplo n.º 17
0
def fit_advi(model, num_iter, learning_rate=0.01, seed=0):
    """Automatic Differentiation Variational Inference using a Normal variational distribution
    with a diagonal covariance matrix.
    """
    rng_key = random.PRNGKey(seed)
    adam = Adam(learning_rate)
    # Automatically create a variational distribution (aka "guide" in Pyro's terminology)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key)

    # Run optimization
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(num_iter))
    results = ADVIResults(svi=svi,
                          guide=guide,
                          state=last_state,
                          losses=losses)
    return results
Exemplo n.º 18
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, jnp.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples['x'].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
Exemplo n.º 19
0
class SVIHandler(Handler):
    """
    Helper object that abstracts some of numpyros complexities. Inspired
    by an implementation of Florian Wilhelm.

    :param model: A numpyro model.
    :param guide: A numpyro guide.
    :param loss: Loss function, defaults to Trace_ELBO.
    :param lr: Learning rate, defaults to 0.001.
    :param lrd: Learning rate decay per step, defaults to 1.0 (no decay)
    :param rng_key: Random seed, defaults to 254.
    :param num_epochs: Number of epochs to train the model, defaults to 5000.
    :param num_samples: Number of posterior samples.
    :param log_func: Logging function, defaults to print.
    :param log_freq: Frequency of logging, defaults to 0 (no logging).
    :param to_numpy: Convert the posterior distribution to numpy array(s),
        defaults to True.
    """
    def __init__(
        self,
        model: Model,
        guide: Guide,
        loss: Trace_ELBO = Trace_ELBO(num_particles=1),
        optimizer: optim.optimizers.optimizer = optim.ClippedAdam,
        lr: float = 0.001,
        lrd: float = 1.0,
        rng_key: int = 254,
        num_epochs: int = 30000,
        num_samples: int = 1000,
        log_func=_print_consumer,
        log_freq=1000,
        to_numpy: bool = True,
    ):
        self.model = model
        self.guide = guide
        self.loss = loss
        self.optimizer = optimizer(step_size=lambda x: lr * lrd**x)
        self.rng_key = random.PRNGKey(rng_key)

        self.svi = SVI(self.model, self.guide, self.optimizer, loss=self.loss)
        self.init_state = None

        self.log_func = log_func
        self.log_freq = log_freq
        self.num_epochs = num_epochs
        self.num_samples = num_samples

        self.loss = None
        self.to_numpy = to_numpy

    def _log(self, epoch, loss, n_digits=4):
        msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}"
        self.log_func(msg)

    def _fit(self, *args):
        def _step(state, i, *args):
            state = lax.cond(
                i % self.log_freq == 0,
                lambda _: host_callback.id_tap(self.log_func,
                                               (i, self.num_epochs),
                                               result=state),
                lambda _: state,
                operand=None,
            )
            return self.svi.update(state, *args)

        return lax.scan(
            lambda state, i: _step(state, i, *args),
            self.init_state,
            jnp.arange(self.num_epochs),
        )

    def _update_state(self, state, loss):
        self.state = state
        self.init_state = state
        self.loss = loss if self.loss is None else jnp.concatenate(
            [self.loss, loss])

    def fit(self, *args, **kwargs):
        self.num_epochs = kwargs.pop("num_epochs", self.num_epochs)
        predictive_kwargs = kwargs.pop("predictive_kwargs", {})

        if self.init_state is None:
            self.init_state = self.svi.init(self.rng_key, *args)

        state, loss = self._fit(*args)
        self._update_state(state, loss)
        self.params = self.svi.get_params(state)

        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=self.num_samples,
            **predictive_kwargs,
        )

        self.posterior = Posterior(predictive(self.rng_key, *args),
                                   self.to_numpy)

        return self

    def predict(self, *args, **kwargs):
        """kwargs -> Predictive, args -> predictive"""
        num_samples = kwargs.pop("num_samples", self.num_samples)
        rng_key = kwargs.pop("rng_key", self.rng_key)

        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=num_samples,
            **kwargs,
        )

        self.predictive = Posterior(predictive(rng_key, *args), self.to_numpy)

    def dump_params(self, file_name: str):
        assert self.params is not None, "'init_svi' needs to be called first"
        pickle.dump(self.params, open(file_name, "wb"))

    def load_params(self, file_name):
        self.params = pickle.load(open(file_name, "rb"))
Exemplo n.º 20
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    adam = optim.Adam(0.01)
    # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model.
    # We will use BNAF instead!
    guide = AutoIAFNormal(dual_moon_model,
                          num_flows=2,
                          hidden_dims=[args.num_hidden, args.num_hidden])
    svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(0), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2),
                                                     dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy,
                                       potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x)
                                                      )  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = vmap(transformed_constrain_fn)(
        guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
Exemplo n.º 21
0
class SVIHandler(Handler):
    def __init__(
        self,
        model: Model,
        guide: Guide,
        loss: Trace_ELBO = Trace_ELBO(num_particles=1),
        optimizer: optim.optimizers.optimizer = optim.Adam,
        lr: float = 0.001,
        rng_key: int = 254,
        num_epochs: int = 100000,
        num_samples: int = 5000,
        log_func=print,
        log_freq=0,
    ):
        self.model = model
        self.guide = guide
        self.loss = loss
        self.optimizer = optimizer(step_size=lr)
        self.rng_key = random.PRNGKey(rng_key)

        self.svi = SVI(self.model, self.guide, self.optimizer, loss=self.loss)
        self.init_state = None

        self.log_func = log_func
        self.log_freq = log_freq
        self.num_epochs = num_epochs
        self.num_samples = num_samples

        self.loss = None

    def _log(self, epoch, loss, n_digits=4):
        msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}"
        self.log_func(msg)

    def _fit(self, epochs, *args):
        return lax.scan(
            lambda state, i: self.svi.update(state, *args),
            self.init_state,
            jnp.arange(epochs),
        )

    def _update_state(self, state, loss):
        self.state = state
        self.init_state = state
        self.loss = loss if self.loss is None else jnp.concatenate([self.loss, loss])

    def fit(self, *args, **kwargs):
        num_epochs = kwargs.pop("num_epochs", self.num_epochs)
        log_freq = kwargs.pop("log_freq", self.log_freq)

        if self.init_state is None:
            self.init_state = self.svi.init(self.rng_key, *args)

        if log_freq <= 0:
            state, loss = self._fit(num_epochs, *args)
            self._update_state(state, loss)
        else:
            steps, rest = num_epochs // log_freq, num_epochs % log_freq

            for step in range(steps):
                state, loss = self._fit(log_freq, *args)
                self._log(log_freq * (step + 1), loss[-1])
                self._update_state(state, loss)

            if rest > 0:
                state, loss = self._fit(rest, *args)
                self._update_state(state, loss)

        self.params = self.svi.get_params(state)
        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=self.num_samples,
            **kwargs,
        )
        self.posterior = predictive(self.rng_key, *args)

    def get_posterior_predictive(self, *args, **kwargs):
        """kwargs -> Predictive, args -> predictive"""
        num_samples = kwargs.pop("num_samples", self.num_samples)

        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=num_samples,
            **kwargs,
        )
        self.posterior_predictive = predictive(self.rng_key, *args)
Exemplo n.º 22
0
class ModelHandler(object):
    def __init__(self,
                 model: Model,
                 guide: Guide,
                 rng_key: int = 0,
                 *,
                 loss: ELBO = ELBO(num_particles=1),
                 optim_builder: optim.optimizers.optimizer = optim.Adam):
        """Handling the model and guide for training and prediction

        Args:
            model: function holding the numpyro model
            guide: function holding the numpyro guide
            rng_key: random key as int
            loss: loss to optimize
            optim_builder: builder for an optimizer
        """
        self.model = model
        self.guide = guide
        self.rng_key = random.PRNGKey(rng_key)  # current random key
        self.loss = loss
        self.optim_builder = optim_builder
        self.svi = None
        self.svi_state = None
        self.optim = None
        self.log_func = print  # overwrite e.g. logger.info(...)

    def reset_svi(self):
        """Reset the current SVI state"""
        self.svi = None
        self.svi_state = None
        return self

    def init_svi(self, X: DeviceArray, *, lr: float, **kwargs):
        """Initialize the SVI state

        Args:
            X: input data
            lr: learning rate
            kwargs: other keyword arguments for optimizer
        """
        self.optim = self.optim_builder(lr, **kwargs)
        self.svi = SVI(self.model, self.guide, self.optim, self.loss)
        svi_state = self.svi.init(self.rng_key, X)
        if self.svi_state is None:
            self.svi_state = svi_state
        return self

    @property
    def optim_state(self) -> OptimizerState:
        """Current optimizer state"""
        assert self.svi_state is not None, "'init_svi' needs to be called first"
        return self.svi_state.optim_state

    @optim_state.setter
    def optim_state(self, state: OptimizerState):
        """Set current optimizer state"""
        self.svi_state = SVIState(state, self.rng_key)

    def dump_optim_state(self, fh: IO):
        """Pickle and dump optimizer state to file handle"""
        pickle.dump(
            optim.optimizers.unpack_optimizer_state(self.optim_state[1]), fh)
        return self

    def load_optim_state(self, fh: IO):
        """Read and unpickle optimizer state from file handle"""
        state = optim.optimizers.pack_optimizer_state(pickle.load(fh))
        iter0 = jnp.array(0)
        self.optim_state = (iter0, state)
        return self

    @property
    def optim_total_steps(self) -> int:
        """Returns the number of performed iterations in total"""
        return int(self.optim_state[0])

    def _fit(self, X: DeviceArray, n_epochs) -> float:
        @jit
        def train_epochs(svi_state, n_epochs):
            def train_one_epoch(_, val):
                loss, svi_state = val
                svi_state, loss = self.svi.update(svi_state, X)
                return loss, svi_state

            return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state))

        loss, self.svi_state = train_epochs(self.svi_state, n_epochs)
        return float(loss / X.shape[0])

    def _log(self, n_digits, epoch, loss):
        msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}"
        self.log_func(msg)

    def fit(self,
            X: DeviceArray,
            *,
            n_epochs: int,
            log_freq: int = 0,
            lr: float,
            **kwargs) -> float:
        """Train but log with a given frequency

        Args:
            X: input data
            n_epochs: total number of epochs
            log_freq: log loss every log_freq number of eppochs
            lr: learning rate
            kwargs: parameters of `init_svi`

        Returns:
            final loss of last epoch
        """
        self.init_svi(X, lr=lr, **kwargs)
        if log_freq <= 0:
            self._fit(X, n_epochs)
        else:
            loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]

            curr_epoch = 0
            n_digits = len(str(abs(n_epochs)))
            self._log(n_digits, curr_epoch, loss)

            for i in range(n_epochs // log_freq):
                curr_epoch += log_freq
                loss = self._fit(X, log_freq)
                self._log(n_digits, curr_epoch, loss)

            rest = n_epochs % log_freq
            if rest > 0:
                curr_epoch += rest

                loss = self._fit(X, rest)
                self._log(n_digits, curr_epoch, loss)

        loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]
        self.rng_key = self.svi_state.rng_key
        return float(loss)

    @property
    def model_params(self) -> Optional[Dict[str, DeviceArray]]:
        """Gets model parameters

        Returns:
            dict of model parameters
        """
        if self.svi is not None:
            return self.svi.get_params(self.svi_state)
        else:
            return None

    def predict(self, X: DeviceArray, **kwargs) -> DeviceArray:
        """Predict the parameters of a model specified by `return_sites`

        Args:
            X: input data
            kwargs: keyword arguments for numpro `Predictive`

        Returns:
            samples for all sample sites
        """
        self.init_svi(X, lr=0.)  # dummy initialization
        predictive = Predictive(self.model,
                                guide=self.guide,
                                params=self.model_params,
                                **kwargs)
        samples = predictive(self.rng_key, X)
        return samples
Exemplo n.º 23
0
    def _train_full_data(self,
                         x_data,
                         obs2sample,
                         n_epochs=20000,
                         lr=0.002,
                         progressbar=True,
                         random_seed=1):

        idx = np.arange(x_data.shape[0]).astype("int64")

        # move data to default device
        x_data = device_put(jnp.array(x_data))
        extra_data = {
            'idx': device_put(jnp.array(idx)),
            'obs2sample': device_put(jnp.array(obs2sample))
        }

        # initialise SVI inference method
        svi = SVI(
            self.model.forward,
            self.guide,
            # limit the gradient step from becoming too large
            optim.ClippedAdam(clip_norm=jnp.array(200),
                              **{'step_size': jnp.array(lr)}),
            loss=Trace_ELBO())
        init_state = svi.init(random.PRNGKey(random_seed),
                              x_data=x_data,
                              **extra_data)
        self.state = init_state

        if not progressbar:
            # Training in one step
            epochs_iterator = tqdm(range(1))
            for e in epochs_iterator:
                state, losses = lax.scan(
                    lambda state_1, i: svi.update(
                        state_1, x_data=self.x_data, **extra_data),
                    # TODO for minibatch DataLoader goes here
                    init_state,
                    jnp.arange(n_epochs))
                # print(state)
                epochs_iterator.set_description(
                    'ELBO Loss: ' + '{:.4e}'.format(losses[::-1][0]))

            self.state = state
            self.hist = losses

        else:
            # training using for-loop

            jit_step_update = jit(lambda state_1: svi.update(
                state_1, x_data=x_data, **extra_data))
            # TODO figure out minibatch static_argnums https://github.com/pyro-ppl/numpyro/issues/869

            ### very slow
            epochs_iterator = tqdm(range(n_epochs))
            for e in epochs_iterator:
                self.state, loss = jit_step_update(self.state)
                self.hist.append(loss)
                epochs_iterator.set_description('ELBO Loss: ' +
                                                '{:.4e}'.format(loss))

        self.state_param = svi.get_params(self.state).copy()