def test_diverging(kernel_cls, adapt_step_size): data = random.normal(random.PRNGKey(0), (1000, )) def model(data): loc = numpyro.sample('loc', dist.Normal(0., 1.)) numpyro.sample('obs', dist.Normal(loc, 1), obs=data) kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False) num_warmup = num_samples = 1000 mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True) warmup_divergences = mcmc.get_extra_fields()['diverging'].sum() mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging']) num_divergences = warmup_divergences + mcmc.get_extra_fields( )['diverging'].sum() if adapt_step_size: assert num_divergences <= num_warmup else: assert_allclose(num_divergences, num_warmup + num_samples)
def _sample(current_state, seed): step_size = jax.tree_map(jax.numpy.ones_like, init_state) nuts_kernel = NUTS( potential_fn=lambda x: -logp_fn_jax(*x), # model=model, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps", )) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields( group_by_chain=True)["num_steps"] return samples, leapfrogs_taken
def main() -> None: # Data num = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) # Random key rng_key = random.PRNGKey(0) # Inference nuts_kernel = NUTS(model_noncentered) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000) mcmc.run(rng_key, num, sigma, y=y, extra_fields=("potential_energy", )) print(mcmc.print_summary()) # Extra pe = mcmc.get_extra_fields()["potential_energy"] print(f"Expected log joint density: {np.mean(-pe):.2f}") # Prediction predictive = Predictive(model_pred, num_samples=100) samples = predictive(random.PRNGKey(1)) print("prior", np.mean(samples["obs"])) predictive = Predictive(model_pred, mcmc.get_samples()) samples = predictive(random.PRNGKey(1)) print("posterior", np.mean(samples["obs"]))
def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = 0.1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (10_000,) ) n, _ = data.shape num_warmup = 200 num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample( "mean", dist.Normal(ref_params, jnp.ones_like(ref_params)) ) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = HMCECS.taylor_proxy({"mean": ref_params}) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"]) pes = mcmc.get_extra_fields()["hmc_state.potential_energy"] samples = mcmc.get_samples() pes_full = vmap( lambda sample: log_density( model, (data,), {}, {**sample, **{"N": jnp.arange(n)}} )[0] )(samples) assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
def test_extra_fields(): def model(): numpyro.sample('x', dist.Normal(0, 1), sample_shape=(5,)) mcmc = MCMC(NUTS(model), 1000, 1000) mcmc.run(random.PRNGKey(0), extra_fields=('num_steps', 'adapt_state.step_size')) samples = mcmc.get_samples(group_by_chain=True) assert samples['x'].shape == (1, 1000, 5) stats = mcmc.get_extra_fields(group_by_chain=True) assert 'num_steps' in stats assert stats['num_steps'].shape == (1, 1000) assert 'adapt_state.step_size' in stats assert stats['adapt_state.step_size'].shape == (1, 1000)
def test_extra_fields(): def model(): numpyro.sample("x", dist.Normal(0, 1), sample_shape=(5,)) mcmc = MCMC(NUTS(model), 1000, 1000) mcmc.run(random.PRNGKey(0), extra_fields=("num_steps", "adapt_state.step_size")) samples = mcmc.get_samples(group_by_chain=True) assert samples["x"].shape == (1, 1000, 5) stats = mcmc.get_extra_fields(group_by_chain=True) assert "num_steps" in stats assert stats["num_steps"].shape == (1, 1000) assert "adapt_state.step_size" in stats assert stats["adapt_state.step_size"].shape == (1, 1000)
def _sample(*inputs): if op.nshared > 0: current_state = inputs[:-op.nshared] shared_inputs = tuple(op.fgraph.inputs[-op.nshared:]) else: current_state = inputs shared_inputs = () def log_fn_wrap(x): res = logp_fn(*( x # We manually obtain the shared values and added them # as arguments to our compiled "inner" function + tuple( v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs))) if isinstance(res, (list, tuple)): # This handles the new JAX backend, which always returns a tuple res = res[0] return -res nuts_kernel = NUTS( potential_fn=log_fn_wrap, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method="parallel", progress_bar=progress_bar, ) pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps", )) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields( group_by_chain=True)["num_steps"] return tuple(samples) + (leapfrogs_taken, )
def inference(belief_sequences, obs, mask, rng_key): nuts_kernel = NUTS(model, dense_mass=True) mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False) mcmc.run( rng_key, belief_sequences, obs, mask, extra_fields=('potential_energy',) ) samples = mcmc.get_samples() potential_energy = mcmc.get_extra_fields()['potential_energy'].mean() # mcmc.print_summary() return samples, potential_energy
def main(m, seed): rng_key = random.PRNGKey(seed) cutoff_up = 800 cutoff_down = 100 if m <= 10: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m) else: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10) model = generative_model nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) seq = (seq['beliefs'][0][cutoff_down:cutoff_up], seq['beliefs'][1][cutoff_down:cutoff_up]) rng_key, _rng_key = random.split(rng_key) mcmc.run( _rng_key, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool), extra_fields=('potential_energy',) ) samples = mcmc.get_samples() waic = log_pred_density( model, samples, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool) )['waic'] jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic) print(mcmc.get_extra_fields()['potential_energy'].mean())
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)
def run_model( model_func, data, ep, num_samples=500, num_warmup=500, num_chains=4, target_accept=0.75, max_tree_depth=15, save_results=True, output_fname=None, model_kwargs=None, save_json=False, chain_method="parallel", heuristic_step_size=True, ): """ Model run utility :param model_func: numpyro model :param data: PreprocessedData object :param ep: EpidemiologicalParameters object :param num_samples: number of samples :param num_warmup: number of warmup samples :param num_chains: number of chains :param target_accept: target accept :param max_tree_depth: maximum treedepth :param save_results: whether to save full results :param output_fname: output filename :param model_kwargs: model kwargs -- extra arguments for the model function :param save_json: whether to save json :param chain_method: Numpyro chain method to use :param heuristic_step_size: whether to find a heuristic step size :return: posterior_samples, warmup_samples, info_dict (dict with assorted diagnostics), Numpyro mcmc object """ print( f"Running {num_chains} chains, {num_samples} per chain with {num_warmup} warmup steps" ) nuts_kernel = NUTS( model_func, init_strategy=init_to_median, target_accept_prob=target_accept, max_tree_depth=max_tree_depth, find_heuristic_step_size=heuristic_step_size, ) mcmc = MCMC( nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, num_chains=num_chains, chain_method=chain_method, ) rng_key = random.PRNGKey(0) # hmcstate = nuts_kernel.init(rng_key, 1, model_args=(data, ep)) # nRVs = hmcstate.adapt_state.inverse_mass_matrix.size # inverse_mass_matrix = init_diag_inv_mass_mat * jnp.ones(nRVs) # mass_matrix_sqrt_inv = np.sqrt(inverse_mass_matrix) # mass_matrix_sqrt = 1./mass_matrix_sqrt_inv # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(inverse_mass_matrix=inverse_mass_matrix)) # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt_inv=mass_matrix_sqrt_inv)) # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt=mass_matrix_sqrt)) # mcmc.post_warmup_state = hmcstate info_dict = { "model_name": model_func.__name__, } start = time.time() if model_kwargs is None: model_kwargs = {} info_dict["model_kwargs"] = model_kwargs # also collect some extra information for better diagonstics! print(f"Warmup Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") mcmc.warmup( rng_key, data, ep, **model_kwargs, collect_warmup=True, extra_fields=["num_steps", "mean_accept_prob", "adapt_state"], ) mcmc.get_extra_fields()["num_steps"].block_until_ready() info_dict["warmup"] = {} info_dict["warmup"]["num_steps"] = np.array( mcmc.get_extra_fields()["num_steps"]).tolist() info_dict["warmup"]["step_size"] = np.array( mcmc.get_extra_fields()["adapt_state"].step_size).tolist() info_dict["warmup"]["inverse_mass_matrix"] = {} all_mass_mats = jnp.array( jnp.array_split( mcmc.get_extra_fields()["adapt_state"].inverse_mass_matrix, num_chains, axis=0, )) print(all_mass_mats.shape) for i in range(num_chains): info_dict["warmup"]["inverse_mass_matrix"][ f"chain_{i}"] = all_mass_mats[i, -1, :].tolist() info_dict["warmup"]["mean_accept_prob"] = np.array( mcmc.get_extra_fields()["mean_accept_prob"]).tolist() warmup_samples = mcmc.get_samples() print(f"Sample Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") mcmc.run( rng_key, data, ep, **model_kwargs, extra_fields=["num_steps", "mean_accept_prob", "adapt_state"], ) posterior_samples = mcmc.get_samples() # if you don't block this, the timer won't quite work properly. posterior_samples[list(posterior_samples.keys())[0]].block_until_ready() print(f"Sample Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") end = time.time() time_per_sample = float(end - start) / num_samples divergences = int(mcmc.get_extra_fields()["diverging"].sum()) info_dict["time_per_sample"] = time_per_sample info_dict["total_runtime"] = float(end - start) info_dict["divergences"] = divergences info_dict["sample"] = {} info_dict["sample"]["num_steps"] = np.array( mcmc.get_extra_fields()["num_steps"]).tolist() info_dict["sample"]["mean_accept_prob"] = np.array( mcmc.get_extra_fields()["mean_accept_prob"]).tolist() info_dict["sample"]["step_size"] = np.array( mcmc.get_extra_fields()["adapt_state"].step_size).tolist() print(f"Sampling {num_samples} samples per chain took {end - start:.2f}s") print(f"There were {divergences} divergences.") grouped_posterior_samples = mcmc.get_samples(True) all_ess = np.array([]) for k in grouped_posterior_samples.keys(): ess = numpyro.diagnostics.effective_sample_size( np.asarray(grouped_posterior_samples[k])) all_ess = np.append(all_ess, ess) print(f"{np.sum(np.isnan(all_ess))} ESS were nan") all_ess = all_ess[np.logical_not(np.isnan(all_ess))] info_dict["ess"] = { "med": float(np.percentile(all_ess, 50)), "lower": float(np.percentile(all_ess, 2.5)), "upper": float(np.percentile(all_ess, 97.5)), "min": float(np.min(all_ess)), "max": float(np.max(all_ess)), } print( f"Mean ESS: {info_dict['ess']['med']:.2f} [{info_dict['ess']['lower']:.2f} ... {info_dict['ess']['upper']:.2f}]" ) if num_chains > 1: all_rhat = np.array([]) for k in grouped_posterior_samples.keys(): rhat = numpyro.diagnostics.gelman_rubin( np.asarray(grouped_posterior_samples[k])) all_rhat = np.append(all_rhat, rhat) print(f"{np.sum(np.isnan(all_rhat))} Rhat were nan") all_rhat = all_rhat[np.logical_not(np.isnan(all_rhat))] info_dict["rhat"] = { "med": float(np.percentile(all_rhat, 50)), "upper": float(np.percentile(all_rhat, 97.5)), "lower": float(np.percentile(all_rhat, 2.5)), "min": float(np.max(all_rhat)), "max": float(np.min(all_rhat)), } print( f"Rhat: {info_dict['rhat']['med']:.2f} [{info_dict['rhat']['lower']:.2f} ... {info_dict['rhat']['upper']:.2f}]" ) if save_results: print("Saving .netcdf") try: inf_data = az.from_numpyro(mcmc) if output_fname is None: output_fname = f'{model_func.__name__}-{datetime.now(tz=None).strftime("%d-%m;%H-%M-%S")}.netcdf' az.to_netcdf(inf_data, output_fname) json_fname = output_fname.replace(".netcdf", ".json") if save_json: print("Saving Json") with open(json_fname, "w") as f: json.dump(info_dict, f, ensure_ascii=False, indent=4) except Exception as e: print(e) return posterior_samples, warmup_samples, info_dict, mcmc