Exemplo n.º 1
0
def test_gamma_poisson(hyperpriors):
    def model(data):
        with pyro.plate("latent_dim", data.shape[1]):
            alpha = (
                pyro.sample("alpha", dist.HalfCauchy(1.0))
                if hyperpriors
                else torch.tensor([1.0, 1.0])
            )
            beta = (
                pyro.sample("beta", dist.HalfCauchy(1.0))
                if hyperpriors
                else torch.tensor([1.0, 1.0])
            )
            gamma_poisson = GammaPoissonPair()
            rate = pyro.sample("rate", gamma_poisson.latent(alpha, beta))
            with pyro.plate("data", data.shape[0]):
                pyro.sample("obs", gamma_poisson.conditional(rate), obs=data)

    true_rate = torch.tensor([3.0, 10.0])
    num_samples = 100
    data = dist.Poisson(rate=true_rate).sample(sample_shape=(torch.Size((100,))))
    hmc_kernel = NUTS(
        collapse_conjugate(model), jit_compile=True, ignore_jit_warnings=True
    )
    mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = posterior_replay(model, samples, data, num_samples=num_samples)
    assert_equal(posterior["rate"].mean(0), true_rate, prec=0.3)
Exemplo n.º 2
0
def test_beta_binomial(hyperpriors):
    def model(data):
        with pyro.plate("plate_0", data.shape[-1]):
            alpha = pyro.sample(
                "alpha", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor(
                    [1., 1.])
            beta = pyro.sample(
                "beta", dist.HalfCauchy(1.)) if hyperpriors else torch.tensor(
                    [1., 1.])
            beta_binom = BetaBinomialPair()
            with pyro.plate("plate_1", data.shape[-2]):
                probs = pyro.sample("probs", beta_binom.latent(alpha, beta))
                with pyro.plate("data", data.shape[0]):
                    pyro.sample("binomial",
                                beta_binom.conditional(
                                    probs=probs, total_count=total_count),
                                obs=data)

    true_probs = torch.tensor([[0.7, 0.4], [0.6, 0.4]])
    total_count = torch.tensor([[1000, 600], [400, 800]])
    num_samples = 80
    data = dist.Binomial(
        total_count=total_count,
        probs=true_probs).sample(sample_shape=(torch.Size((10, ))))
    hmc_kernel = NUTS(collapse_conjugate(model),
                      jit_compile=True,
                      ignore_jit_warnings=True)
    mcmc = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=50)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = posterior_replay(model, samples, data, num_samples=num_samples)
    assert_equal(posterior["probs"].mean(0), true_probs, prec=0.05)