Esempio n. 1
0
    def __init__(
        self,
        prior,
        classifier: Union[str, Callable] = "resnet",
        device: str = "cpu",
        logging_level: Union[int, str] = "warning",
        summary_writer: Optional[SummaryWriter] = None,
        show_progress_bars: bool = True,
        **unused_args
    ):
        r"""Sequential Neural Ratio Estimation.

        We implement two inference methods in the respective subclasses.

        - SNRE_A / AALR is limited to `num_atoms=2`, but allows for density evaluation
          when training for one round.
        - SNRE_B / SRE can use more than two atoms, potentially boosting performance,
          but allows for posterior evaluation **only up to a normalizing constant**,
          even when training only one round.

        Args:
            classifier: Classifier trained to approximate likelihood ratios. If it is
                a string, use a pre-configured network of the provided type (one of
                linear, mlp, resnet). Alternatively, a function that builds a custom
                neural network can be provided. The function will be called with the
                first batch of simulations (theta, x), which can thus be used for shape
                inference and potentially for z-scoring. It needs to return a PyTorch
                `nn.Module` implementing the classifier.
            unused_args: Absorbs additional arguments. No entries will be used. If it
                is not empty, we warn. In future versions, when the new interface of
                0.14.0 is more mature, we will remove this argument.

        See docstring of `NeuralInference` class for all other arguments.
        """

        super().__init__(
            prior=prior,
            device=device,
            logging_level=logging_level,
            summary_writer=summary_writer,
            show_progress_bars=show_progress_bars,
            **unused_args
        )

        # As detailed in the docstring, `density_estimator` is either a string or
        # a callable. The function creating the neural network is attached to
        # `_build_neural_net`. It will be called in the first round and receive
        # thetas and xs as inputs, so that they can be used for shape inference and
        # potentially for z-scoring.
        check_estimator_arg(classifier)
        if isinstance(classifier, str):
            self._build_neural_net = utils.classifier_nn(model=classifier)
        else:
            self._build_neural_net = classifier

        # Ratio-based-specific summary_writer fields.
        self._summary.update({"mcmc_times": []})  # type: ignore
