예제 #1
0
파일: test_utils.py 프로젝트: jyotikab/sbi
def get_normalization_uniform_prior(
    posterior: DirectPosterior,
    prior: Distribution,
    true_observation: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Return the unnormalized posterior likelihood, the normalized posterior likelihood,
    and the estimated acceptance probability.

    Args:
        posterior: estimated posterior
        prior: prior distribution
        true_observation: observation where we evaluate the posterior
    """

    # Test normalization.
    prior_sample = prior.sample()

    # Compute unnormalized density, i.e. just the output of the density estimator.
    posterior_likelihood_unnorm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=False))
    # Compute the normalized density, scale up output of the density
    # estimator by the ratio of posterior samples within the prior bounds.
    posterior_likelihood_norm = torch.exp(
        posterior.log_prob(prior_sample, norm_posterior=True))

    # Estimate acceptance ratio through rejection sampling.
    acceptance_prob = posterior.leakage_correction(x=true_observation)

    return posterior_likelihood_unnorm, posterior_likelihood_norm, acceptance_prob
예제 #2
0
    def eval_posterior(
        posterior: DirectPosterior,
        data_real: to.Tensor,
        num_samples: int,
        calculate_log_probs: bool = True,
        normalize_posterior: bool = True,
        subrtn_sbi_sampling_hparam: Optional[dict] = None,
    ) -> Tuple[to.Tensor, Optional[to.Tensor]]:
        r"""
        Evaluates the posterior by computing parameter samples given observed data, its log probability
        and the simulated trajectory.

        :param posterior: posterior to evaluate, e.g. a normalizing flow, that samples domain parameters conditioned on
                          the provided data
        :param data_real: data from the real-world rollouts a.k.a. set of $x_o$ of shape
                          [num_iter, num_rollouts_per_iter * dim_feat]
        :param num_samples: number of samples to draw from the posterior
        :param calculate_log_probs: if `True`, the log-probabilities are computed, else `None` is returned
        :param normalize_posterior: if `True`, the normalization of the posterior density is enforced by sbi
        :param subrtn_sbi_sampling_hparam: keyword arguments forwarded to sbi's `DirectPosterior.sample()` function
        :return: domain parameters sampled form the posterior of shape [batch_size, num_samples, dim_domain_param], as
                 well as the log-probabilities of these domain parameters
        """
        if not isinstance(data_real, to.Tensor) or data_real.ndim != 2:
            raise pyrado.ShapeErr(
                msg=
                f"The data must be a 2-dim PyTorch tensor, but is of shape {data_real.shape}!"
            )

        batch_size, _ = data_real.shape

        # Sample domain parameters for all batches and stack them
        default_sampling_hparam = dict(
            mcmc_method="slice_np_vectorized",
            mcmc_parameters=dict(warmup_steps=50,
                                 num_chains=100,
                                 init_strategy="sir"),  # default: slice_np, 20
        )
        if subrtn_sbi_sampling_hparam is None:
            subrtn_sbi_sampling_hparam = dict()
        elif isinstance(subrtn_sbi_sampling_hparam, dict):
            subrtn_sbi_sampling_hparam = merge_dicts(
                [default_sampling_hparam, subrtn_sbi_sampling_hparam])
        else:
            raise pyrado.TypeErr(given=subrtn_sbi_sampling_hparam,
                                 expected_type=dict)

        # Sample domain parameters from the posterior
        domain_params = to.stack(
            [
                posterior.sample(
                    (num_samples, ), x=x_o, **subrtn_sbi_sampling_hparam)
                for x_o in data_real
            ],
            dim=0,
        )

        # Check shape
        if not domain_params.ndim == 3 or domain_params.shape[:2] != (
                batch_size, num_samples):
            raise pyrado.ShapeErr(
                msg=
                f"The sampled domain parameters must be a 3-dim tensor where the 1st dimension is {batch_size} and "
                f"the 2nd dimension is {num_samples}, but it is of shape {domain_params.shape}!"
            )

        # Compute the log probability if desired
        if calculate_log_probs:
            # Batch-wise computation and stacking
            with completion_context("Evaluating posterior", color="w"):
                log_probs = to.stack(
                    [
                        posterior.log_prob(
                            dp, x=x_o, norm_posterior=normalize_posterior)
                        for dp, x_o in zip(domain_params, data_real)
                    ],
                    dim=0,
                )

            # Check shape
            if log_probs.shape != (batch_size, num_samples):
                raise pyrado.ShapeErr(given=log_probs,
                                      expected_match=(batch_size, num_samples))

        else:
            log_probs = None

        return domain_params, log_probs