def test_posterior_predictive(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials) posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials) marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"] assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
def test_posterior_predictive_svi_auto_diag_normal_guide(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) opt = optim.Adam(dict(lr=0.1)) loss = Trace_ELBO() guide = AutoDiagonalNormal(conditioned_model) svi_run = SVI(conditioned_model, guide, opt, loss, num_steps=1000, num_samples=100).run(num_trials) posterior_predictive = TracePredictive(model, svi_run, num_samples=10000).run(num_trials) marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"] assert_close(marginal_return_vals.mean, torch.ones(5) * 700, rtol=0.05)
def test_posterior_predictive_svi_one_hot(): pseudocounts = torch.ones(3) * 0.1 true_probs = torch.tensor([0.15, 0.6, 0.25]) classes = dist.OneHotCategorical(true_probs).sample((10000, )) opt = optim.Adam(dict(lr=0.1)) loss = Trace_ELBO() guide = AutoDelta(one_hot_model) svi_run = SVI(one_hot_model, guide, opt, loss, num_steps=1000, num_samples=1000).run(pseudocounts, classes=classes) posterior_predictive = TracePredictive(one_hot_model, svi_run, num_samples=10000).run(pseudocounts) marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"] assert_close(marginal_return_vals.mean, true_probs.unsqueeze(0), rtol=0.1)