Ejemplo n.º 1
0
def test_process_prior(prior):

    prior, parameter_dim, numpy_simulator = process_prior(prior)

    batch_size = 2
    theta = prior.sample((batch_size, ))
    assert theta.shape == torch.Size(
        (batch_size,
         parameter_dim)), "Number of sampled parameters must match batch size."
    assert (prior.log_prob(theta).shape[0] == batch_size
            ), "Number of log probs must match number of input values."
Ejemplo n.º 2
0
def test_process_simulator(simulator: Callable, prior: Distribution):

    prior, theta_dim, prior_returns_numpy = process_prior(prior)
    simulator = process_simulator(simulator, prior, prior_returns_numpy)

    n_batch = 1
    x = simulator(prior.sample((n_batch, )))

    assert isinstance(x, Tensor), "Processed simulator must return Tensor."
    assert (
        x.shape[0] == n_batch
    ), "Processed simulator must return as many data points as parameters in batch."
Ejemplo n.º 3
0
def test_process_prior(prior):

    prior, parameter_dim, numpy_simulator = process_prior(
        prior,
        custom_prior_wrapper_kwargs=dict(lower_bound=zeros(3), upper_bound=ones(3)),
    )

    batch_size = 2
    theta = prior.sample((batch_size,))
    assert theta.shape == torch.Size(
        (batch_size, parameter_dim)
    ), "Number of sampled parameters must match batch size."
    assert (
        prior.log_prob(theta).shape[0] == batch_size
    ), "Number of log probs must match number of input values."
Ejemplo n.º 4
0
def test_reinterpreted_batch_dim_prior():
    """Test whether the right warning and error are raised for reinterpreted priors."""

    # Both must raise ValueError because we don't reinterpret batch dims anymore.
    with pytest.raises(ValueError):
        process_prior(Uniform(zeros(3), ones(3)))
    with pytest.raises(ValueError):
        process_prior(MultivariateNormal(zeros(2, 3), ones(3)))

    # This must pass without warnings or errors.
    process_prior(BoxUniform(zeros(3), ones(3)))