def test_functional_beta_bernoulli_x64(algo): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, model_args=(data, )) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_reuse_mcmc_pe_gen(): y1 = onp.random.normal(3, 0.1, (100, )) y2 = onp.random.normal(-3, 0.1, (100, )) def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(0), model, y1, dynamic_args=True) init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn) init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, )) @jit def _sample(state_and_args): hmc_state, model_args = state_and_args return sample_kernel(hmc_state, (model_args, )), model_args samples = fori_collect(0, 500, _sample, (init_state, y1), transform=lambda state: constrain_fn(y1) (state[0].z)) assert_allclose(samples['mu'].mean(), 3., atol=0.1) # Run on data, re-using `mcmc` - this should be much faster. init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, )) samples = fori_collect(0, 500, _sample, (init_state, y2), transform=lambda state: constrain_fn(y2) (state[0].z)) assert_allclose(samples['mu'].mean(), -3., atol=0.1)
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) model_info = initialize_model(init_rng_key, model, model_args=(returns, )) init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect( args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z), progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) print_results(hmc_states, dates) fig, ax = plt.subplots(1, 1) dates = mdates.num2date(mdates.datestr2num(dates)) ax.plot(dates, returns, lw=0.5) # format the ticks ax.xaxis.set_major_locator(mdates.YearLocator()) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) ax.xaxis.set_minor_locator(mdates.MonthLocator()) ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01) legend = ax.legend(['returns', 'volatility'], loc='upper right') legend.legendHandles[1].set_alpha(0.6) ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time') plt.savefig("stochastic_volatility_plot.pdf") plt.tight_layout()
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) init_params, potential_fn, constrain_fn = initialize_model( init_rng_key, model, returns) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect( 0, args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) print_results(hmc_states, dates)
def test_functional_map(algo, map_fn): if map_fn is pmap and xla_bridge.device_count() == 1: pytest.skip('pmap test requires device_count greater than 1.') true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = np.array([0., -1.]) rng_keys = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn( lambda init_param, rng_key: init_kernel(init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key)) init_states = init_kernel_map(init_params, rng_keys) fori_collect_map = map_fn( lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False)) chain_samples = fori_collect_map(init_states) assert_allclose(np.mean(chain_samples, axis=1), np.repeat(true_mean, 2), rtol=0.06) assert_allclose(np.std(chain_samples, axis=1), np.repeat(true_std, 2), rtol=0.06)
# full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('../..')) os.environ['SPHINX_BUILD'] = '1' # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.infer.mcmc import hmc # noqa: E402 hmc(None, None) # -- Project information ----------------------------------------------------- project = u'NumPyro' copyright = u'2019, Uber Technologies, Inc' author = u'Uber AI Labs' version = '' if 'READTHEDOCS' not in os.environ: # if developing locally, use pyro.__version__ as version from numpyro import __version__ # noqaE402 version = __version__ # release version