def test_functional_beta_bernoulli_x64(algo): num_warmup, num_samples = 410, 100 def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta)) numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) init_params, potential_fn, constrain_fn, _ = initialize_model( random.PRNGKey(2), model, model_args=(data, )) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1.0, num_warmup=num_warmup) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05) if "JAX_ENABLE_X64" in os.environ: assert samples["p_latent"].dtype == jnp.float64
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) model_info = initialize_model(init_rng_key, model, model_args=(returns,)) init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z), progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) print_results(hmc_states, dates) fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) dates = mdates.num2date(mdates.datestr2num(dates)) ax.plot(dates, returns, lw=0.5) # format the ticks ax.xaxis.set_major_locator(mdates.YearLocator()) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) ax.xaxis.set_minor_locator(mdates.MonthLocator()) ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01) legend = ax.legend(['returns', 'volatility'], loc='upper right') legend.legendHandles[1].set_alpha(0.6) ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time') plt.savefig("stochastic_volatility_plot.pdf")
def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): from numpyro.infer.hmc import hmc warnings.warn( "The functional interface `hmc` has been moved to `numpyro.infer.hmc` module.", DeprecationWarning) return hmc(potential_fn, potential_fn_gen, kinetic_fn, algo)
def test_construct(): import bellini from bellini import Quantity, Species, Substance, Story import pint ureg = pint.UnitRegistry() s = Story() water = Species(name='water') s.one_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="first_normal", ) s.another_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="second_normal", ) s.combined_water = s.one_water_quantity + s.another_water_quantity # s.combined_water.observed = True s.combined_water.name = "combined_water" s.combined_water_with_nose = bellini.distributions.Normal( loc=s.combined_water, scale=Quantity(0.01, ureg.mole), name="combined_with_noise") s.combined_water_with_nose.observed = True from bellini.api._numpyro import graph_to_numpyro_model model = graph_to_numpyro_model(s.g) from numpyro.infer.util import initialize_model import jax model_info = initialize_model( jax.random.PRNGKey(2666), model, ) from numpyro.infer.hmc import hmc from numpyro.util import fori_collect init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, trajectory_length=10, num_warmup=300) samples = fori_collect( 0, 500, sample_kernel, hmc_state, transform=lambda state: model_info.postprocess_fn(state.z)) print(samples)
def test_functional_map(algo, map_fn): if map_fn is pmap and jax.device_count() == 1: pytest.skip("pmap test requires device_count greater than 1.") true_mean, true_std = 1.0, 2.0 num_warmup, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = jnp.array([0.0, -1.0]) rng_keys = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn( lambda init_param, rng_key: init_kernel(init_param, trajectory_length=9, num_warmup=num_warmup, rng_key=rng_key)) init_states = init_kernel_map(init_params, rng_keys) fori_collect_map = map_fn(lambda hmc_state: fori_collect( 0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False, )) chain_samples = fori_collect_map(init_states) assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06) assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
# full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('../..')) os.environ['SPHINX_BUILD'] = '1' # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.infer.hmc import hmc # noqa: E402 hmc(None, None) # -- Project information ----------------------------------------------------- project = u'NumPyro' copyright = u'2019, Uber Technologies, Inc' author = u'Uber AI Labs' version = '' if 'READTHEDOCS' not in os.environ: # if developing locally, use pyro.__version__ as version from numpyro import __version__ # noqaE402 version = __version__ # release version