Ejemplo n.º 1
0
def test_invalid_inputs():

    dists = [
        Gamma(ones(1), ones(1)),
        Uniform(zeros(1), ones(1)),
        Beta(ones(1), 2 * ones(1)),
    ]
    joint = MultipleIndependent(dists)

    # Test too many input dimensions.
    with pytest.raises(AssertionError):
        joint.log_prob(ones(10, 4))

    # Test nested construction.
    with pytest.raises(AssertionError):
        MultipleIndependent([joint])

    # Test 3D value.
    with pytest.raises(AssertionError):
        joint.log_prob(ones(10, 4, 1))
Ejemplo n.º 2
0
def process_prior(prior) -> Tuple[Distribution, int, bool]:
    """Return PyTorch distribution-like prior from user-provided prior.

    Args:
        prior: Prior object with `.sample()` and `.log_prob()` as provided by the user.

    Raises:
        AttributeError: If prior objects lacks `.sample()` or `.log_prob()`.

    Returns:
        prior: Prior that emits samples and evaluates log prob as PyTorch Tensors.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: Whether the return type of the prior was a Numpy array.
    """

    # If prior is a sequence, assume independent components and check as PyTorch prior.
    if isinstance(prior, Sequence):
        warnings.warn(
            f"""Prior was provided as a sequence of {len(prior)} priors. They will be
            interpreted as independent of each other and matched in order to the
            components of the parameter."""
        )
        return process_pytorch_prior(MultipleIndependent(prior))

    if isinstance(prior, Distribution):
        return process_pytorch_prior(prior)

    # If prior is given as `scipy.stats` object, wrap as PyTorch.
    elif isinstance(prior, (rv_frozen, multi_rv_frozen)):
        event_shape = torch.Size([prior.rvs().size])
        # batch_shape is passed as default
        prior = ScipyPytorchWrapper(
            prior, batch_shape=torch.Size([]), event_shape=event_shape
        )
        return process_pytorch_prior(prior)

    # Otherwise it is a custom prior - check for `.sample()` and `.log_prob()`.
    else:
        return process_custom_prior(prior)
Ejemplo n.º 3
0
def test_independent_joint_shapes_and_samples(dists):
    """Test return shapes and validity of samples by comparing to samples generated
    from underlying distributions individually."""

    # Fix the seed for reseeding within this test.
    seed = 0

    joint = MultipleIndependent(dists)

    # Check shape of single sample and log prob
    sample = joint.sample()
    assert sample.shape == torch.Size([joint.ndims])
    assert joint.log_prob(sample).shape == torch.Size([1])

    num_samples = 10

    # seed sampling for later comparison.
    torch.manual_seed(seed)
    samples = joint.sample((num_samples, ))
    log_probs = joint.log_prob(samples)

    # Check sample and log_prob return shapes.
    assert samples.shape == torch.Size([
        num_samples, joint.ndims
    ]) or samples.shape == torch.Size([num_samples])
    assert log_probs.shape == torch.Size([num_samples])

    # Seed again to get same samples by hand.
    torch.manual_seed(seed)
    true_samples = []
    true_log_probs = []

    # Get samples and log probs by hand.
    for d in dists:
        sample = d.sample((num_samples, ))
        true_samples.append(sample)
        true_log_probs.append(d.log_prob(sample).reshape(num_samples, -1))

    # collect in Tensor.
    true_samples = torch.cat(true_samples, dim=-1)
    true_log_probs = torch.cat(true_log_probs, dim=-1).sum(-1)

    # Check whether independent joint sample equal individual samples.
    assert (true_samples == samples).all()
    assert (true_log_probs == log_probs).all()
Ejemplo n.º 4
0
    # Run inference.
    _ = infer(num_rounds=1, num_simulations_per_round=100, max_num_epochs=2)


@pytest.mark.parametrize(
    "dists",
    [
        pytest.param([Beta(ones(1), 2 * ones(1))],
                     marks=pytest.mark.xfail),  # single dist.
        pytest.param([Gamma(ones(2), ones(1))],
                     marks=pytest.mark.xfail),  # single batched dist.
        pytest.param(
            [
                Gamma(ones(2), ones(1)),
                MultipleIndependent(
                    [Uniform(zeros(1), ones(1)),
                     Uniform(zeros(1), ones(1))]),
            ],
            marks=pytest.mark.xfail,
        ),  # nested definition.
        pytest.param([Uniform(0, 1), Beta(1, 2)],
                     marks=pytest.mark.xfail),  # scalar dists.
        [Uniform(zeros(1), ones(1)),
         Uniform(zeros(1), ones(1))],
        (
            Gamma(ones(1), ones(1)),
            Uniform(zeros(1), ones(1)),
            Beta(ones(1), 2 * ones(1)),
        ),
        [MultivariateNormal(zeros(3), eye(3)),
         Gamma(ones(1), ones(1))],