コード例 #1
0
def test_posterior_predictive():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
    mcmc_run = MCMC(nuts_kernel, num_samples=1000,
                    warmup_steps=200).run(num_trials)
    posterior_predictive = TracePredictive(model, mcmc_run,
                                           num_samples=10000).run(num_trials)
    marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"]
    assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
コード例 #2
0
def test_posterior_predictive_svi_auto_diag_normal_guide():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    opt = optim.Adam(dict(lr=0.1))
    loss = Trace_ELBO()
    guide = AutoDiagonalNormal(conditioned_model)
    svi_run = SVI(conditioned_model,
                  guide,
                  opt,
                  loss,
                  num_steps=1000,
                  num_samples=100).run(num_trials)
    posterior_predictive = TracePredictive(model, svi_run,
                                           num_samples=10000).run(num_trials)
    marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"]
    assert_close(marginal_return_vals.mean, torch.ones(5) * 700, rtol=0.05)
コード例 #3
0
def test_posterior_predictive_svi_one_hot():
    pseudocounts = torch.ones(3) * 0.1
    true_probs = torch.tensor([0.15, 0.6, 0.25])
    classes = dist.OneHotCategorical(true_probs).sample((10000, ))
    opt = optim.Adam(dict(lr=0.1))
    loss = Trace_ELBO()
    guide = AutoDelta(one_hot_model)
    svi_run = SVI(one_hot_model,
                  guide,
                  opt,
                  loss,
                  num_steps=1000,
                  num_samples=1000).run(pseudocounts, classes=classes)
    posterior_predictive = TracePredictive(one_hot_model,
                                           svi_run,
                                           num_samples=10000).run(pseudocounts)
    marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"]
    assert_close(marginal_return_vals.mean, true_probs.unsqueeze(0), rtol=0.1)