Exemplo n.º 1
0
def test_beta_bernoulli():
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = sample("beta", dist.Beta(1., 1.))
        sample("obs", dist.Bernoulli(f), obs=data)

    def guide():
        alpha_q = param("alpha_q", 1.0, constraint=constraints.positive)
        beta_q = param("beta_q", 1.0, constraint=constraints.positive)
        sample("beta", dist.Beta(alpha_q, beta_q))

    opt_init, opt_update, get_params = optimizers.adam(0.05)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    rng_init, rng_train = random.split(random.PRNGKey(1))
    opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 300, body_fn, (opt_state, rng_train))

    params = constrain_fn(get_params(opt_state))
    assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']),
                    0.8,
                    atol=0.05,
                    rtol=0.05)
Exemplo n.º 2
0
def test_beta_bernoulli(auto_class):
    data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = sample('beta', dist.Beta(np.ones(2), np.ones(2)))
        sample('obs', dist.Bernoulli(f), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               opt_state,
                                               sample_shape=(1000, ))
    assert_allclose(np.mean(posterior_samples['beta'], 0),
                    true_coefs,
                    atol=0.04)
Exemplo n.º 3
0
def test_uniform_normal():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, alpha))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = AutoDiagonalNormal(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    median = guide.median(opt_state)
    assert_allclose(median['loc'], true_coef, rtol=0.05)
    # test .quantile method
    median = guide.quantiles(opt_state, [0.2, 0.5])
    assert_allclose(median['loc'][1], true_coef, rtol=0.1)
Exemplo n.º 4
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params)
    rng = PRNGKey(0)
    opt_state = svi_init(rng, model_args=(data,))

    # Training loop
    rng, = random.split(rng, 1)

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i, opt_state_, rng_, model_args=(data,))
        return opt_state_, rng_

    opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng))

    # Report the final values of the variational parameters
    # in the guide after training.
    params = get_params(opt_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1
Exemplo n.º 5
0
def test_beta_bernoulli(auto_class, rtol):
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = sample('beta', dist.Beta(1., 1.))
        sample('obs', dist.Bernoulli(f), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.08)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = lax.fori_loop(0, 300, body_fn, (opt_state, rng_train))
    median = guide.median(opt_state)
    assert_allclose(median['beta'], 0.8, rtol=rtol)
Exemplo n.º 6
0
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def actual_model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, alpha))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, 1)) * alpha
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)

    guide = AutoDiagonalNormal(rng_guide, actual_model, get_params)
    svi_init, _, svi_eval = svi(actual_model, guide, elbo, opt_init,
                                opt_update, get_params)
    opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, ))
    actual_params = get_params(opt_state)
    actual_base_values = constrain_fn(actual_params)
    actual_values = guide.median(opt_state)
    actual_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, ))

    guide = AutoDiagonalNormal(rng_guide, expected_model, get_params)
    svi_init, _, svi_eval = svi(expected_model, guide, elbo, opt_init,
                                opt_update, get_params)
    opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, ))
    expected_params = get_params(opt_state)
    expected_base_values = constrain_fn(expected_params)
    expected_values = guide.median(opt_state)
    expected_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, ))

    check_eq(actual_params, expected_params)
    check_eq(actual_base_values, expected_base_values)
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc'],
                    expected_values['alpha'] * expected_values['loc'])
    assert_allclose(actual_loss, expected_loss)
Exemplo n.º 7
0
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, labels),
                                       guide_args=(data, labels))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, labels),
                                            guide_args=(data, labels))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    if auto_class is not AutoIAFNormal:
        median = guide.median(opt_state)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(opt_state, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               opt_state,
                                               sample_shape=(1000, ))
    assert_allclose(np.mean(posterior_samples['coefs'], 0),
                    true_coefs,
                    rtol=0.1)
Exemplo n.º 8
0
def test_dynamic_constraints():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        # NB: model's constraints will play no effect
        loc = param('loc', 0., constraint=constraints.interval(0, 0.5))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    def guide():
        alpha = param('alpha', 0.5, constraint=constraints.unit_interval)
        param('loc', 0, constraint=constraints.interval(0, alpha))

    opt_init, opt_update, get_params = optimizers.adam(0.05)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    rng_init, rng_train = random.split(random.PRNGKey(1))
    opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ))
        return opt_state_, rng_

    opt_state, rng = fori_loop(0, 300, body_fn, (opt_state, rng_train))
    params = get_param(opt_state,
                       model,
                       guide,
                       get_params,
                       constrain_fn,
                       rng,
                       guide_args=())
    assert_allclose(params['loc'], true_coef, atol=0.05)
