예제 #1
0
def nuts(data, model, seed=None, iter=None, warmup=None, num_chains=None):
    assert type(data) == dict
    assert type(model) == Model
    assert seed is None or type(seed) == int

    iter, warmup, num_chains = apply_default_hmc_args(iter, warmup, num_chains)

    if seed is None:
        seed = np.random.randint(0, 2**32, dtype=np.uint32).astype(np.int32)
    rng = random.PRNGKey(seed)

    kernel = NUTS(model.fn)
    # TODO: We could use a way of avoid requiring users to set
    # `--xla_force_host_platform_device_count` manually when
    # `num_chains` > 1 to achieve parallel chains.
    mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains)
    mcmc.run(rng, **data)
    samples = mcmc.get_samples()

    # Here we re-run the model on the samples in order to collect
    # transformed parameters. (e.g. `b`, `mu`, etc.) Theses are made
    # available via the return value of the model.
    transformed_samples = run_model_on_samples_and_data(
        model.fn, samples, data)
    all_samples = dict(samples, **transformed_samples)

    loc = partial(location, data, samples, transformed_samples, model.fn)

    return Samples(all_samples, partial(get_param, all_samples), loc)
예제 #2
0
def test_binomial_stable(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs',
                           dist.Binomial(data['n'], logits=logits),
                           obs=data['x'])
        else:
            numpyro.sample('obs',
                           dist.Binomial(data['n'], probs=p),
                           obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p'].dtype == np.float64
예제 #3
0
파일: funnel.py 프로젝트: hdocmsu/numpyro
def run_inference(model, args, rng):
    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                args.num_warmup,
                args.num_samples,
                num_chains=args.num_chains)
    mcmc.run(rng)
    return mcmc.get_samples()
예제 #4
0
파일: bnn.py 프로젝트: hdocmsu/numpyro
def run_inference(model, args, rng, X, Y, D_H):
    if args.num_chains > 1:
        rng = random.split(rng, args.num_chains)
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
    mcmc.run(rng, X, Y, D_H)
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
예제 #5
0
def benchmark_hmc(args, features, labels):
    step_size = np.sqrt(0.5 / features.shape[0])
    trajectory_length = step_size * args.num_steps
    rng = random.PRNGKey(1)
    start = time.time()
    kernel = NUTS(model, trajectory_length=trajectory_length)
    mcmc = MCMC(kernel, 0, args.num_samples)
    mcmc.run(rng, features, labels)
    print('\nMCMC elapsed time:', time.time() - start)
예제 #6
0
def run_inference(model, args, rng, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                args.num_warmup,
                args.num_samples,
                num_chains=args.num_chains)
    mcmc.run(rng, X, Y)
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
예제 #7
0
def test_improper_prior():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 1000, 8000

    def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05)
    assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
예제 #8
0
def test_improper_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.param('loc',
                            0.,
                            constraint=constraints.interval(0., alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
예제 #9
0
def test_predictive():
    model, data, true_probs = beta_bernoulli()
    mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()

    predictive_samples = predictive(random.PRNGKey(1), model, samples)
    assert predictive_samples.keys() == {"obs"}

    predictive_samples = predictive(random.PRNGKey(1), model, samples,
                                    return_sites=["beta", "obs"])

    # check shapes
    assert predictive_samples["beta"].shape == (100, 5)
    assert predictive_samples["obs"].shape == (100, 1000, 5)

    # check sample mean
    assert_allclose(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1)
예제 #10
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2),
             data,
             collect_warmup=True,
             collect_fields=('z', 'num_steps', 'adapt_state.step_size'))
    samples = mcmc.get_samples()
    assert len(samples[0]['loc']) == num_warmup + num_samples
    assert_allclose(np.mean(samples[0]['loc'], 0), true_coef, atol=0.05)
예제 #11
0
def test_prior_with_sample_shape():
    data = {
        "J": 8,
        "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
        "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
    }

    def schools_model():
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        theta = numpyro.sample('theta',
                               dist.Normal(mu, tau),
                               sample_shape=(data['J'], ))
        numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])

    num_samples = 500
    mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
