コード例 #1
0
def test_mdn_with_1D_uniform_prior():
    """
    Note, we have this test because for 1D uniform priors, mdn log prob evaluation
    results in batch_size x batch_size return. This is probably because Uniform does
    not allow for event dimension > 1 and somewhere in pyknos it is used as if this was
    possible.
    Casting to BoxUniform solves it.
    """
    num_dim = 1
    x_o = torch.tensor([[1.0]])
    num_samples = 100

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior = Uniform(low=torch.zeros(num_dim), high=torch.ones(num_dim))

    def simulator(theta: Tensor) -> Tensor:
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNPE(density_estimator="mdn")

    theta, x = simulate_for_sbi(simulator, prior, 100)
    posterior_estimator = inference.append_simulations(theta, x).train()
    posterior = DirectPosterior(posterior_estimator=posterior_estimator,
                                prior=prior)
    samples = posterior.sample((num_samples, ), x=x_o)
    log_probs = posterior.log_prob(samples, x=x_o)

    assert log_probs.shape == torch.Size([num_samples])
コード例 #2
0
ファイル: posterior_nn_test.py プロジェクト: bkmi/sbi
def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool):

    num_dim = 2

    prior = MultivariateNormal(loc=zeros(num_dim),
                               covariance_matrix=eye(num_dim))
    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
    inference = snpe_method(prior=prior)
    theta, x = simulate_for_sbi(simulator, prior, 1000)
    posterior_estimator = inference.append_simulations(
        theta, x).train(max_num_epochs=3)

    if x_o_batch_dim == 0:
        x_o = ones(num_dim)
    elif x_o_batch_dim == 1:
        x_o = ones(1, num_dim)
    elif x_o_batch_dim == 2:
        x_o = ones(2, num_dim)
    else:
        raise NotImplementedError

    posterior = DirectPosterior(posterior_estimator=posterior_estimator,
                                prior=prior).set_default_x(x_o)
    samples = posterior.sample((10, ))
    _ = posterior.log_prob(samples)
コード例 #3
0
def test_inference_with_restriction_estimator():

    # likelihood_mean will be likelihood_shift+theta
    num_dim = 3
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)
    x_o = zeros(1, num_dim)
    num_samples = 500

    def linear_gaussian_nan(theta,
                            likelihood_shift=likelihood_shift,
                            likelihood_cov=likelihood_cov):
        condition = theta[:, 0] < 0.0
        x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
        x[condition] = float("nan")

        return x

    prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))
    target_samples = samples_true_posterior_linear_gaussian_uniform_prior(
        x_o,
        likelihood_shift=likelihood_shift,
        likelihood_cov=likelihood_cov,
        num_samples=num_samples,
        prior=prior,
    )

    simulator, prior = prepare_for_sbi(linear_gaussian_nan, prior)
    restriction_estimator = RestrictionEstimator(prior=prior)
    proposals = [prior]
    num_rounds = 2

    for r in range(num_rounds):
        theta, x = simulate_for_sbi(simulator, proposals[-1], 1000)
        restriction_estimator.append_simulations(theta, x)
        if r < num_rounds - 1:
            _ = restriction_estimator.train()
        proposals.append(restriction_estimator.restrict_prior())

    all_theta, all_x, _ = restriction_estimator.get_simulations()

    # Any method can be used in combination with the `RejectionEstimator`.
    inference = SNPE_C(prior=prior)
    posterior_estimator = inference.append_simulations(all_theta,
                                                       all_x).train()

    # Build posterior.
    posterior = DirectPosterior(
        prior=prior,
        posterior_estimator=posterior_estimator).set_default_x(x_o)

    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg=f"{SNPE_C}")
