def test_binomial_stable(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn) hmc_state = init_kernel(init_params, num_warmup=warmup_steps) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p'].dtype == np.float64
def test_change_point(): # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696 warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = sample('lambda1', dist.Exponential(alpha)) lambda2 = sample('lambda2', dist.Exponential(alpha)) tau = sample('tau', dist.Uniform(0, 1)) lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2) sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(4), model, count_data) init_kernel, sample_kernel = hmc(potential_fn) hmc_state = init_kernel(init_params, num_warmup=warmup_steps) samples = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) tau_posterior = (samples['tau'] * len(count_data)).astype("int") tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_x64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64
def test_beta_bernoulli(algo): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps, progbar=False) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=False) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_dirichlet_categorical(algo, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = sample('p_latent', dist.Dirichlet(concentration)) sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps, progbar=False, dense_mass=dense_mass) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=False) assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert hmc_states['p_latent'].dtype == np.float64
def run_inference(dept, male, applications, admit, rng, args): init_params, potential_fn, constrain_fn = initialize_model( rng, glmm, dept, male, applications, admit) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup_steps) hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) return hmc_states
def main(args): jax_config.update('jax_platform_name', args.device) _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng, sample_rng = random.split(random.PRNGKey(args.rng)) init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng) hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) print_results(hmc_states, dates)
def run_inference(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, rng, args): init_params, potential_fn, constrain_fn = initialize_model( rng, semi_supervised_hmm, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, ) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup) hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=lambda state: constrain_fn(state.z)) return hmc_states
def test_unnormalized_normal(algo): true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = np.array(0.) hmc_state = init_kernel(init_params, trajectory_length=9, num_warmup=warmup_steps) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: x.z) assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05)
def benchmark_hmc(args, features, labels): trajectory_length = step_size * args.num_steps _, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels) init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo) t0 = time.time() # TODO: Use init_params from `initialize_model` instead of fixed params. hmc_state, _, _ = init_kernel(init_params, num_warmup=0, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, run_warmup=False) t1 = time.time() print("time for hmc_init: ", t1 - t0) def transform(state): return {'coefs': state.z['coefs'], 'num_steps': state.num_steps} hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=transform) num_leapfrogs = np.sum(hmc_states['num_steps']) print('number of leapfrog steps: ', num_leapfrogs) print('avg. time for each step: ', (time.time() - t1) / num_leapfrogs)
def test_map(algo, map_fn): if map_fn is pmap and xla_bridge.device_count() == 1: pytest.skip('pmap test requires device_count greater than 1.') true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = np.array([0., -1.]) rngs = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn( lambda init_param, rng: init_kernel(init_param, trajectory_length=9, num_warmup=warmup_steps, progbar=False, rng=rng)) init_states = init_kernel_map(init_params, rngs) fori_collect_map = map_fn( lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False)) chain_samples = fori_collect_map(init_states) assert_allclose(np.mean(chain_samples, axis=1), np.repeat(true_mean, 2), rtol=0.05) assert_allclose(np.std(chain_samples, axis=1), np.repeat(true_std, 2), rtol=0.05)
# http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('../..')) # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.mcmc import hmc # noqa: E402 from numpyro.svi import svi # noqa: E402 os.environ['SPHINX_BUILD'] = '1' hmc(None, None) svi(None, None, None, None, None, None) # -- Project information ----------------------------------------------------- project = u'Numpyro' copyright = u'2019, Uber Technologies, Inc' author = u'Uber AI Labs' # The short X.Y version version = u'0.0' # The full version, including alpha/beta/rc tags release = u'0.0' # -- General configuration ---------------------------------------------------