Example #1
0
 def __init__(self,
              lower: Tensor,
              upper: Tensor,
              return_numpy: bool = False):
     self.lower = lower
     self.upper = upper
     self.dist = BoxUniform(lower, upper)
     self.return_numpy = return_numpy
Example #2
0
def test_process_matrix_observation():
    prior = BoxUniform(torch.zeros(4), torch.ones(4))
    observed_data = np.zeros((1, 2, 2))
    simulator = matrix_simulator

    observed_data, observation_dim = process_observed_data(
        observed_data, simulator, prior)
Example #3
0
def test_nograd_after_inference_train(inference_method) -> None:

    num_dim = 2
    prior_ = BoxUniform(-torch.ones(num_dim), torch.ones(num_dim))
    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior_)

    inference = inference_method(
        prior,
        **(
            dict(classifier="resnet")
            if inference_method in [SNRE_A, SNRE_B]
            else dict(
                density_estimator=(
                    "mdn_snpe_a" if inference_method == SNPE_A else "maf"
                )
            )
        ),
        show_progress_bars=False,
    )

    theta, x = simulate_for_sbi(simulator, prior, 32)
    inference = inference.append_simulations(theta, x)

    posterior_estimator = inference.train(max_num_epochs=2)

    def check_no_grad(model):
        for p in model.parameters():
            assert p.grad is None

    check_no_grad(posterior_estimator)
    check_no_grad(inference._neural_net)
Example #4
0
def test_simulate_in_batches(
    num_sims,
    batch_size,
    simulator=diagonal_linear_gaussian,
    prior=BoxUniform(zeros(5), ones(5)),
):
    """Test combinations of num_sims and simulation_batch_size. """

    theta = prior.sample((num_sims,))
    simulate_in_batches(simulator, theta, batch_size)
Example #5
0
def test_reinterpreted_batch_dim_prior():
    """Test whether the right warning and error are raised for reinterpreted priors."""

    # Both must raise ValueError because we don't reinterpret batch dims anymore.
    with pytest.raises(ValueError):
        process_prior(Uniform(zeros(3), ones(3)))
    with pytest.raises(ValueError):
        process_prior(MultivariateNormal(zeros(2, 3), ones(3)))

    # This must pass without warnings or errors.
    process_prior(BoxUniform(zeros(3), ones(3)))
Example #6
0
def test_simulate_in_batches(
        num_sims,
        batch_size,
        simulator,
        prior=BoxUniform(zeros(5), ones(5)),
):
    """Test combinations of num_sims and simulation_batch_size. """

    simulator, prior = prepare_for_sbi(simulator, prior)
    theta = prior.sample((num_sims, ))
    simulate_in_batches(simulator, theta, batch_size)
def test_nonlinearGaussian_based_on_mmd():
    simulator = non_linear_gaussian

    # Ground truth parameters as specified in 'Sequential Neural Likelihood' paper.
    ground_truth_parameters = torch.tensor([-0.7, -2.9, -1.0, -0.9, 0.6])
    # ground truth observation using same seed as 'Sequential Neural Likelihood' paper.
    ground_truth_observation = torch.tensor([
        -0.97071232,
        -2.94612244,
        -0.44947218,
        -3.42318484,
        -0.13285634,
        -3.36401699,
        -0.85367595,
        -2.42716377,
    ])

    # assume batch dims
    parameter_dim = ground_truth_parameters.shape[0]
    observation_dim = ground_truth_observation.shape[0]

    prior = BoxUniform(
        low=-3 * torch.ones(parameter_dim),
        high=3 * torch.ones(parameter_dim),
    )

    infer = SnpeC(
        simulator=simulator,
        true_observation=ground_truth_observation[None, ],
        prior=prior,
        num_atoms=-1,
        z_score_obs=True,
        use_combined_loss=False,
        retrain_from_scratch_each_round=False,
        discard_prior_samples=False,
    )

    num_rounds, num_simulations_per_round = 2, 1000
    posterior = infer(num_rounds=num_rounds,
                      num_simulations_per_round=num_simulations_per_round)

    samples = posterior.sample(1000)

    # Sample from (analytically tractable) target distribution.
    target_samples = get_ground_truth_posterior_samples_nonlinear_gaussian(
        num_samples=1000)

    # Compute and check if MMD is larger than expected.
    mmd = utils.unbiased_mmd_squared(target_samples.float(), samples.float())

    max_mmd = 0.16  # mean mmd plus 2 stds.
    assert mmd < max_mmd, f"MMD={mmd} larger than mean plus 2 stds."
