示例#1
0
    def _build_mcmc_init_fn(
        self,
        prior: Any,
        potential_fn: Callable,
        init_strategy: str = "prior",
        **kwargs,
    ) -> Callable:
        """
        Return function that, when called, creates an initial parameter set for MCMC.

        Args:
            prior: Prior distribution.
            potential_fn: Potential function that the candidate samples are weighted
                with.
            init_strategy: Specifies the initialization method. Either of
                [`prior`|`sir`|`latest_sample`].
            kwargs: Passed on to init function. This way, init specific keywords can
                be set through `mcmc_parameters`. Unused arguments should be absorbed.

        Returns: Initialization function.
        """
        if init_strategy == "prior":
            return lambda: prior_init(prior, **kwargs)
        elif init_strategy == "sir":
            return lambda: sir(prior, potential_fn, **kwargs)
        elif init_strategy == "latest_sample":
            latest_sample = IterateParameters(self._mcmc_init_params, **kwargs)
            return latest_sample
        else:
            raise NotImplementedError
示例#2
0
    def _build_mcmc_init_fn(
        self,
        prior: Any,
        potential_fn: Callable,
        init_strategy: str = "prior",
        init_strategy_num_candidates: int = 10000,
        **kwargs,
    ) -> Callable:
        """
        Return function that, when called, creates an initial parameter set for MCMC.

        Args:
            prior: Prior distribution.
            potential_fn: Potential function that the candidate samples are weighted
                with.
            init_strategy: Specifies the initialization method. Either of
                [`prior`|`sir`].
            init_strategy_num_candidates: Number of candidate samples drawn.
            kwargs: Absorbs passed but unused arguments. E.g. in
                `DirectPosterior.sample()` we pass `mcmc_parameters` which might
                contain entries that are not used here.

        Returns: Initialization function.
        """

        if init_strategy == "prior":
            return lambda: prior_init(prior)
        elif init_strategy == "sir":
            return lambda: sir(
                prior,
                potential_fn,
                init_strategy_num_candidates,
            )
        elif init_strategy == "latest_sample":
            return lambda: self._mcmc_init_params
        else:
            raise NotImplementedError