コード例 #1
0
def test_chain():
    N, dim = 3000, 3
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    rngs = random.split(random.PRNGKey(2), 2)
    init_params, potential_fn, constrain_fn = initialize_model(
        rngs, model, labels)
    samples = mcmc(num_warmup,
                   num_samples,
                   init_params,
                   num_chains=2,
                   potential_fn=potential_fn,
                   constrain_fn=constrain_fn)

    assert samples['coefs'].shape[0] == 2 * num_samples
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
コード例 #2
0
def run_inference(model, at_bats, hits, rng, args):
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, constrain_fn = initialize_model(rng, model, at_bats, hits)
    hmc_states = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
                      sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
    return hmc_states
コード例 #3
0
def test_logistic_regression(algo):
    N, dim = 3000, 3
    warmup_steps, num_samples = 1000, 8000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, labels)
    samples = mcmc(warmup_steps,
                   num_samples,
                   init_params,
                   sampler='hmc',
                   algo=algo,
                   potential_fn=potential_fn,
                   trajectory_length=10,
                   constrain_fn=constrain_fn)
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['coefs'].dtype == np.float64
コード例 #4
0
ファイル: hmm.py プロジェクト: Anthonymcqueen21/numpyro
def main(args):
    jax_config.update('jax_platform_name', args.device)
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words,
     unsupervised_words) = simulate_data(
         random.PRNGKey(1),
         num_categories=args.num_categories,
         num_words=args.num_words,
         num_supervised_data=args.num_supervised,
         num_unsupervised_data=args.num_unsupervised,
     )
    print('Starting inference...')
    rng = random.PRNGKey(2)
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, constrain_fn = initialize_model(
        rng,
        semi_supervised_hmm,
        transition_prior,
        emission_prior,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
    start = time.time()
    samples = mcmc(args.num_warmup,
                   args.num_samples,
                   init_params,
                   num_chains=args.num_chains,
                   potential_fn=potential_fn,
                   constrain_fn=constrain_fn,
                   progbar=True)
    print('\nMCMC elapsed time:', time.time() - start)
    print_results(samples, transition_prob, emission_prob)
コード例 #5
0
def test_dirichlet_categorical(algo, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    samples = mcmc(warmup_steps,
                   num_samples,
                   init_params,
                   constrain_fn=constrain_fn,
                   progbar=False,
                   print_summary=False,
                   potential_fn=potential_fn,
                   algo=algo,
                   trajectory_length=1.,
                   dense_mass=dense_mass)
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
コード例 #6
0
ファイル: funnel.py プロジェクト: leej35/numpyro
def run_inference(model, args, rng):
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, constrain_fn = initialize_model(rng, model)
    samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
                   potential_fn=potential_fn, constrain_fn=constrain_fn)
    return samples
コード例 #7
0
ファイル: covtype.py プロジェクト: leej35/numpyro
def benchmark_hmc(args, features, labels):
    step_size = np.sqrt(0.5 / features.shape[0])
    trajectory_length = step_size * args.num_steps
    rng = random.PRNGKey(1)
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, _ = initialize_model(rng, model, features,
                                                    labels)

    start = time.time()
    mcmc(0,
         args.num_samples,
         init_params,
         num_chains=args.num_chains,
         potential_fn=potential_fn,
         trajectory_length=trajectory_length)
    print('\nMCMC elapsed time:', time.time() - start)
コード例 #8
0
def run_inference(dept, male, applications, admit, rng, args):
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, constrain_fn = initialize_model(
        rng, glmm, dept, male, applications, admit)
    samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
                   potential_fn=potential_fn, constrain_fn=constrain_fn)
    return samples
コード例 #9
0
def run_inference(model, args, rng, X, Y, hypers):
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, hypers)
    start = time.time()
    samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
                   sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
    print('\nMCMC elapsed time:', time.time() - start)
    return samples
コード例 #10
0
ファイル: bnn.py プロジェクト: mcgrady20150318/numpyro
def run_inference(model, args, rng, X, Y, D_H):
    init_params, potential_fn, constrain_fn = initialize_model(
        rng, model, X, Y, D_H)
    samples = mcmc(args.num_warmup,
                   args.num_samples,
                   init_params,
                   sampler='hmc',
                   potential_fn=potential_fn,
                   constrain_fn=constrain_fn)
    return samples
コード例 #11
0
ファイル: test_mcmc.py プロジェクト: juvu/numpyro
def test_improper_prior():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 1000, 8000

    def model(data):
        mean = param('mean', 0.)
        std = param('std', 1., constraint=constraints.positive)
        return sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,))
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
    samples = mcmc(num_warmup, num_samples, init_params, potential_fn=potential_fn,
                   constrain_fn=constrain_fn)
    assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05)
    assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
コード例 #12
0
def test_uniform_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    samples = mcmc(1000,
                   1000,
                   init_params,
                   potential_fn=potential_fn,
                   constrain_fn=constrain_fn)
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #13
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = np.dot(a, a.T)
    true_prec = np.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * np.dot(z.T, np.dot(true_prec, z))

    init_params = np.zeros(D)
    samples = mcmc(warmup_steps, num_samples, init_params, potential_fn=potential_fn, dense_mass=True)
    assert_allclose(np.mean(samples), true_mean, atol=0.02)
    assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
コード例 #14
0
ファイル: hmm.py プロジェクト: neerajprad/numpyro
def main(args):
    jax_config.update('jax_platform_name', args.device)
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words, unsupervised_words) = simulate_data(
        random.PRNGKey(1),
        num_categories=args.num_categories,
        num_words=args.num_words,
        num_supervised_data=args.num_supervised,
        num_unsupervised_data=args.num_unsupervised,
    )
    print('Starting inference...')
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2),
        semi_supervised_hmm,
        transition_prior, emission_prior, supervised_categories,
        supervised_words, unsupervised_words,
    )
    samples = mcmc(args.num_warmup, args.num_samples, init_params,
                   potential_fn=potential_fn, constrain_fn=constrain_fn)
    print_results(samples, transition_prob, emission_prob)
コード例 #15
0
def main(args):
    jax_config.update('jax_platform_name', args.device)

    print("Start vanilla HMC...")
    vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]),
                           potential_fn=dual_moon_pe, progbar=True)

    opt_init, opt_update, get_params = optimizers.adam(0.001)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden])
    svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params)
    opt_state, _ = svi_init(rng_init)

    def body_fn(val, i):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_)
        return (opt_state_, rng_), loss

    print("Start training guide...")
    (last_state, _), losses = lax.scan(body_fn, (opt_state, rng_train), np.arange(args.num_iters))
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state,
                                           sample_shape=(args.num_samples,))

    transform = guide.get_transform(last_state)
    unpack_fn = guide.unpack_latent

    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model)
    transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
    transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x)))  # noqa: E731

    init_params = np.zeros(guide.latent_size)
    print("\nStart NeuTra HMC...")
    zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn)
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    summary(tree_map(lambda x: x[None, ...], samples))

    # make plots

    # IAF guide samples (for plotting)
    iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,))
    iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.)

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.)
    ax3.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
    ax4.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0.,
                    s=30, alpha=0.5, edgecolor="none")
    ax5.set(xlim=[-5, 5], ylim=[-5, 5],
            xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6)
    ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()