예제 #1
0
파일: base_posterior.py 프로젝트: bkmi/sbi
    def __init__(
        self,
        potential_fn: Callable,
        theta_transform: Optional[TorchTransform] = None,
        device: Optional[str] = None,
        x_shape: Optional[torch.Size] = None,
    ):
        """
        Args:
            potential_fn: The potential function from which to draw samples.
            theta_transform: Transformation that will be applied during sampling.
                Allows to perform, e.g. MCMC in unconstrained space.
            device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
                `potential_fn.device` is used.
        """

        # Ensure device string.
        self._device = process_device(
            potential_fn.device if device is None else device)

        self.potential_fn = potential_fn

        if theta_transform is None:
            self.theta_transform = torch_tf.IndependentTransform(
                torch_tf.identity_transform, reinterpreted_batch_ndims=1)
        else:
            self.theta_transform = theta_transform

        self._map = None
        self._purpose = ""
        self._x_shape = x_shape

        # If the sampler interface (#573) is used, the user might have passed `x_o`
        # already to the potential function builder. If so, this `x_o` will be used
        # as default x.
        self._x = self.potential_fn.return_x_o()
예제 #2
0
def mcmc_transform(
    prior: Distribution,
    num_prior_samples_for_zscoring: int = 1000,
    enable_transform: bool = True,
    device: str = "cpu",
    **kwargs,
) -> TorchTransform:
    """
    Builds a transform that is applied to parameters during MCMC.

    The resulting transform is defined such that the forward mapping maps from
    constrained to unconstrained space.

    It does two things:
    1) When the prior support is bounded, it transforms the parameters into unbounded
        space.
    2) It z-scores the parameters such that MCMC is performed in a z-scored space.

    Args:
        prior: The prior distribution.
        num_prior_samples_for_zscoring: The number of samples drawn from the prior
            to infer the `mean` and `stddev` of the prior used for z-scoring. Unused if
            the prior has bounded support or when the prior has `mean` and `stddev`
            attributes.
        enable_transform: Whether or not to use a transformation during MCMC.

    Returns: A transformation that transforms whose `forward()` maps from unconstrained
        (or z-scored) to constrained (or non-z-scored) space.
    """

    if enable_transform:
        # Some distributions have a support argument but it raises a
        # NotImplementedError. We catch this case here.
        try:
            _ = prior.support
            has_support = True
        except (NotImplementedError, AttributeError):
            # NotImplementedError -> Distribution that inherits from torch dist but
            # does not implement support.
            # AttributeError -> Custom distribution that has no support attribute.
            warnings.warn(
                """The passed prior has no support property, transform will be
                constructed from mean and std. If the passed prior is supposed to be
                bounded consider implementing the prior.support property.""")
            has_support = False

        # If the distribution has a `support`, check if the support is bounded.
        # If it is not bounded, we want to z-score the space. This is not done
        # by `biject_to()`, so we have to deal with this case separately.
        if has_support:
            if hasattr(prior.support, "base_constraint"):
                constraint = prior.support.base_constraint  # type: ignore
            else:
                constraint = prior.support
            if isinstance(constraint, constraints._Real):
                support_is_bounded = False
            else:
                support_is_bounded = True
        else:
            support_is_bounded = False

        # Prior with bounded support, e.g., uniform priors.
        if has_support and support_is_bounded:
            transform = biject_to(prior.support)
        # For all other cases build affine transform with mean and std.
        else:
            if hasattr(prior, "mean") and hasattr(prior, "stddev"):
                prior_mean = prior.mean.to(device)
                prior_std = prior.stddev.to(device)
            else:
                theta = prior.sample(
                    torch.Size((num_prior_samples_for_zscoring, )))
                prior_mean = theta.mean(dim=0).to(device)
                prior_std = theta.std(dim=0).to(device)

            transform = torch_tf.AffineTransform(loc=prior_mean,
                                                 scale=prior_std)
    else:
        transform = torch_tf.identity_transform

    # Pytorch `transforms` do not sum the determinant over the parameters. However, if
    # the `transform` explicitly is an `IndependentTransform`, it does. Since our
    # `BoxUniform` is a `Independent` distribution, it will also automatically get a
    # `IndependentTransform` wrapper in `biject_to`. Our solution here is to wrap all
    # transforms as `IndependentTransform`.
    if not isinstance(transform, torch_tf.IndependentTransform):
        transform = torch_tf.IndependentTransform(transform,
                                                  reinterpreted_batch_ndims=1)

    check_transform(prior, transform)  # type: ignore

    return transform.inv  # type: ignore
예제 #3
0
            with highest log-probability as the initial points for the optimization.
        learning_rate: Learning rate of the optimizer.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead (thus,
            the default is `10`.)
        show_progress_bars: Whether or not to show a progressbar for the optimization.
        interruption_note: The message printed when the user interrupts the
            optimization.

    Returns:
        The `argmax` and `max` of the `potential_fn`.
    """

    if theta_transform is None:
        theta_transform = torch_tf.IndependentTransform(
            torch_tf.identity_transform, reinterpreted_batch_ndims=1)
    else:
        theta_transform = theta_transform

    init_probs = potential_fn(inits).detach()

    # Pick the `num_to_optimize` best init locations.
    sort_indices = torch.argsort(init_probs, dim=0)
    sorted_inits = inits[sort_indices]
    optimize_inits = sorted_inits[-num_to_optimize:]

    # The `_overall` variables store data accross the iterations, whereas the
    # `_iter` variables contain data exclusively extracted from the current
    # iteration.
    best_log_prob_iter = torch.max(init_probs)
    best_theta_iter = sorted_inits[-1]
예제 #4
0
def _transform_to_independent(constraint):
    base_transform = transform_to(constraint.base_constraint)
    return transforms.IndependentTransform(
        base_transform, constraint.reinterpreted_batch_ndims)