def pyro_centered_schools(data, draws, chains): """Centered eight schools implementation in Pyro. Note there is not really a deterministic node in pyro, so I do not know how to do a non-centered implementation. """ import torch from pyro.infer.mcmc import MCMC, NUTS del chains y = torch.Tensor(data["y"]).type(torch.Tensor) sigma = torch.Tensor(data["sigma"]).type(torch.Tensor) nuts_kernel = NUTS(_pyro_conditioned_model, adapt_step_size=True) posterior = MCMC( # pylint:disable=not-callable nuts_kernel, num_samples=draws, warmup_steps=500 ).run(_pyro_centered_model, sigma, y) # This block lets the posterior be pickled for trace in posterior.exec_traces: for node in trace.nodes.values(): node.pop("fn", None) posterior.kernel = None posterior.run = None posterior.logger = None if hasattr(posterior, "sampler"): posterior.sampler = None return posterior
def infer_sample( cond_model, n_steps, warmup_steps, n_chains=1, device="cpu", guidefile=None, guide_conf=None, mcmcfile=None, ): """Runs the NUTS HMC algorithm. Saves the samples and weights as well as a netcdf file for the run. Parameters ---------- args : dict Command line arguments. cond_model : callable Model conditioned on an observed images. """ initial_params, potential_fn, transforms, prototype_trace = util.initialize_model( cond_model) if guidefile is not None: guide = init_guide(cond_model, guide_conf, guidefile=guidefile, device=device) sample = guide() for key in initial_params.keys(): initial_params[key] = transforms[key](sample[key].detach()) # FIXME: In the case of DiagonalNormal, results have to be mapped back onto unpacked latents if guide_conf["type"] == "DiagonalNormal": transform = guide.get_transform() unpack_fn = lambda u: guide.unpack_latent(u) potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) initial_params = {"z": torch.zeros(guide.get_posterior().shape())} transforms = None def fun(*args, **kwargs): res = potential_fn(*args, **kwargs) return res nuts_kernel = NUTS( potential_fn=fun, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, use_multinomial_sampling=True, jit_compile=False, max_tree_depth=10, transforms=transforms, step_size=1.0, ) nuts_kernel.initial_params = initial_params # Run mcmc = MCMC( nuts_kernel, n_steps, warmup_steps=warmup_steps, initial_params=initial_params, num_chains=n_chains, ) mcmc.run() # This block lets the posterior be pickled mcmc.sampler = None mcmc.kernel.potential_fn = None mcmc._cache = {} print(f"Saving MCMC object to {mcmcfile}") with open(mcmcfile, "wb") as f: pickle.dump(mcmc, f, pickle.HIGHEST_PROTOCOL)