def run_inference(model, at_bats, hits, rng_key, args): if args.algo == "NUTS": kernel = NUTS(model) elif args.algo == "HMC": kernel = HMC(model) elif args.algo == "SA": kernel = SA(model) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if ( "NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1., 0.5 warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000) def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) init_params = jnp.array(0.) if kernel_cls is SA: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) elif kernel_cls is BarkerMH: kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass) else: kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07) assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07) if 'JAX_ENABLE_X64' in os.environ: assert hmc_states.dtype == jnp.float64
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: num_warmup, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: num_warmup, num_samples = (2000, 12000) else: num_warmup, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples["logits"].shape == (num_samples, N) # those coefficients are found by doing MAP inference using AutoDelta expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1) if "JAX_ENABLE_X64" in os.environ: assert samples["coefs"].dtype == jnp.float64
def run_inference( model: Callable, at_bats: jnp.ndarray, hits: jnp.ndarray, rng_key: jnp.ndarray, *, num_warmup: int = 1500, num_samples: int = 3000, num_chains: int = 1, algo_name: str = "NUTS", ) -> Dict[str, jnp.ndarray]: if algo_name == "NUTS": kernel = NUTS(model) elif algo_name == "HMC": kernel = HMC(model) elif algo_name == "SA": kernel = SA(model) else: raise ValueError("Unknown algorithm name") mcmc = MCMC(kernel, num_warmup, num_samples, num_chains) mcmc.run(rng_key, at_bats, hits) return mcmc.get_samples()
def test_beta_bernoulli_x64(kernel_cls): warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) if kernel_cls is SA: kernel = SA(model=model) else: kernel = kernel_cls(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == jnp.float64
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)