def _build_mcmc_init_fn( self, prior: Any, potential_fn: Callable, init_strategy: str = "prior", **kwargs, ) -> Callable: """ Return function that, when called, creates an initial parameter set for MCMC. Args: prior: Prior distribution. potential_fn: Potential function that the candidate samples are weighted with. init_strategy: Specifies the initialization method. Either of [`prior`|`sir`|`latest_sample`]. kwargs: Passed on to init function. This way, init specific keywords can be set through `mcmc_parameters`. Unused arguments should be absorbed. Returns: Initialization function. """ if init_strategy == "prior": return lambda: prior_init(prior, **kwargs) elif init_strategy == "sir": return lambda: sir(prior, potential_fn, **kwargs) elif init_strategy == "latest_sample": latest_sample = IterateParameters(self._mcmc_init_params, **kwargs) return latest_sample else: raise NotImplementedError
def _build_mcmc_init_fn( self, prior: Any, potential_fn: Callable, init_strategy: str = "prior", init_strategy_num_candidates: int = 10000, **kwargs, ) -> Callable: """ Return function that, when called, creates an initial parameter set for MCMC. Args: prior: Prior distribution. potential_fn: Potential function that the candidate samples are weighted with. init_strategy: Specifies the initialization method. Either of [`prior`|`sir`]. init_strategy_num_candidates: Number of candidate samples drawn. kwargs: Absorbs passed but unused arguments. E.g. in `DirectPosterior.sample()` we pass `mcmc_parameters` which might contain entries that are not used here. Returns: Initialization function. """ if init_strategy == "prior": return lambda: prior_init(prior) elif init_strategy == "sir": return lambda: sir( prior, potential_fn, init_strategy_num_candidates, ) elif init_strategy == "latest_sample": return lambda: self._mcmc_init_params else: raise NotImplementedError