コード例 #1
0
def main(args):
    nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit,)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(model, data.sigma, data.y)
    mcmc.summary(prob=0.5)
コード例 #2
0
def mcmc_solver(idx):
    data = []
    for (p, v, a) in data_input[idx]:
        _, vNext = step((torch.tensor(p), torch.tensor(v)), torch.tensor(a))
        data.append([p, v, a, vNext])
    nuts_kernel = NUTS(
        model,
        jit_compile=False,
    )
    mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100, num_chains=1)
    data = torch.tensor(data)
    mcmc.run(data)
    mcmc.summary(prob=0.8)
コード例 #3
0
def test_save_params(save_params, Kernel, options):
    save_params = list(save_params)

    def model():
        x = pyro.sample("x", dist.Normal(0, 1))
        with pyro.plate("plate", 2):
            y = pyro.sample("y", dist.Normal(x, 1))
            pyro.sample("obs", dist.Normal(y, 1), obs=torch.zeros(2))

    kernel = Kernel(model, **options)
    mcmc = MCMC(kernel, warmup_steps=2, num_samples=4, save_params=save_params)
    mcmc.run()

    samples = mcmc.get_samples()
    assert set(samples.keys()) == set(save_params)

    diagnostics = mcmc.diagnostics()
    diagnostics = {k: v for k, v in diagnostics.items() if k in "xy"}
    assert set(diagnostics.keys()) == set(save_params)

    mcmc.summary()  # smoke test