Exemple #1
0
def test_functional_beta_bernoulli_x64(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, model_args=(data, ))
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Exemple #2
0
def test_reuse_mcmc_pe_gen():
    y1 = onp.random.normal(3, 0.1, (100, ))
    y2 = onp.random.normal(-3, 0.1, (100, ))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(0), model, y1, dynamic_args=True)
    init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn)
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, ))

    @jit
    def _sample(state_and_args):
        hmc_state, model_args = state_and_args
        return sample_kernel(hmc_state, (model_args, )), model_args

    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y1),
                           transform=lambda state: constrain_fn(y1)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), 3., atol=0.1)

    # Run on data, re-using `mcmc` - this should be much faster.
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, ))
    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y2),
                           transform=lambda state: constrain_fn(y2)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), -3., atol=0.1)
Exemple #3
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns, ))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info,
                            args.num_warmup,
                            rng_key=sample_rng_key)
    hmc_states = fori_collect(
        args.num_warmup,
        args.num_warmup + args.num_samples,
        sample_kernel,
        hmc_state,
        transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
        progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(1, 1)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time',
           ylabel='returns',
           title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
    plt.tight_layout()
Exemple #4
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    init_params, potential_fn, constrain_fn = initialize_model(
        init_rng_key, model, returns)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params,
                            args.num_warmup,
                            rng_key=sample_rng_key)
    hmc_states = fori_collect(
        0,
        args.num_samples,
        sample_kernel,
        hmc_state,
        transform=lambda hmc_state: constrain_fn(hmc_state.z))
    print_results(hmc_states, dates)
Exemple #5
0
def test_functional_map(algo, map_fn):
    if map_fn is pmap and xla_bridge.device_count() == 1:
        pytest.skip('pmap test requires device_count greater than 1.')

    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = np.array([0., -1.])
    rng_keys = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(
        lambda init_param, rng_key: init_kernel(init_param,
                                                trajectory_length=9,
                                                num_warmup=warmup_steps,
                                                rng_key=rng_key))
    init_states = init_kernel_map(init_params, rng_keys)

    fori_collect_map = map_fn(
        lambda hmc_state: fori_collect(0,
                                       num_samples,
                                       sample_kernel,
                                       hmc_state,
                                       transform=lambda x: x.z,
                                       progbar=False))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(np.mean(chain_samples, axis=1),
                    np.repeat(true_mean, 2),
                    rtol=0.06)
    assert_allclose(np.std(chain_samples, axis=1),
                    np.repeat(true_std, 2),
                    rtol=0.06)
Exemple #6
0
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath('../..'))

os.environ['SPHINX_BUILD'] = '1'

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.infer.mcmc import hmc  # noqa: E402
hmc(None, None)

# -- Project information -----------------------------------------------------

project = u'NumPyro'
copyright = u'2019, Uber Technologies, Inc'
author = u'Uber AI Labs'

version = ''

if 'READTHEDOCS' not in os.environ:
    # if developing locally, use pyro.__version__ as version
    from numpyro import __version__  # noqaE402
    version = __version__

# release version