コード例 #4
0
def mdn_inference_with_different_methods(method):

    num_dim = 2
    x_o = torch.tensor([[1.0, 0.0]])
    num_samples = 500
    num_simulations = 2000

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    gt_posterior = true_posterior_linear_gaussian_mvn_prior(
        x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov)
    target_samples = gt_posterior.sample((num_samples, ))

    def simulator(theta: Tensor) -> Tensor:
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    simulator, prior = prepare_for_sbi(simulator, prior)
    inference = method(density_estimator="mdn")

    theta, x = simulate_for_sbi(simulator, prior, num_simulations)
    estimator = inference.append_simulations(theta, x).train()
    if method == SNPE:
        posterior = DirectPosterior(posterior_estimator=estimator, prior=prior)
    else:
        potential_fn, theta_transform = likelihood_estimator_based_potential(
            likelihood_estimator=estimator, prior=prior, x_o=x_o)
        posterior = MCMCPosterior(potential_fn=potential_fn,
                                  theta_transform=theta_transform,
                                  proposal=prior)

    samples = posterior.sample((num_samples, ), x=x_o)

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg=f"{method}")
コード例 #5
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_train_with_different_data_and_training_device(
    inference_method, data_device: str, training_device: str
) -> None:

    assert torch.cuda.is_available(), "this test requires that cuda is available."

    num_dim = 2
    prior_ = BoxUniform(
        -torch.ones(num_dim), torch.ones(num_dim), device=training_device
    )
    simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior_)

    inference = inference_method(
        prior,
        **(
            dict(classifier="resnet")
            if inference_method in [SNRE_A, SNRE_B]
            else dict(
                density_estimator=(
                    "mdn_snpe_a" if inference_method == SNPE_A else "maf"
                )
            )
        ),
        show_progress_bars=False,
        device=training_device,
    )

    theta, x = simulate_for_sbi(simulator, prior, 32)
    theta, x = theta.to(data_device), x.to(data_device)
    x_o = torch.zeros(x.shape[1])
    inference = inference.append_simulations(theta, x)

    posterior_estimator = inference.train(max_num_epochs=2)

    # Check for default device for inference object
    weights_device = next(inference._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device

    _ = DirectPosterior(
        posterior_estimator=posterior_estimator, prior=prior
    ).set_default_x(x_o)
コード例 #6
0
ファイル: linearGaussian_snpe_test.py プロジェクト: bkmi/sbi
def test_example_posterior(snpe_method: type):
    """Return an inferred `NeuralPosterior` for interactive examination."""
    num_dim = 2
    x_o = zeros(1, num_dim)

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

    if snpe_method == SNPE_A:
        extra_kwargs = dict(final_round=True)
    else:
        extra_kwargs = dict()

    simulator, prior = prepare_for_sbi(
        lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
        prior)
    inference = snpe_method(prior, show_progress_bars=False)

    theta, x = simulate_for_sbi(simulator,
                                prior,
                                1000,
                                simulation_batch_size=10,
                                num_workers=6)
    posterior_estimator = inference.append_simulations(theta,
                                                       x).train(**extra_kwargs)
    if snpe_method == SNPE_A:
        posterior_estimator = inference.correct_for_proposal()
    posterior = DirectPosterior(
        prior=prior,
        posterior_estimator=posterior_estimator).set_default_x(x_o)
    assert posterior is not None
コード例 #7
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_training_and_mcmc_on_device(
    method,
    model,
    data_device,
    mcmc_method,
    training_device,
    prior_device,
    prior_type="gaussian",
):
    """Test training on devices.

    This test does not check training speeds.

    """

    num_dim = 2
    num_samples = 10
    num_simulations = 100
    max_num_epochs = 5

    x_o = zeros(1, num_dim).to(data_device)
    likelihood_shift = -1.0 * ones(num_dim).to(prior_device)
    likelihood_cov = 0.3 * eye(num_dim).to(prior_device)

    if prior_type == "gaussian":
        prior_mean = zeros(num_dim).to(prior_device)
        prior_cov = eye(num_dim).to(prior_device)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    else:
        prior = BoxUniform(
            low=-2 * torch.ones(num_dim),
            high=2 * torch.ones(num_dim),
            device=prior_device,
        )

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    training_device = process_device(training_device)

    if method in [SNPE_A, SNPE_C]:
        kwargs = dict(
            density_estimator=utils.posterior_nn(model=model, num_transforms=2)
        )
    elif method == SNLE:
        kwargs = dict(
            density_estimator=utils.likelihood_nn(model=model, num_transforms=2)
        )
    elif method in (SNRE_A, SNRE_B):
        kwargs = dict(classifier=utils.classifier_nn(model=model))
    else:
        raise ValueError()

    inferer = method(show_progress_bars=False, device=training_device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for _ in range(2):
        theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations)
        theta, x = theta.to(data_device), x.to(data_device)

        estimator = inferer.append_simulations(theta, x).train(
            training_batch_size=100, max_num_epochs=max_num_epochs
        )
        if method == SNLE:
            potential_fn, theta_transform = likelihood_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNPE_A or method == SNPE_C:
            potential_fn, theta_transform = posterior_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNRE_A or method == SNRE_B:
            potential_fn, theta_transform = ratio_estimator_based_potential(
                estimator, prior, x_o
            )
        else:
            raise ValueError

        if mcmc_method == "rejection":
            posterior = RejectionPosterior(
                proposal=prior,
                potential_fn=potential_fn,
                device=training_device,
            )
        elif mcmc_method == "direct":
            posterior = DirectPosterior(
                posterior_estimator=estimator, prior=prior
            ).set_default_x(x_o)
        else:
            posterior = MCMCPosterior(
                potential_fn=potential_fn,
                theta_transform=theta_transform,
                proposal=prior,
                method=mcmc_method,
                device=training_device,
            )
        proposals.append(posterior)

    # Check for default device for inference object
    weights_device = next(inferer._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device
    samples = proposals[-1].sample(sample_shape=(num_samples,))
    proposals[-1].potential(samples)
コード例 #8
0
ファイル: linearGaussian_snpe_test.py プロジェクト: bkmi/sbi
def test_c2st_snpe_on_linearGaussian(snpe_method, num_dim: int, prior_str: str,
                                     num_trials: int):
    """Test whether SNPE infers well a simple example with available ground truth."""

    x_o = zeros(num_trials, num_dim)
    num_samples = 1000
    num_simulations = 2600

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    if prior_str == "gaussian":
        prior_mean = zeros(num_dim)
        prior_cov = eye(num_dim)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
        gt_posterior = true_posterior_linear_gaussian_mvn_prior(
            x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov)
        target_samples = gt_posterior.sample((num_samples, ))
    else:
        prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))
        target_samples = samples_true_posterior_linear_gaussian_uniform_prior(
            x_o,
            likelihood_shift,
            likelihood_cov,
            prior=prior,
            num_samples=num_samples)

    simulator, prior = prepare_for_sbi(
        lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
        prior)

    inference = snpe_method(prior, show_progress_bars=False)

    theta, x = simulate_for_sbi(simulator,
                                prior,
                                num_simulations,
                                simulation_batch_size=1000)
    posterior_estimator = inference.append_simulations(
        theta, x).train(training_batch_size=100)
    posterior = DirectPosterior(
        prior=prior,
        posterior_estimator=posterior_estimator).set_default_x(x_o)
    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg="snpe_c")

    map_ = posterior.map(num_init_samples=1_000, show_progress_bars=False)

    # Checks for log_prob()
    if prior_str == "gaussian":
        # For the Gaussian prior, we compute the KLd between ground truth and posterior.
        dkl = get_dkl_gaussian_prior(posterior, x_o[0], likelihood_shift,
                                     likelihood_cov, prior_mean, prior_cov)

        max_dkl = 0.15

        assert (
            dkl < max_dkl
        ), f"D-KL={dkl} is more than 2 stds above the average performance."

        assert ((map_ - gt_posterior.mean)**2).sum() < 0.5

    elif prior_str == "uniform":
        # Check whether the returned probability outside of the support is zero.
        posterior_prob = get_prob_outside_uniform_prior(
            posterior, prior, num_dim)
        assert (
            posterior_prob == 0.0
        ), "The posterior probability outside of the prior support is not zero"

        # Check whether normalization (i.e. scaling up the density due
        # to leakage into regions without prior support) scales up the density by the
        # correct factor.
        (
            posterior_likelihood_unnorm,
            posterior_likelihood_norm,
            acceptance_prob,
        ) = get_normalization_uniform_prior(posterior, prior, x=x_o)
        # The acceptance probability should be *exactly* the ratio of the unnormalized
        # and the normalized likelihood. However, we allow for an error margin of 1%,
        # since the estimation of the acceptance probability is random (based on
        # rejection sampling).
        assert (
            acceptance_prob * 0.99 < posterior_likelihood_unnorm /
            posterior_likelihood_norm < acceptance_prob * 1.01
        ), "Normalizing the posterior density using the acceptance probability failed."

        assert ((map_ - ones(num_dim))**2).sum() < 0.5