Esempio n. 2
0
def test_training_and_mcmc_on_device(method, model, device):
    """Test training on devices.

    This test does not check training speeds.

    """
    device = process_device(device)

    num_dim = 2
    num_samples = 10
    num_simulations = 500
    max_num_epochs = 5

    x_o = zeros(1, num_dim)
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    if method == SNPE:
        kwargs = dict(density_estimator=utils.posterior_nn(model=model), )
        mcmc_kwargs = dict(
            sample_with_mcmc=True,
            mcmc_method="slice_np",
        )
    elif method == SNLE:
        kwargs = dict(density_estimator=utils.likelihood_nn(model=model), )
        mcmc_kwargs = dict(mcmc_method="slice")
    elif method == SNRE:
        kwargs = dict(classifier=utils.classifier_nn(model=model), )
        mcmc_kwargs = dict(mcmc_method="slice_np_vectorized", )
    else:
        raise ValueError()

    inferer = method(prior, show_progress_bars=False, device=device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for r in range(2):
        theta, x, = simulate_for_sbi(simulator,
                                     proposal=prior,
                                     num_simulations=num_simulations)
        _ = inferer.append_simulations(theta,
                                       x).train(training_batch_size=100,
                                                max_num_epochs=max_num_epochs)
        posterior = inferer.build_posterior(**mcmc_kwargs).set_default_x(x_o)
        proposals.append(posterior)

    proposals[-1].sample(sample_shape=(num_samples, ), x=x_o, **mcmc_kwargs)
Esempio n. 3
0
def test_inference_with_2d_x(embedding, method):

    num_dim = 2
    num_samples = 10
    num_simulations = 100

    prior = utils.BoxUniform(zeros(num_dim), torch.ones(num_dim))

    simulator, prior = prepare_for_sbi(simulator_2d, prior)

    theta_o = torch.ones(1, num_dim)
    x_o = simulator(theta_o)

    if method == SNPE:
        kwargs = dict(
            density_estimator=utils.posterior_nn(
                model="mdn",
                embedding_net=embedding(),
            ),
            sample_with_mcmc=True,
        )
    elif method == SNLE:
        kwargs = dict(density_estimator=utils.likelihood_nn(
            model="mdn", embedding_net=embedding()))
    else:
        kwargs = dict(density_estimator=utils.classifier_nn(
            model="mlp",
            embedding_net_x=embedding(),
        ))

    infer = method(
        simulator,
        prior,
        1,
        1,
        show_progress_bars=False,
        mcmc_method="slice_np",
        **kwargs,
    )

    posterior = infer(num_simulations=num_simulations,
                      training_batch_size=100,
                      max_num_epochs=10).set_default_x(x_o)

    posterior.log_prob(
        posterior.sample((num_samples, ), show_progress_bars=False))
def test_inference_with_2d_x(embedding, method):

    num_dim = 2
    num_samples = 10
    num_simulations = 100

    prior = utils.BoxUniform(zeros(num_dim), torch.ones(num_dim))

    simulator, prior = prepare_for_sbi(simulator_2d, prior)

    theta_o = torch.ones(1, num_dim)

    if method == SNPE:
        net_provider = utils.posterior_nn(
            model="mdn",
            embedding_net=embedding(),
        )
        sample_kwargs = {"sample_with_mcmc": True}
        num_trials = 1
    elif method == SNLE:
        net_provider = utils.likelihood_nn(model="mdn",
                                           embedding_net=embedding())
        sample_kwargs = {}
        num_trials = 2
    else:
        net_provider = utils.classifier_nn(
            model="mlp",
            embedding_net_x=embedding(),
        )
        sample_kwargs = {
            "mcmc_method": "slice_np_vectorized",
            "mcmc_parameters": {
                "num_chains": 2
            },
        }
        num_trials = 2

    inference = method(prior, net_provider, show_progress_bars=False)
    theta, x = simulate_for_sbi(simulator, prior, num_simulations)
    _ = inference.append_simulations(theta, x).train(training_batch_size=100,
                                                     max_num_epochs=10)
    x_o = simulator(theta_o.repeat(num_trials, 1))
    posterior = inference.build_posterior(**sample_kwargs).set_default_x(x_o)

    posterior.log_prob(
        posterior.sample((num_samples, ), show_progress_bars=False))
Esempio n. 5
0
def test_embedding_net_api(method, num_dim: int, embedding_net: str):
    """Tests the API when using a preconfigured embedding net."""

    x_o = zeros(1, num_dim)

    # likelihood_mean will be likelihood_shift+theta
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

    theta = prior.sample((1000, ))
    x = linear_gaussian(theta, likelihood_shift, likelihood_cov)

    if embedding_net == "mlp":
        embedding = FCEmbedding(input_dim=num_dim)
    else:
        raise NameError

    if method == "SNPE":
        density_estimator = posterior_nn("maf", embedding_net=embedding)
        inference = SNPE(prior,
                         density_estimator=density_estimator,
                         show_progress_bars=False)
    elif method == "SNLE":
        density_estimator = likelihood_nn("maf", embedding_net=embedding)
        inference = SNLE(prior,
                         density_estimator=density_estimator,
                         show_progress_bars=False)
    elif method == "SNRE":
        classifier = classifier_nn("resnet", embedding_net_x=embedding)
        inference = SNRE(prior,
                         classifier=classifier,
                         show_progress_bars=False)
    else:
        raise NameError

    _ = inference.append_simulations(theta, x).train(max_num_epochs=5)
    posterior = inference.build_posterior().set_default_x(x_o)

    s = posterior.sample((1, ))
    _ = posterior.potential(s)
Esempio n. 6
0
def test_sre_on_linearGaussian_api(num_dim: int):
    """Test inference API of SRE with linear Gaussian model. 

    Avoids intense computation for fast testing of API etc. 
    
    Args:
        num_dim: parameter dimension of the Gaussian model
    """

    simulator = linear_gaussian
    prior = distributions.MultivariateNormal(
        loc=torch.zeros(num_dim), covariance_matrix=torch.eye(num_dim))

    # XXX this breaks the test! (and #76 doesn't seem to fix)
    # true_observation = torch.zeros(1, num_dim)

    true_observation = torch.zeros(num_dim)

    classifier = utils.classifier_nn(
        "resnet",
        prior=prior,
        context=true_observation,
    )

    infer = SRE(
        simulator=simulator,
        prior=prior,
        true_observation=true_observation,
        classifier=classifier,
        simulation_batch_size=50,
        mcmc_method="slice-np",
    )

    posterior = infer(num_rounds=1, num_simulations_per_round=1000)

    samples = posterior.sample(num_samples=10, num_chains=2)
Esempio n. 7
0
def test_sre_posterior_correction(mcmc_method: str, prior_str: str):
    """Test leakage correction both for MCMC and rejection sampling.

    Args:
        mcmc_method: which mcmc method to use for sampling
        prior_str: one of "gaussian" or "uniform"
    """

    num_dim = 2
    if prior_str == "gaussian":
        prior = distributions.MultivariateNormal(
            loc=torch.zeros(num_dim), covariance_matrix=torch.eye(num_dim))
    else:
        prior = utils.BoxUniform(low=-1.0 * torch.ones(num_dim),
                                 high=torch.ones(num_dim))

    true_observation = torch.zeros(num_dim)

    classifier = utils.classifier_nn(
        "resnet",
        prior=prior,
        context=true_observation,
    )

    infer = SRE(
        simulator=linear_gaussian,
        prior=prior,
        true_observation=true_observation,
        classifier=classifier,
        simulation_batch_size=50,
        mcmc_method=mcmc_method,
    )

    posterior = infer(num_rounds=1, num_simulations_per_round=1000)

    samples = posterior.sample(num_samples=50)
Esempio n. 8
0
    def __init__(
        self,
        simulator: Callable,
        prior,
        num_workers: int = 1,
        simulation_batch_size: int = 1,
        classifier: Union[str, Callable] = "resnet",
        mcmc_method: str = "slice_np",
        mcmc_parameters: Optional[Dict[str, Any]] = None,
        device: str = "cpu",
        logging_level: Union[int, str] = "warning",
        summary_writer: Optional[SummaryWriter] = None,
        show_progress_bars: bool = True,
        show_round_summary: bool = False,
    ):
        r"""Sequential Neural Ratio Estimation.

        We implement two inference methods in the respective subclasses.

        - SNRE_A / AALR is limited to `num_atoms=2`, but allows for density evaluation
          when training for one round.
        - SNRE_B / SRE can use more than two atoms, potentially boosting performance,
          but allows for posterior evaluation **only up to a normalizing constant**,
          even when training only one round.

        Args:
            classifier: Classifier trained to approximate likelihood ratios. If it is
                a string, use a pre-configured network of the provided type (one of
                linear, mlp, resnet). Alternatively, a function that builds a custom
                neural network can be provided. The function will be called with the
                first batch of simulations (theta, x), which can thus be used for shape
                inference and potentially for z-scoring. It needs to return a PyTorch
                `nn.Module` implementing the classifier.
            mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`.
                Currently defaults to `slice_np` for a custom numpy implementation of
                slice sampling; select `hmc`, `nuts` or `slice` for Pyro-based sampling.
            mcmc_parameters: Dictionary overriding the default parameters for MCMC.
                The following parameters are supported: `thin` to set the thinning
                factor for the chain, `warmup_steps` to set the initial number of
                samples to discard, `num_chains` for the number of chains, `init_strategy`
                for the initialisation strategy for chains; `prior` will draw init
                locations from prior, whereas `sir` will use Sequential-Importance-
                Resampling using `init_strategy_num_candidates` to find init
                locations.

        See docstring of `NeuralInference` class for all other arguments.
        """

        super().__init__(
            simulator=simulator,
            prior=prior,
            num_workers=num_workers,
            simulation_batch_size=simulation_batch_size,
            device=device,
            logging_level=logging_level,
            summary_writer=summary_writer,
            show_progress_bars=show_progress_bars,
            show_round_summary=show_round_summary,
        )

        # As detailed in the docstring, `density_estimator` is either a string or
        # a callable. The function creating the neural network is attached to
        # `_build_neural_net`. It will be called in the first round and receive
        # thetas and xs as inputs, so that they can be used for shape inference and
        # potentially for z-scoring.
        check_estimator_arg(classifier)
        if isinstance(classifier, str):
            self._build_neural_net = utils.classifier_nn(model=classifier)
        else:
            self._build_neural_net = classifier
        self._posterior = None
        self._sample_with_mcmc = True
        self._mcmc_method = mcmc_method
        self._mcmc_parameters = mcmc_parameters

        # Ratio-based-specific summary_writer fields.
        self._summary.update({"mcmc_times": []})  # type: ignore
Esempio n. 9
0
def test_inference_with_2d_x(embedding, method):

    num_dim = 2
    num_samples = 10
    num_simulations = 100

    prior = utils.BoxUniform(zeros(num_dim), torch.ones(num_dim))

    simulator, prior = prepare_for_sbi(simulator_2d, prior)

    theta_o = torch.ones(1, num_dim)

    if method == SNPE:
        net_provider = utils.posterior_nn(
            model="mdn",
            embedding_net=embedding(),
        )
        num_trials = 1
    elif method == SNLE:
        net_provider = utils.likelihood_nn(model="mdn",
                                           embedding_net=embedding())
        num_trials = 2
    else:
        net_provider = utils.classifier_nn(
            model="mlp",
            z_score_theta="structured",  # Test that structured z-scoring works.
            embedding_net_x=embedding(),
        )
        num_trials = 2

    if method == SNRE:
        inference = method(classifier=net_provider, show_progress_bars=False)
    else:
        inference = method(density_estimator=net_provider,
                           show_progress_bars=False)
    theta, x = simulate_for_sbi(simulator, prior, num_simulations)
    estimator = inference.append_simulations(theta,
                                             x).train(training_batch_size=100,
                                                      max_num_epochs=10)
    x_o = simulator(theta_o.repeat(num_trials, 1))

    if method == SNLE:
        potential_fn, theta_transform = likelihood_estimator_based_potential(
            estimator, prior, x_o)
    elif method == SNPE:
        potential_fn, theta_transform = posterior_estimator_based_potential(
            estimator, prior, x_o)
    elif method == SNRE:
        potential_fn, theta_transform = ratio_estimator_based_potential(
            estimator, prior, x_o)
    else:
        raise NotImplementedError

    posterior = MCMCPosterior(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        proposal=prior,
        method="slice_np_vectorized",
        num_chains=2,
    )

    posterior.potential(
        posterior.sample((num_samples, ), show_progress_bars=False))
Esempio n. 10
0
File: sre.py Progetto: boyali/sbi
    def __init__(
        self,
        simulator: Callable,
        prior,
        true_observation: Tensor,
        classifier: nn.Module,
        num_atoms: int = -1,
        simulation_batch_size: int = 1,
        mcmc_method: str = "slice-np",
        summary_net: Optional[nn.Module] = None,
        classifier_loss: str = "sre",
        retrain_from_scratch_each_round: bool = False,
        summary_writer: Optional[SummaryWriter] = None,
        device: Optional[torch.device] = None,
    ):
        """Sequential Ratio Estimation

        As presented in _Likelihood-free MCMC with Amortized Approximate Likelihood Ratios_ by Hermans et al., Pre-print 2019, https://arxiv.org/abs/1903.04057

        See NeuralInference docstring for all other arguments.

        Args:
            classifier: Binary classifier
            num_atoms: Number of atoms to use for classification.
                If -1, use all other parameters in minibatch
            retrain_from_scratch_each_round: whether to retrain from scratch
                each round
            summary_net: Optional network which may be used to produce feature
                vectors f(x) for high-dimensional observations
            classifier_loss: `sre` implements the algorithm suggested in Durkan et al. 
                2019, whereas `aalr` implements the algorithm suggested in Hermans et al. 2019. `sre` can use more than two atoms, potentially boosting performance, but does not allow for exact posterior density evaluation (only up to a normalizing constant), even when training only one round. `aalr` is limited to `num_atoms=2`, but allows for density evaluation when training for one round.
        """

        super().__init__(
            simulator,
            prior,
            true_observation,
            simulation_batch_size,
            device,
            summary_writer,
        )

        self._classifier_loss = classifier_loss

        assert isinstance(num_atoms, int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        if classifier is None:
            classifier = utils.classifier_nn(
                model="resnet", prior=self._prior, context=self._true_observation,
            )

        # create posterior object which can sample()
        self._neural_posterior = Posterior(
            algorithm_family=self._classifier_loss,
            neural_net=classifier,
            prior=prior,
            context=true_observation,
            mcmc_method=mcmc_method,
            get_potential_function=PotentialFunctionProvider(),
        )

        # XXX why not classifier.train(True)???
        self._neural_posterior.neural_net.train(True)

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        if self._retrain_from_scratch_each_round:
            self._untrained_classifier = deepcopy(classifier)
        else:
            self._untrained_classifier = None

        # SRE-specific summary_writer fields
        self._summary.update({"mcmc_times": []})
Esempio n. 11
0
def test_sre_on_linearGaussian_based_on_mmd(num_dim: int, prior_str: str,
                                            classifier_loss: str):
    """Test MMD accuracy of inference with SRE on linear Gaussian model. 

    NOTE: The mmd threshold is calculated based on a number of test runs and taking the mean plus 2 stds. 
    
    Args:
        num_dim: parameter dimension of the gaussian model
        prior_str: one of "gaussian" or "uniform"
    """

    true_observation = torch.zeros(num_dim)
    num_samples = 300

    if prior_str == "gaussian":
        prior = distributions.MultivariateNormal(
            loc=torch.zeros(num_dim), covariance_matrix=torch.eye(num_dim))
        target_samples = get_true_posterior_samples_linear_gaussian_mvn_prior(
            true_observation[None, ], num_samples=num_samples)
    else:
        prior = utils.BoxUniform(-1.0 * torch.ones(num_dim),
                                 torch.ones(num_dim))
        target_samples = get_true_posterior_samples_linear_gaussian_uniform_prior(
            true_observation[None, ], num_samples=num_samples, prior=prior)

    classifier = utils.classifier_nn(
        "resnet",
        prior=prior,
        context=true_observation,
    )

    num_atoms = 2 if classifier_loss == "aalr" else -1

    infer = SRE(
        simulator=linear_gaussian,
        prior=prior,
        true_observation=true_observation,
        num_atoms=num_atoms,
        classifier=classifier,
        classifier_loss=classifier_loss,
        simulation_batch_size=50,
        mcmc_method="slice-np",
    )

    posterior = infer(num_rounds=1, num_simulations_per_round=1000)

    samples = posterior.sample(num_samples=num_samples)

    # Check if mmd is larger than expected.
    mmd = utils.unbiased_mmd_squared(target_samples, samples)
    max_mmd = 0.045
    assert (mmd < max_mmd
            ), f"MMD={mmd} is more than 2 stds above the average performance."

    # Checks for log_prob()
    if prior_str == "gaussian" and classifier_loss == "aalr":
        # For the Gaussian prior, we compute the KLd between ground truth and
        # posterior. We can do this only if the classifier_loss was as described in
        # Hermans et al. 2019 ('aalr') since Durkan et al. 2019 version only allows
        # evaluation up to a constant.
        # For the Gaussian prior, we compute the KLd between ground truth and posterior
        dkl = test_utils.get_dkl_gaussian_prior(posterior, true_observation,
                                                num_dim)

        max_dkl = 0.05 if num_dim == 1 else 0.8

        assert (
            dkl < max_dkl
        ), f"KLd={dkl} is more than 2 stds above the average performance."
    if prior_str == "uniform":
        # Check whether the returned probability outside of the support is zero.
        posterior_prob = test_utils.get_prob_outside_uniform_prior(
            posterior, num_dim)
        assert (
            posterior_prob == 0.0
        ), "The posterior probability outside of the prior support is not zero"
Esempio n. 12
0
def test_training_and_mcmc_on_device(
    method,
    model,
    data_device,
    mcmc_method,
    training_device,
    prior_device,
    prior_type="gaussian",
):
    """Test training on devices.

    This test does not check training speeds.

    """

    num_dim = 2
    num_samples = 10
    num_simulations = 100
    max_num_epochs = 5

    x_o = zeros(1, num_dim).to(data_device)
    likelihood_shift = -1.0 * ones(num_dim).to(prior_device)
    likelihood_cov = 0.3 * eye(num_dim).to(prior_device)

    if prior_type == "gaussian":
        prior_mean = zeros(num_dim).to(prior_device)
        prior_cov = eye(num_dim).to(prior_device)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    else:
        prior = BoxUniform(
            low=-2 * torch.ones(num_dim),
            high=2 * torch.ones(num_dim),
            device=prior_device,
        )

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    training_device = process_device(training_device)

    if method in [SNPE_A, SNPE_C]:
        kwargs = dict(
            density_estimator=utils.posterior_nn(model=model, num_transforms=2)
        )
    elif method == SNLE:
        kwargs = dict(
            density_estimator=utils.likelihood_nn(model=model, num_transforms=2)
        )
    elif method in (SNRE_A, SNRE_B):
        kwargs = dict(classifier=utils.classifier_nn(model=model))
    else:
        raise ValueError()

    inferer = method(show_progress_bars=False, device=training_device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for _ in range(2):
        theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations)
        theta, x = theta.to(data_device), x.to(data_device)

        estimator = inferer.append_simulations(theta, x).train(
            training_batch_size=100, max_num_epochs=max_num_epochs
        )
        if method == SNLE:
            potential_fn, theta_transform = likelihood_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNPE_A or method == SNPE_C:
            potential_fn, theta_transform = posterior_estimator_based_potential(
                estimator, prior, x_o
            )
        elif method == SNRE_A or method == SNRE_B:
            potential_fn, theta_transform = ratio_estimator_based_potential(
                estimator, prior, x_o
            )
        else:
            raise ValueError

        if mcmc_method == "rejection":
            posterior = RejectionPosterior(
                proposal=prior,
                potential_fn=potential_fn,
                device=training_device,
            )
        elif mcmc_method == "direct":
            posterior = DirectPosterior(
                posterior_estimator=estimator, prior=prior
            ).set_default_x(x_o)
        else:
            posterior = MCMCPosterior(
                potential_fn=potential_fn,
                theta_transform=theta_transform,
                proposal=prior,
                method=mcmc_method,
                device=training_device,
            )
        proposals.append(posterior)

    # Check for default device for inference object
    weights_device = next(inferer._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device
    samples = proposals[-1].sample(sample_shape=(num_samples,))
    proposals[-1].potential(samples)
def test_training_and_mcmc_on_device(method, model, data_device, training_device):
    """Test training on devices.

    This test does not check training speeds.

    """
    training_device = process_device(training_device)

    num_dim = 2
    num_samples = 10
    num_simulations = 500
    max_num_epochs = 5

    x_o = zeros(1, num_dim).to(data_device)
    likelihood_shift = -1.0 * ones(num_dim)
    likelihood_cov = 0.3 * eye(num_dim)

    prior_mean = zeros(num_dim)
    prior_cov = eye(num_dim)
    prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

    def simulator(theta):
        return linear_gaussian(theta, likelihood_shift, likelihood_cov)

    if method in [SNPE_A, SNPE_C]:
        kwargs = dict(
            density_estimator=utils.posterior_nn(model=model),
        )
        mcmc_kwargs = (
            dict(
                sample_with="mcmc",
                mcmc_method="slice_np",
            )
            if method == SNPE_C
            else {}
        )
    elif method == SNLE:
        kwargs = dict(
            density_estimator=utils.likelihood_nn(model=model),
        )
        mcmc_kwargs = dict(sample_with="mcmc", mcmc_method="slice")
    elif method in (SNRE_A, SNRE_B):
        kwargs = dict(
            classifier=utils.classifier_nn(model=model),
        )
        mcmc_kwargs = dict(
            sample_with="mcmc",
            mcmc_method="slice_np_vectorized",
        )
    else:
        raise ValueError()

    inferer = method(prior, show_progress_bars=False, device=training_device, **kwargs)

    proposals = [prior]

    # Test for two rounds.
    for _ in range(2):
        theta, x = simulate_for_sbi(simulator, prior, num_simulations)
        theta, x = theta.to(data_device), x.to(data_device)

        _ = inferer.append_simulations(theta, x).train(
            training_batch_size=100, max_num_epochs=max_num_epochs
        )
        posterior = inferer.build_posterior(**mcmc_kwargs).set_default_x(x_o)
        proposals.append(posterior)

    # Check for default device for inference object
    weights_device = next(inferer._neural_net.parameters()).device
    assert torch.device(training_device) == weights_device
    samples = proposals[-1].sample(sample_shape=(num_samples,), x=x_o, **mcmc_kwargs)
    proposals[-1].log_prob(samples)