Exemplo n.º 9
0
def main(args):
    encoder_init, encode = encoder(args.hidden_dim, args.z_dim)
    decoder_init, decode = decoder(args.hidden_dim, 28 * 28)
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
    svi_init, svi_update, svi_eval = svi(model, guide, elbo, opt_init, opt_update, get_params,
                                         encode=encode, decode=decode, z_dim=args.z_dim)
    svi_update = jit(svi_update)
    rng = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train')
    test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test')
    num_train, train_idx = train_init()
    rng, rng_enc, rng_dec = random.split(rng, 3)
    _, encoder_params = encoder_init(rng_enc, (args.batch_size, 28 * 28))
    _, decoder_params = decoder_init(rng_dec, (args.batch_size, args.z_dim))
    params = {'encoder': encoder_params, 'decoder': decoder_params}
    rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0])
    opt_state, constrain_fn = svi_init(rng, (sample_batch,), (sample_batch,), params)
    rng, = random.split(rng, 1)

    @jit
    def epoch_train(opt_state, rng):
        def body_fn(i, val):
            loss_sum, opt_state, rng = val
            rng, batch = binarize(rng, train_fetch(i, train_idx)[0])
            loss, opt_state, rng = svi_update(i, rng, opt_state, (batch,), (batch,),)
            loss_sum += loss
            return loss_sum, opt_state, rng

        return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))

    @jit
    def eval_test(opt_state, rng):
        def body_fun(i, val):
            loss_sum, rng = val
            rng, = random.split(rng, 1)
            rng, batch = binarize(rng, test_fetch(i, test_idx)[0])
            loss = svi_eval(rng, opt_state, (batch,), (batch,)) / len(batch)
            loss_sum += loss
            return loss_sum, rng

        loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng))
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')
        _, test_sample = binarize(rng, img)
        params = get_params(opt_state)
        z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng)
        img_loc = decode(params['decoder'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')

    for i in range(args.num_epochs):
        t_start = time.time()
        num_train, train_idx = train_init()
        _, opt_state, rng = epoch_train(opt_state, rng)
        rng, rng_test = random.split(rng, 2)
        num_test, test_idx = test_init()
        test_loss = eval_test(opt_state, rng_test)
        reconstruct_img(i)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
Exemplo n.º 10
0
def main(args):
    jax_config.update('jax_platform_name', args.device)

    print("Start vanilla HMC...")
    vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]),
                           potential_fn=dual_moon_pe, progbar=True)

    opt_init, opt_update, get_params = optimizers.adam(0.001)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden])
    svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params)
    opt_state, _ = svi_init(rng_init)

    def body_fn(val, i):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_)
        return (opt_state_, rng_), loss

    print("Start training guide...")
    (last_state, _), losses = lax.scan(body_fn, (opt_state, rng_train), np.arange(args.num_iters))
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state,
                                           sample_shape=(args.num_samples,))

    transform = guide.get_transform(last_state)
    unpack_fn = guide.unpack_latent

    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model)
    transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
    transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x)))  # noqa: E731

    init_params = np.zeros(guide.latent_size)
    print("\nStart NeuTra HMC...")
    zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn)
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    summary(tree_map(lambda x: x[None, ...], samples))

    # make plots

    # IAF guide samples (for plotting)
    iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,))
    iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.)

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.)
    ax3.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
    ax4.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0.,
                    s=30, alpha=0.5, edgecolor="none")
    ax5.set(xlim=[-5, 5], ylim=[-5, 5],
            xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6)
    ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
Exemplo n.º 11
0
# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath('../..'))

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.mcmc import hmc  # noqa: E402
from numpyro.svi import svi  # noqa: E402

os.environ['SPHINX_BUILD'] = '1'
hmc(None, None)
svi(None, None, None, None, None, None)

# -- Project information -----------------------------------------------------

project = u'Numpyro'
copyright = u'2019, Uber Technologies, Inc'
author = u'Uber AI Labs'

# The short X.Y version
version = u'0.0'
# The full version, including alpha/beta/rc tags
release = u'0.0'

# -- General configuration ---------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.