def main(args): nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit,) mcmc = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps, num_chains=args.num_chains) mcmc.run(model, data.sigma, data.y) mcmc.summary(prob=0.5)
def mcmc_solver(idx): data = [] for (p, v, a) in data_input[idx]: _, vNext = step((torch.tensor(p), torch.tensor(v)), torch.tensor(a)) data.append([p, v, a, vNext]) nuts_kernel = NUTS( model, jit_compile=False, ) mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100, num_chains=1) data = torch.tensor(data) mcmc.run(data) mcmc.summary(prob=0.8)
def test_save_params(save_params, Kernel, options): save_params = list(save_params) def model(): x = pyro.sample("x", dist.Normal(0, 1)) with pyro.plate("plate", 2): y = pyro.sample("y", dist.Normal(x, 1)) pyro.sample("obs", dist.Normal(y, 1), obs=torch.zeros(2)) kernel = Kernel(model, **options) mcmc = MCMC(kernel, warmup_steps=2, num_samples=4, save_params=save_params) mcmc.run() samples = mcmc.get_samples() assert set(samples.keys()) == set(save_params) diagnostics = mcmc.diagnostics() diagnostics = {k: v for k, v in diagnostics.items() if k in "xy"} assert set(diagnostics.keys()) == set(save_params) mcmc.summary() # smoke test