Пример #1
0
def test_functional_beta_bernoulli_x64(algo):
    num_warmup, num_samples = 410, 100

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta))
        numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn, _ = initialize_model(
        random.PRNGKey(2), model, model_args=(data, ))
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.0,
                            num_warmup=num_warmup)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p_latent"].dtype == jnp.float64
Пример #2
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns,))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
    hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
                              progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
Пример #3
0
def hmc(potential_fn=None,
        potential_fn_gen=None,
        kinetic_fn=None,
        algo='NUTS'):
    from numpyro.infer.hmc import hmc

    warnings.warn(
        "The functional interface `hmc` has been moved to `numpyro.infer.hmc` module.",
        DeprecationWarning)
    return hmc(potential_fn, potential_fn_gen, kinetic_fn, algo)
Пример #4
0
def test_construct():
    import bellini
    from bellini import Quantity, Species, Substance, Story
    import pint
    ureg = pint.UnitRegistry()

    s = Story()
    water = Species(name='water')
    s.one_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="first_normal",
    )
    s.another_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="second_normal",
    )
    s.combined_water = s.one_water_quantity + s.another_water_quantity
    # s.combined_water.observed = True
    s.combined_water.name = "combined_water"

    s.combined_water_with_nose = bellini.distributions.Normal(
        loc=s.combined_water,
        scale=Quantity(0.01, ureg.mole),
        name="combined_with_noise")

    s.combined_water_with_nose.observed = True

    from bellini.api._numpyro import graph_to_numpyro_model
    model = graph_to_numpyro_model(s.g)

    from numpyro.infer.util import initialize_model
    import jax
    model_info = initialize_model(
        jax.random.PRNGKey(2666),
        model,
    )
    from numpyro.infer.hmc import hmc
    from numpyro.util import fori_collect

    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info,
                            trajectory_length=10,
                            num_warmup=300)
    samples = fori_collect(
        0,
        500,
        sample_kernel,
        hmc_state,
        transform=lambda state: model_info.postprocess_fn(state.z))

    print(samples)
Пример #5
0
def test_functional_map(algo, map_fn):
    if map_fn is pmap and jax.device_count() == 1:
        pytest.skip("pmap test requires device_count greater than 1.")

    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = jnp.array([0.0, -1.0])
    rng_keys = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(
        lambda init_param, rng_key: init_kernel(init_param,
                                                trajectory_length=9,
                                                num_warmup=num_warmup,
                                                rng_key=rng_key))
    init_states = init_kernel_map(init_params, rng_keys)

    fori_collect_map = map_fn(lambda hmc_state: fori_collect(
        0,
        num_samples,
        sample_kernel,
        hmc_state,
        transform=lambda x: x.z,
        progbar=False,
    ))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(jnp.mean(chain_samples, axis=1),
                    jnp.repeat(true_mean, 2),
                    rtol=0.06)
    assert_allclose(jnp.std(chain_samples, axis=1),
                    jnp.repeat(true_std, 2),
                    rtol=0.06)
Пример #6
0
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath('../..'))

os.environ['SPHINX_BUILD'] = '1'

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.infer.hmc import hmc  # noqa: E402
hmc(None, None)

# -- Project information -----------------------------------------------------

project = u'NumPyro'
copyright = u'2019, Uber Technologies, Inc'
author = u'Uber AI Labs'

version = ''

if 'READTHEDOCS' not in os.environ:
    # if developing locally, use pyro.__version__ as version
    from numpyro import __version__  # noqaE402
    version = __version__

# release version