예제 #12
0
파일: hmm.py 프로젝트: hdocmsu/numpyro
def main(args):
    jax_config.update('jax_platform_name', args.device)
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words,
     unsupervised_words) = simulate_data(
         random.PRNGKey(1),
         num_categories=args.num_categories,
         num_words=args.num_words,
         num_supervised_data=args.num_supervised,
         num_unsupervised_data=args.num_unsupervised,
     )
    print('Starting inference...')
    rng = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
    mcmc.run(rng, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words)
    samples = mcmc.get_samples()
    print('\nMCMC elapsed time:', time.time() - start)
    print_results(samples, transition_prob, emission_prob)
예제 #13
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = np.tril(0.5 * np.fliplr(np.eye(D)) +
                0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = np.dot(a, a.T)
    true_prec = np.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * np.dot(z.T, np.dot(true_prec, z))

    init_params = np.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples), true_mean, atol=0.02)
    assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
예제 #14
0
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': np.tile(np.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples = mcmc.get_samples()
    assert samples['coefs'].shape[0] == num_chains * num_samples
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
예제 #15
0
파일: ucbadmit.py 프로젝트: hdocmsu/numpyro
def run_inference(dept, male, applications, admit, rng, args):
    kernel = NUTS(glmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains)
    mcmc.run(rng, dept, male, applications, admit)
    return mcmc.get_samples()
예제 #16
0
def main(args):
    jax_config.update('jax_platform_name', args.device)

    print("Start vanilla HMC...")
    nuts_kernel = NUTS(potential_fn=dual_moon_pe)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.]))
    vanilla_samples = mcmc.get_samples()

    adam = optim.Adam(0.001)
    rng_init, rng_train = random.split(random.PRNGKey(1), 2)
    guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True)
    svi = SVI(dual_moon_model, guide, elbo, adam)
    svi_state = svi.init(rng_init)

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))

    transform = guide.get_transform(params)
    unpack_fn = guide.unpack_latent

    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model)
    transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
    transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x)))  # noqa: E731

    init_params = np.zeros(guide.latent_size)
    print("\nStart NeuTra HMC...")
    # TODO: exlore why neutra samples are not good
    # Issue: https://github.com/pyro-ppl/numpyro/issues/256
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(10), init_params=init_params)
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    summary(tree_map(lambda x: x[None, ...], samples))

    # make plots

    # IAF guide samples (for plotting)
    iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,))
    iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.)

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.)
    ax3.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
    ax4.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0.,
                    s=30, alpha=0.5, edgecolor="none")
    ax5.set(xlim=[-5, 5], ylim=[-5, 5],
            xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6)
    ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
예제 #17
0
def test_change_point():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(
            np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,
        24,
        8,
        24,
        7,
        35,
        14,
        11,
        15,
        11,
        22,
        22,
        11,
        57,
        11,
        19,
        29,
        6,
        19,
        12,
        22,
        12,
        18,
        72,
        32,
        9,
        7,
        13,
        19,
        23,
        27,
        20,
        6,
        17,
        13,
        10,
        14,
        6,
        16,
        15,
        7,
        2,
        15,
        15,
        19,
        70,
        49,
        7,
        53,
        22,
        21,
        31,
        19,
        11,
        18,
        20,
        12,
        35,
        17,
        23,
        17,
        4,
        2,
        31,
        30,
        13,
        27,
        0,
        39,
        37,
        5,
        14,
        13,
        22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32)
    tau_values, counts = onp.unique(tau_posterior, return_counts=True)
    mode_ind = np.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['lambda1'].dtype == np.float64
        assert samples['lambda2'].dtype == np.float64
        assert samples['tau'].dtype == np.float64
예제 #18
0
파일: baseball.py 프로젝트: hdocmsu/numpyro
import argparse
import os

import numpy as onp

from jax.config import config as jax_config
import jax.numpy as np
import jax.random as random
from jax.scipy.special import logsumexp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.examples.datasets import BASEBALL, load_dataset
from numpyro.mcmc import MCMC, NUTS
"""
Original example from Pyro:
https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py

Example has been adapted from [1]. It demonstrates how to do Bayesian inference using
NUTS (or, HMC) in Pyro, and use of some common inference utilities.

As in the Stan tutorial, this uses the small baseball dataset of Efron and Morris [2]
to estimate players' batting average which is the fraction of times a player got a
base hit out of the number of times they went up at bat.

The dataset separates the initial 45 at-bats statistics from the remaining season.
We use the hits data from the initial 45 at-bats to estimate the batting average
for each player. We then use the remaining season's data to validate the predictions
from our models.