コード例 #9
0
ファイル: linearGaussian_snpe_test.py プロジェクト: bkmi/sbi
def test_sample_conditional():
    """
    Test whether sampling from the conditional gives the same results as evaluating.

    This compares samples that get smoothed with a Gaussian kde to evaluating the
    conditional log-probability with `eval_conditional_density`.

    `eval_conditional_density` is itself tested in `sbiutils_test.py`. Here, we use
    a bimodal posterior to test the conditional.
    """

    num_dim = 3
    dim_to_sample_1 = 0
    dim_to_sample_2 = 2

    x_o = zeros(1, num_dim)

    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.1 * eye(num_dim)

    prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

    def simulator(theta):
        if torch.rand(1) > 0.5:
            return linear_gaussian(theta, likelihood_shift, likelihood_cov)
        else:
            return linear_gaussian(theta, -likelihood_shift, likelihood_cov)

    # Test whether SNPE works properly with structured z-scoring.
    net = utils.posterior_nn("maf", z_score_x="structured", hidden_features=20)

    simulator, prior = prepare_for_sbi(simulator, prior)

    inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False)

    # We need a pretty big dataset to properly model the bimodality.
    theta, x = simulate_for_sbi(simulator, prior, 10000)
    posterior_estimator = inference.append_simulations(
        theta, x).train(max_num_epochs=60)

    posterior = DirectPosterior(
        prior=prior,
        posterior_estimator=posterior_estimator).set_default_x(x_o)
    samples = posterior.sample((50, ))

    # Evaluate the conditional density be drawing samples and smoothing with a Gaussian
    # kde.
    potential_fn, theta_transform = posterior_estimator_based_potential(
        posterior_estimator, prior=prior, x_o=x_o)
    (
        conditioned_potential_fn,
        restricted_tf,
        restricted_prior,
    ) = conditonal_potential(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        prior=prior,
        condition=samples[0],
        dims_to_sample=[dim_to_sample_1, dim_to_sample_2],
    )
    mcmc_posterior = MCMCPosterior(
        potential_fn=conditioned_potential_fn,
        theta_transform=restricted_tf,
        proposal=restricted_prior,
    )
    cond_samples = mcmc_posterior.sample((500, ))

    _ = analysis.pairplot(
        cond_samples,
        limits=[[-2, 2], [-2, 2], [-2, 2]],
        figsize=(2, 2),
        diag="kde",
        upper="kde",
    )

    limits = [[-2, 2], [-2, 2], [-2, 2]]

    density = gaussian_kde(cond_samples.numpy().T, bw_method="scott")

    X, Y = np.meshgrid(
        np.linspace(limits[0][0], limits[0][1], 50),
        np.linspace(limits[1][0], limits[1][1], 50),
    )
    positions = np.vstack([X.ravel(), Y.ravel()])
    sample_kde_grid = np.reshape(density(positions).T, X.shape)

    # Evaluate the conditional with eval_conditional_density.
    eval_grid = analysis.eval_conditional_density(
        posterior,
        condition=samples[0],
        dim1=dim_to_sample_1,
        dim2=dim_to_sample_2,
        limits=torch.tensor([[-2, 2], [-2, 2], [-2, 2]]),
    )

    # Compare the two densities.
    sample_kde_grid = sample_kde_grid / np.sum(sample_kde_grid)
    eval_grid = eval_grid / torch.sum(eval_grid)

    error = np.abs(sample_kde_grid - eval_grid.numpy())

    max_err = np.max(error)
    assert max_err < 0.0027
