示例#1
0
def pyro_centered_schools(data, draws, chains):
    """Centered eight schools implementation in Pyro.

    Note there is not really a deterministic node in pyro, so I do not
    know how to do a non-centered implementation.
    """
    import torch
    from pyro.infer.mcmc import MCMC, NUTS

    del chains
    y = torch.Tensor(data["y"]).type(torch.Tensor)
    sigma = torch.Tensor(data["sigma"]).type(torch.Tensor)

    nuts_kernel = NUTS(_pyro_conditioned_model, adapt_step_size=True)
    posterior = MCMC(  # pylint:disable=not-callable
        nuts_kernel, num_samples=draws, warmup_steps=500
    ).run(_pyro_centered_model, sigma, y)

    # This block lets the posterior be pickled
    for trace in posterior.exec_traces:
        for node in trace.nodes.values():
            node.pop("fn", None)
    posterior.kernel = None
    posterior.run = None
    posterior.logger = None
    if hasattr(posterior, "sampler"):
        posterior.sampler = None
    return posterior
示例#2
0
def infer_sample(
    cond_model,
    n_steps,
    warmup_steps,
    n_chains=1,
    device="cpu",
    guidefile=None,
    guide_conf=None,
    mcmcfile=None,
):
    """Runs the NUTS HMC algorithm.

    Saves the samples and weights as well as a netcdf file for the run.

    Parameters
    ----------
    args : dict
        Command line arguments.
    cond_model : callable
        Model conditioned on an observed images.
    """
    initial_params, potential_fn, transforms, prototype_trace = util.initialize_model(
        cond_model)

    if guidefile is not None:
        guide = init_guide(cond_model,
                           guide_conf,
                           guidefile=guidefile,
                           device=device)
        sample = guide()
        for key in initial_params.keys():
            initial_params[key] = transforms[key](sample[key].detach())

    # FIXME: In the case of DiagonalNormal, results have to be mapped back onto unpacked latents
    if guide_conf["type"] == "DiagonalNormal":
        transform = guide.get_transform()
        unpack_fn = lambda u: guide.unpack_latent(u)
        potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
        initial_params = {"z": torch.zeros(guide.get_posterior().shape())}
        transforms = None

    def fun(*args, **kwargs):
        res = potential_fn(*args, **kwargs)
        return res

    nuts_kernel = NUTS(
        potential_fn=fun,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        full_mass=False,
        use_multinomial_sampling=True,
        jit_compile=False,
        max_tree_depth=10,
        transforms=transforms,
        step_size=1.0,
    )
    nuts_kernel.initial_params = initial_params

    # Run
    mcmc = MCMC(
        nuts_kernel,
        n_steps,
        warmup_steps=warmup_steps,
        initial_params=initial_params,
        num_chains=n_chains,
    )
    mcmc.run()

    # This block lets the posterior be pickled
    mcmc.sampler = None
    mcmc.kernel.potential_fn = None
    mcmc._cache = {}

    print(f"Saving MCMC object to {mcmcfile}")
    with open(mcmcfile, "wb") as f:
        pickle.dump(mcmc, f, pickle.HIGHEST_PROTOCOL)