Example #8
0
class UserNumpyUniform:
    """User defined numpy uniform prior.

    Used for testing to mimick a user-defined prior with valid .sample and .log_prob
    methods.
    """

    def __init__(self, lower: Tensor, upper: Tensor, return_numpy: bool = False):
        self.lower = lower
        self.upper = upper
        self.dist = BoxUniform(lower, upper)
        self.return_numpy = return_numpy

    def sample(self, sample_shape=torch.Size([])):
        samples = self.dist.sample(sample_shape)
        return samples.numpy() if self.return_numpy else samples

    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        log_probs = self.dist.log_prob(values)
        return log_probs.numpy() if self.return_numpy else log_probs
Example #9
0
def test_simulate_in_batches(
        num_samples,
        batch_size,
        simulator=linear_gaussian,
        prior=BoxUniform(torch.zeros(5), torch.ones(5)),
):
    """Test combinations of num_samples and simulation_batch_size. """

    simulate_in_batches(
        simulator,
        lambda n: prior.sample((n, )),
        num_samples,
        batch_size,
        torch.Size([5]),
    )
Example #10
0
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Turn off validation of input arguments to allow `log_prob()` on samples outside
    # of the support.
    prior.set_default_validate_args(False)

    # Reject unwrapped scalar priors.
    # This will reject Uniform priors with dimension larger than 1.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")
    # Cast 1D Uniform to BoxUniform to avoid shape error in mdn log prob.
    elif isinstance(prior, Uniform) and prior.batch_shape.numel() == 1:
        prior = BoxUniform(low=prior.low, high=prior.high)
        warnings.warn(
            "Casting 1D Uniform prior to BoxUniform to match sbi batch requirements."
        )

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior,
                                         return_type=float32,
                                         validate_args=False)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