コード例 #10
0
ファイル: linearGaussian_snpe_test.py プロジェクト: bkmi/sbi
def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str):
    """Test whether SNPE B/C infer well a simple example with available ground truth.
    .
    """

    num_dim = 2
    x_o = zeros((1, num_dim))
    num_samples = 1000

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

    gt_posterior = true_posterior_linear_gaussian_mvn_prior(
        x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov)
    target_samples = gt_posterior.sample((num_samples, ))

    if method_str == "snpe_c_non_atomic":
        # Test whether SNPE works properly with structured z-scoring.
        density_estimator = utils.posterior_nn("mdn",
                                               z_score_x="structured",
                                               num_components=5)
        method_str = "snpe_c"
    elif method_str == "snpe_a":
        density_estimator = "mdn_snpe_a"
    else:
        density_estimator = "maf"

    simulator, prior = prepare_for_sbi(
        lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov),
        prior)
    creation_args = dict(
        prior=prior,
        density_estimator=density_estimator,
        show_progress_bars=False,
    )

    if method_str == "snpe_b":
        inference = SNPE_B(**creation_args)
        theta, x = simulate_for_sbi(simulator,
                                    prior,
                                    500,
                                    simulation_batch_size=10)
        posterior_estimator = inference.append_simulations(theta, x).train()
        posterior1 = DirectPosterior(
            prior=prior,
            posterior_estimator=posterior_estimator).set_default_x(x_o)
        theta, x = simulate_for_sbi(simulator,
                                    posterior1,
                                    1000,
                                    simulation_batch_size=10)
        posterior_estimator = inference.append_simulations(
            theta, x, proposal=posterior1).train()
        posterior = DirectPosterior(
            prior=prior,
            posterior_estimator=posterior_estimator).set_default_x(x_o)
    elif method_str == "snpe_c":
        inference = SNPE_C(**creation_args)
        theta, x = simulate_for_sbi(simulator,
                                    prior,
                                    900,
                                    simulation_batch_size=50)
        posterior_estimator = inference.append_simulations(theta, x).train()
        posterior1 = DirectPosterior(
            prior=prior,
            posterior_estimator=posterior_estimator).set_default_x(x_o)
        theta = posterior1.sample((1000, ))
        x = simulator(theta)
        _ = inference.append_simulations(theta, x, proposal=posterior1).train()
        posterior = inference.build_posterior().set_default_x(x_o)
    elif method_str == "snpe_a":
        inference = SNPE_A(**creation_args)
        proposal = prior
        final_round = False
        num_rounds = 3
        for r in range(num_rounds):
            if r == 2:
                final_round = True
            theta, x = simulate_for_sbi(simulator,
                                        proposal,
                                        500,
                                        simulation_batch_size=50)
            inference = inference.append_simulations(theta,
                                                     x,
                                                     proposal=proposal)
            _ = inference.train(max_num_epochs=200, final_round=final_round)
            posterior = inference.build_posterior().set_default_x(x_o)
            proposal = posterior

    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg=method_str)
