Beispiel #1
0
def test_binomial_stable(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs',
                           dist.Binomial(data['n'], logits=logits),
                           obs=data['x'])
        else:
            numpyro.sample('obs',
                           dist.Binomial(data['n'], probs=p),
                           obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn)
    hmc_state = init_kernel(init_params, num_warmup=warmup_steps)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))

    assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p'].dtype == np.float64
Beispiel #2
0
def test_change_point():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = sample('lambda1', dist.Exponential(alpha))
        lambda2 = sample('lambda2', dist.Exponential(alpha))
        tau = sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(4), model, count_data)
    init_kernel, sample_kernel = hmc(potential_fn)
    hmc_state = init_kernel(init_params, num_warmup=warmup_steps)
    samples = fori_collect(num_samples, sample_kernel, hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    tau_posterior = (samples['tau'] * len(count_data)).astype("int")
    tau_values, counts = onp.unique(tau_posterior, return_counts=True)
    mode_ind = np.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['lambda1'].dtype == np.float64
        assert samples['lambda2'].dtype == np.float64
        assert samples['tau'].dtype == np.float64
Beispiel #3
0
def test_beta_bernoulli(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps,
                            progbar=False)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z),
                           progbar=False)
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Beispiel #4
0
def test_dirichlet_categorical(algo, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = sample('p_latent', dist.Dirichlet(concentration))
        sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps,
                            progbar=False,
                            dense_mass=dense_mass)
    hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
                              transform=lambda x: constrain_fn(x.z),
                              progbar=False)
    assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert hmc_states['p_latent'].dtype == np.float64
Beispiel #5
0
def run_inference(dept, male, applications, admit, rng, args):
    init_params, potential_fn, constrain_fn = initialize_model(
        rng, glmm, dept, male, applications, admit)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup_steps)
    hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: constrain_fn(hmc_state.z))
    return hmc_states
Beispiel #6
0
def main(args):
    jax_config.update('jax_platform_name', args.device)
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng, sample_rng = random.split(random.PRNGKey(args.rng))
    init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng)
    hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: constrain_fn(hmc_state.z))
    print_results(hmc_states, dates)
Beispiel #7
0
def run_inference(transition_prior, emission_prior, supervised_categories,
                  supervised_words, unsupervised_words, rng, args):
    init_params, potential_fn, constrain_fn = initialize_model(
        rng,
        semi_supervised_hmm,
        transition_prior,
        emission_prior,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup)
    hmc_states = fori_collect(args.num_samples,
                              sample_kernel,
                              hmc_state,
                              transform=lambda state: constrain_fn(state.z))
    return hmc_states
Beispiel #8
0
def test_unnormalized_normal(algo):
    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = np.array(0.)
    hmc_state = init_kernel(init_params,
                            trajectory_length=9,
                            num_warmup=warmup_steps)
    hmc_states = fori_collect(num_samples,
                              sample_kernel,
                              hmc_state,
                              transform=lambda x: x.z)
    assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
    assert_allclose(np.std(hmc_states), true_std, rtol=0.05)
Beispiel #9
0
def benchmark_hmc(args, features, labels):
    trajectory_length = step_size * args.num_steps
    _, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels)
    init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo)
    t0 = time.time()
    # TODO: Use init_params from `initialize_model` instead of fixed params.
    hmc_state, _, _ = init_kernel(init_params, num_warmup=0, step_size=step_size,
                                  trajectory_length=trajectory_length,
                                  adapt_step_size=False, run_warmup=False)
    t1 = time.time()
    print("time for hmc_init: ", t1 - t0)

    def transform(state): return {'coefs': state.z['coefs'],
                                  'num_steps': state.num_steps}

    hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=transform)
    num_leapfrogs = np.sum(hmc_states['num_steps'])
    print('number of leapfrog steps: ', num_leapfrogs)
    print('avg. time for each step: ', (time.time() - t1) / num_leapfrogs)
Beispiel #10
0
def test_map(algo, map_fn):
    if map_fn is pmap and xla_bridge.device_count() == 1:
        pytest.skip('pmap test requires device_count greater than 1.')

    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = np.array([0., -1.])
    rngs = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(
        lambda init_param, rng: init_kernel(init_param,
                                            trajectory_length=9,
                                            num_warmup=warmup_steps,
                                            progbar=False,
                                            rng=rng))
    init_states = init_kernel_map(init_params, rngs)

    fori_collect_map = map_fn(
        lambda hmc_state: fori_collect(0,
                                       num_samples,
                                       sample_kernel,
                                       hmc_state,
                                       transform=lambda x: x.z,
                                       progbar=False))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(np.mean(chain_samples, axis=1),
                    np.repeat(true_mean, 2),
                    rtol=0.05)
    assert_allclose(np.std(chain_samples, axis=1),
                    np.repeat(true_std, 2),
                    rtol=0.05)
Beispiel #11
0
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath('../..'))

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.mcmc import hmc  # noqa: E402
from numpyro.svi import svi  # noqa: E402

os.environ['SPHINX_BUILD'] = '1'
hmc(None, None)
svi(None, None, None, None, None, None)

# -- Project information -----------------------------------------------------

project = u'Numpyro'
copyright = u'2019, Uber Technologies, Inc'
author = u'Uber AI Labs'

# The short X.Y version
version = u'0.0'
# The full version, including alpha/beta/rc tags
release = u'0.0'

# -- General configuration ---------------------------------------------------