Example #1
0
def test_independent_joint_shapes_and_samples(dists):
    """Test return shapes and validity of samples by comparing to samples generated
    from underlying distributions individually."""

    # Fix the seed for reseeding within this test.
    seed = 0

    joint = MultipleIndependent(dists)

    # Check shape of single sample and log prob
    sample = joint.sample()
    assert sample.shape == torch.Size([joint.ndims])
    assert joint.log_prob(sample).shape == torch.Size([1])

    num_samples = 10

    # seed sampling for later comparison.
    torch.manual_seed(seed)
    samples = joint.sample((num_samples,))
    log_probs = joint.log_prob(samples)

    # Check sample and log_prob return shapes.
    assert samples.shape == torch.Size(
        [num_samples, joint.ndims]
    ) or samples.shape == torch.Size([num_samples])
    assert log_probs.shape == torch.Size([num_samples])

    # Seed again to get same samples by hand.
    torch.manual_seed(seed)
    true_samples = []
    true_log_probs = []

    # Get samples and log probs by hand.
    for d in dists:
        sample = d.sample((num_samples,))
        true_samples.append(sample)
        true_log_probs.append(d.log_prob(sample).reshape(num_samples, -1))

    # collect in Tensor.
    true_samples = torch.cat(true_samples, dim=-1)
    true_log_probs = torch.cat(true_log_probs, dim=-1).sum(-1)

    # Check whether independent joint sample equal individual samples.
    assert (true_samples == samples).all()
    assert (true_log_probs == log_probs).all()

    # Check support attribute.
    within_support = joint.support.check(true_samples)
    assert within_support.all()
Example #2
0
def test_mnle_accuracy(sampler):
    def mixed_simulator(theta):
        # Extract parameters
        beta, ps = theta[:, :1], theta[:, 1:]

        # Sample choices and rts independently.
        choices = Binomial(probs=ps).sample()
        rts = InverseGamma(concentration=2 * torch.ones_like(beta),
                           rate=beta).sample()

        return torch.cat((rts, choices), dim=1)

    prior = MultipleIndependent(
        [
            Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
            Beta(torch.tensor([2.0]), torch.tensor([2.0])),
        ],
        validate_args=False,
    )

    num_simulations = 2000
    num_samples = 1000
    theta = prior.sample((num_simulations, ))
    x = mixed_simulator(theta)

    # MNLE
    trainer = MNLE(prior)
    trainer.append_simulations(theta, x).train()
    posterior = trainer.build_posterior()

    mcmc_kwargs = dict(
        num_chains=10,
        warmup_steps=100,
        method="slice_np_vectorized",
        init_strategy="proposal",
    )

    for num_trials in [10]:
        theta_o = prior.sample((1, ))
        x_o = mixed_simulator(theta_o.repeat(num_trials, 1))

        # True posterior samples
        transform = mcmc_transform(prior)
        true_posterior_samples = MCMCPosterior(
            PotentialFunctionProvider(prior, atleast_2d(x_o)),
            theta_transform=transform,
            proposal=prior,
            **mcmc_kwargs,
        ).sample((num_samples, ), show_progress_bars=False)

        posterior = trainer.build_posterior(prior=prior, sample_with=sampler)
        posterior.set_default_x(x_o)
        if sampler == "vi":
            posterior.train()

        mnle_posterior_samples = posterior.sample(
            sample_shape=(num_samples, ),
            show_progress_bars=False,
            **mcmc_kwargs if sampler == "mcmc" else {},
        )

        check_c2st(
            mnle_posterior_samples,
            true_posterior_samples,
            alg=f"MNLE with {sampler}",
        )