コード例 #11
0
ファイル: linearGaussian_snpe_test.py プロジェクト: bkmi/sbi
def test_c2st_snpe_on_linearGaussian_different_dims():
    """Test whether SNPE B/C infer well a simple example with available ground truth.

    This example has different number of parameters theta than number of x. Also
    this implicitly tests simulation_batch_size=1. It also impleictly tests whether the
    prior can be `None` and whether we can stop and resume training.

    """

    theta_dim = 3
    x_dim = 2
    discard_dims = theta_dim - x_dim

    x_o = zeros(1, x_dim)
    num_samples = 1000

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(x_dim)
    likelihood_cov = 0.3 * eye(x_dim)

    prior_mean = zeros(theta_dim)
    prior_cov = eye(theta_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    target_samples = samples_true_posterior_linear_gaussian_mvn_prior_different_dims(
        x_o,
        likelihood_shift,
        likelihood_cov,
        prior_mean,
        prior_cov,
        num_discarded_dims=discard_dims,
        num_samples=num_samples,
    )

    simulator, prior = prepare_for_sbi(
        lambda theta: linear_gaussian(theta,
                                      likelihood_shift,
                                      likelihood_cov,
                                      num_discarded_dims=discard_dims),
        prior,
    )
    # Test whether prior can be `None`.
    inference = SNPE_C(prior=None,
                       density_estimator="maf",
                       show_progress_bars=False)

    # type: ignore
    theta, x = simulate_for_sbi(simulator,
                                prior,
                                2000,
                                simulation_batch_size=1)

    inference = inference.append_simulations(theta, x)
    posterior_estimator = inference.train(
        max_num_epochs=10)  # Test whether we can stop and resume.
    posterior_estimator = inference.train(resume_training=True,
                                          force_first_round_loss=True)
    posterior = DirectPosterior(
        prior=prior,
        posterior_estimator=posterior_estimator).set_default_x(x_o)
    samples = posterior.sample((num_samples, ))

    # Compute the c2st and assert it is near chance level of 0.5.
    check_c2st(samples, target_samples, alg="snpe_c")