Example #11
0
def test_train_with_different_data_and_training_device(
    inference_method, data_device: str, training_device: str
) -> None:

    assert torch.cuda.is_available(), "this test requires that cuda is available."

    num_dim = 2
    prior_ = BoxUniform(
        -torch.ones(num_dim), torch.ones(num_dim), device=training_device
    )
    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior_)

    inference = inference_method(
        prior,
        **(
            dict(classifier="resnet")
            if inference_method in [SNRE_A, SNRE_B]
            else dict(
                density_estimator=(
                    "mdn_snpe_a" if inference_method == SNPE_A else "maf"
                )
            )
        ),
        show_progress_bars=False,
        device=training_device,
    )

    theta, x = simulate_for_sbi(simulator, prior, 32)
    theta, x = theta.to(data_device), x.to(data_device)
    x_o = torch.zeros(x.shape[1])
    inference = inference.append_simulations(theta, x)

    posterior_estimator = inference.train(max_num_epochs=2)

    # Check for default device for inference object
    weights_device = next(inference._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device

    _ = DirectPosterior(
        posterior_estimator=posterior_estimator, prior=prior
    ).set_default_x(x_o)
Example #12
0
    # Test log prob on batch of thetas.
    log_probs = prior.log_prob(theta)
    assert isinstance(log_probs, Tensor)
    assert log_probs.shape[0] == batch_size


@pytest.mark.parametrize(
    "prior",
    (
        pytest.param(Uniform(0.0, 1.0), marks=pytest.mark.xfail),
        pytest.param(Uniform(torch.zeros(3), torch.ones(3)),
                     marks=pytest.mark.xfail),
        pytest.param(Uniform(torch.zeros((1, 3)), torch.ones((1, 3))),
                     marks=pytest.mark.xfail),
        Uniform(torch.zeros(1), torch.ones(1)),
        BoxUniform(torch.zeros(3), torch.ones(3)),
        MultivariateNormal(torch.zeros(3), torch.eye(3)),
        UserNumpyUniform(torch.zeros(3), torch.ones(3), return_numpy=False),
        UserNumpyUniform(torch.zeros(3), torch.ones(3), return_numpy=True),
    ),
)
def test_process_prior(prior):

    prior, parameter_dim, numpy_simulator = process_prior(prior)

    batch_size = 2
    theta = prior.sample((batch_size, ))
    assert theta.shape == torch.Size(
        (batch_size,
         parameter_dim)), "Number of sampled parameters must match batch size."
    assert (prior.log_prob(theta).shape[0] == batch_size
Example #13
0
    return theta.reshape(1, 2, 2)


@pytest.mark.parametrize(
    "wrapper, prior",
    (
        (
            CustomPytorchWrapper,
            UserNumpyUniform(zeros(3), ones(3), return_numpy=True),
        ),
        (ScipyPytorchWrapper, multivariate_normal()),
        (ScipyPytorchWrapper, uniform()),
        (ScipyPytorchWrapper, beta(a=1, b=1)),
        (
            PytorchReturnTypeWrapper,
            BoxUniform(zeros(3, dtype=torch.float64),
                       ones(3, dtype=torch.float64)),
        ),
    ),
)
def test_prior_wrappers(wrapper, prior):
    """Test prior wrappers to pytorch distributions."""
    prior = wrapper(prior)

    # use 2 here to test for minimal case >1
    batch_size = 2
    theta = prior.sample((batch_size, ))
    assert isinstance(theta, Tensor)
    assert theta.shape[0] == batch_size

    # Test log prob on batch of thetas.
    log_probs = prior.log_prob(theta)
Example #14
0
def test_training_and_mcmc_on_device(
    method,
    model,
    data_device,
    mcmc_method,
    training_device,
    prior_device,
    prior_type="gaussian",
):
    """Test training on devices.

    This test does not check training speeds.

    """

    num_dim = 2
    num_samples = 10
    num_simulations = 100
    max_num_epochs = 5

    x_o = zeros(1, num_dim).to(data_device)
    likelihood_shift = -1.0 * ones(num_dim).to(prior_device)
    likelihood_cov = 0.3 * eye(num_dim).to(prior_device)

    if prior_type == "gaussian":
        prior_mean = zeros(num_dim).to(prior_device)
        prior_cov = eye(num_dim).to(prior_device)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    else:
        prior = BoxUniform(
            low=-2 * torch.ones(num_dim),
            high=2 * torch.ones(num_dim),
            device=prior_device,
        )

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    training_device = process_device(training_device)

    if method in [SNPE_A, SNPE_C]:
        kwargs = dict(
            density_estimator=utils.posterior_nn(model=model, num_transforms=2)
        )
    elif method == SNLE:
        kwargs = dict(
            density_estimator=utils.likelihood_nn(model=model, num_transforms=2)
        )
    elif method in (SNRE_A, SNRE_B):
        kwargs = dict(classifier=utils.classifier_nn(model=model))
    else:
        raise ValueError()

    inferer = method(show_progress_bars=False, device=training_device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for _ in range(2):
        theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations)
        theta, x = theta.to(data_device), x.to(data_device)

        estimator = inferer.append_simulations(theta, x).train(
            training_batch_size=100, max_num_epochs=max_num_epochs
        )
        if method == SNLE:
            potential_fn, theta_transform = likelihood_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNPE_A or method == SNPE_C:
            potential_fn, theta_transform = posterior_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNRE_A or method == SNRE_B:
            potential_fn, theta_transform = ratio_estimator_based_potential(
                estimator, prior, x_o
            )
        else:
            raise ValueError

        if mcmc_method == "rejection":
            posterior = RejectionPosterior(
                proposal=prior,
                potential_fn=potential_fn,
                device=training_device,
            )
        elif mcmc_method == "direct":
            posterior = DirectPosterior(
                posterior_estimator=estimator, prior=prior
            ).set_default_x(x_o)
        else:
            posterior = MCMCPosterior(
                potential_fn=potential_fn,
                theta_transform=theta_transform,
                proposal=prior,
                method=mcmc_method,
                device=training_device,
            )
        proposals.append(posterior)

    # Check for default device for inference object
    weights_device = next(inferer._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device
    samples = proposals[-1].sample(sample_shape=(num_samples,))
    proposals[-1].potential(samples)