Exemplo n.º 1
0
def pyro_noncentered_schools(data, draws, chains):
    """Non-centered eight schools implementation in Pyro."""
    import torch
    from pyro.infer import MCMC, NUTS

    y = torch.from_numpy(data["y"]).float()
    sigma = torch.from_numpy(data["sigma"]).float()

    nuts_kernel = NUTS(_pyro_noncentered_model)
    posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
    posterior.run(data["J"], sigma, y)

    # This block lets the posterior be pickled
    posterior.sampler = None
    return posterior
Exemplo n.º 2
0
def sample_model(chat,
                 mhat,
                 varpihat,
                 sigmac,
                 sigmam,
                 sigmavarpi,
                 dustco_c,
                 dustco_m,
                 theta_0_mcmc,
                 nsamples=100,
                 nwalkers=1):
    objective = Objective(chat, mhat, varpihat, sigmac, sigmam, sigmavarpi,
                          dustco_c, dustco_m)
    #print(objective.logjoint())
    objective.logjoint()
    #nuts_kernel = NUTS(objective.logjoint, jit_compile=True, ignore_jit_warnings=True)
    #mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=100, num_chains=2, mp_context='spawn')

    try:
        with open('savemcmc_{}.pkl'.format(ind), 'rb') as f:
            mcmc = pickle.load(f)
    except IOError:
        nuts_kernel = NUTS(objective.logjoint,
                           jit_compile=True,
                           ignore_jit_warnings=False)
        mcmc = MCMC(nuts_kernel,
                    num_samples=nsamples,
                    warmup_steps=100,
                    num_chains=nwalkers,
                    initial_params=theta_0_mcmc,
                    mp_context='spawn')
        mcmc.run()

        with open('savemcmc_{}.pkl'.format(ind), 'wb') as f:
            mcmc.sampler = None
            mcmc.kernel.potential_fn = None
            pickle.dump(mcmc, f)
    return mcmc, objective