Example #1
0
def process_prior(
        prior,
        custom_prior_wrapper_kwargs: Dict = {}
) -> Tuple[Distribution, int, bool]:
    """Return PyTorch distribution-like prior from user-provided prior.

    Args:
        prior: Prior object with `.sample()` and `.log_prob()` as provided by the user.
        custom_prior_wrapper_kwargs: kwargs to be passed to the class that wraps a
            custom prior into a pytorch Distribution, e.g., for passing bounds for a
            prior with bounded support (lower_bound, upper_bound), or argument
            constraints.
            (arg_constraints), see pytorch.distributions.Distribution for more info.

    Raises:
        AttributeError: If prior objects lacks `.sample()` or `.log_prob()`.

    Returns:
        prior: Prior that emits samples and evaluates log prob as PyTorch Tensors.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: Whether the return type of the prior was a Numpy array.
    """

    # If prior is a sequence, assume independent components and check as PyTorch prior.
    if isinstance(prior, Sequence):
        warnings.warn(
            f"""Prior was provided as a sequence of {len(prior)} priors. They will be
            interpreted as independent of each other and matched in order to the
            components of the parameter.""")
        return process_pytorch_prior(MultipleIndependent(prior))

    if isinstance(prior, Distribution):
        return process_pytorch_prior(prior)

    # If prior is given as `scipy.stats` object, wrap as PyTorch.
    elif isinstance(prior, (rv_frozen, multi_rv_frozen)):
        event_shape = torch.Size([prior.rvs().size])
        # batch_shape is passed as default
        prior = ScipyPytorchWrapper(prior,
                                    batch_shape=torch.Size([]),
                                    event_shape=event_shape)
        return process_pytorch_prior(prior)

    # Otherwise it is a custom prior - check for `.sample()` and `.log_prob()`.
    else:
        return process_custom_prior(prior, custom_prior_wrapper_kwargs)
Example #2
0
def test_independent_joint_shapes_and_samples(dists):
    """Test return shapes and validity of samples by comparing to samples generated
    from underlying distributions individually."""

    # Fix the seed for reseeding within this test.
    seed = 0

    joint = MultipleIndependent(dists)

    # Check shape of single sample and log prob
    sample = joint.sample()
    assert sample.shape == torch.Size([joint.ndims])
    assert joint.log_prob(sample).shape == torch.Size([1])

    num_samples = 10

    # seed sampling for later comparison.
    torch.manual_seed(seed)
    samples = joint.sample((num_samples,))
    log_probs = joint.log_prob(samples)

    # Check sample and log_prob return shapes.
    assert samples.shape == torch.Size(
        [num_samples, joint.ndims]
    ) or samples.shape == torch.Size([num_samples])
    assert log_probs.shape == torch.Size([num_samples])

    # Seed again to get same samples by hand.
    torch.manual_seed(seed)
    true_samples = []
    true_log_probs = []

    # Get samples and log probs by hand.
    for d in dists:
        sample = d.sample((num_samples,))
        true_samples.append(sample)
        true_log_probs.append(d.log_prob(sample).reshape(num_samples, -1))

    # collect in Tensor.
    true_samples = torch.cat(true_samples, dim=-1)
    true_log_probs = torch.cat(true_log_probs, dim=-1).sum(-1)

    # Check whether independent joint sample equal individual samples.
    assert (true_samples == samples).all()
    assert (true_log_probs == log_probs).all()

    # Check support attribute.
    within_support = joint.support.check(true_samples)
    assert within_support.all()
Example #3
0
def test_invalid_inputs():

    dists = [
        Gamma(ones(1), ones(1)),
        Uniform(zeros(1), ones(1)),
        Beta(ones(1), 2 * ones(1)),
    ]
    joint = MultipleIndependent(dists)

    # Test too many input dimensions.
    with pytest.raises(AssertionError):
        joint.log_prob(ones(10, 4))

    # Test nested construction.
    with pytest.raises(AssertionError):
        MultipleIndependent([joint])

    # Test 3D value.
    with pytest.raises(AssertionError):
        joint.log_prob(ones(10, 4, 1))
Example #4
0
    _ = inference.append_simulations(theta, x).train(max_num_epochs=2)
    _ = inference.build_posterior()


@pytest.mark.parametrize(
    "dists",
    [
        pytest.param([Beta(ones(1), 2 * ones(1))],
                     marks=pytest.mark.xfail),  # single dist.
        pytest.param([Gamma(ones(2), ones(1))],
                     marks=pytest.mark.xfail),  # single batched dist.
        pytest.param(
            [
                Gamma(ones(2), ones(1)),
                MultipleIndependent(
                    [Uniform(zeros(1), ones(1)),
                     Uniform(zeros(1), ones(1))]),
            ],
            marks=pytest.mark.xfail,
        ),  # nested definition.
        pytest.param([Uniform(0, 1), Beta(1, 2)],
                     marks=pytest.mark.xfail),  # scalar dists.
        [Uniform(zeros(1), ones(1)),
         Uniform(zeros(1), ones(1))],
        (
            Gamma(ones(1), ones(1)),
            Uniform(zeros(1), ones(1)),
            Beta(ones(1), 2 * ones(1)),
        ),
        [MultivariateNormal(zeros(3), eye(3)),
         Gamma(ones(1), ones(1))],
Example #5
0
def test_mnle_accuracy(sampler):
    def mixed_simulator(theta):
        # Extract parameters
        beta, ps = theta[:, :1], theta[:, 1:]

        # Sample choices and rts independently.
        choices = Binomial(probs=ps).sample()
        rts = InverseGamma(concentration=2 * torch.ones_like(beta),
                           rate=beta).sample()

        return torch.cat((rts, choices), dim=1)

    prior = MultipleIndependent(
        [
            Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
            Beta(torch.tensor([2.0]), torch.tensor([2.0])),
        ],
        validate_args=False,
    )

    num_simulations = 2000
    num_samples = 1000
    theta = prior.sample((num_simulations, ))
    x = mixed_simulator(theta)

    # MNLE
    trainer = MNLE(prior)
    trainer.append_simulations(theta, x).train()
    posterior = trainer.build_posterior()

    mcmc_kwargs = dict(
        num_chains=10,
        warmup_steps=100,
        method="slice_np_vectorized",
        init_strategy="proposal",
    )

    for num_trials in [10]:
        theta_o = prior.sample((1, ))
        x_o = mixed_simulator(theta_o.repeat(num_trials, 1))

        # True posterior samples
        transform = mcmc_transform(prior)
        true_posterior_samples = MCMCPosterior(
            PotentialFunctionProvider(prior, atleast_2d(x_o)),
            theta_transform=transform,
            proposal=prior,
            **mcmc_kwargs,
        ).sample((num_samples, ), show_progress_bars=False)

        posterior = trainer.build_posterior(prior=prior, sample_with=sampler)
        posterior.set_default_x(x_o)
        if sampler == "vi":
            posterior.train()

        mnle_posterior_samples = posterior.sample(
            sample_shape=(num_samples, ),
            show_progress_bars=False,
            **mcmc_kwargs if sampler == "mcmc" else {},
        )

        check_c2st(
            mnle_posterior_samples,
            true_posterior_samples,
            alg=f"MNLE with {sampler}",
        )