def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n, initial_state, initial_extra, **kwargs) n = len(initial_state.value) if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'simulated_data'): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'distance'): initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data) if not hasattr(initial_state, 'threshold'): if self.threshold_schedule is None: initial_state.threshold = jnp.zeros(n) + jnp.inf else: initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0] if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, sampler: MCMCSampler, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(abc_scenario, sampler, n, initial_state, initial_extra, **kwargs) if sampler.parameters.stepsize is None: initial_extra.parameters.stepsize = 1.0 if sampler.parameters.threshold is None: initial_extra.parameters.threshold = 50. initial_state.threshold = initial_extra.parameters.threshold initial_extra.parameters.stepsize = jnp.ones( abc_scenario.dim) * initial_extra.parameters.stepsize initial_state.stepsize = initial_extra.parameters.stepsize initial_extra.post_mean = initial_state.value initial_extra.diag_post_cov = initial_extra.parameters.stepsize * abc_scenario.dim / 2.38**2 return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_threshold = self.next_threshold(new_ensemble_state, new_extra) new_ensemble_state.threshold = jnp.ones(n) * next_threshold new_extra.parameters.threshold = next_threshold new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight) alive_inds = previous_ensemble_state.log_weight > -jnp.inf new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum() new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) return new_ensemble_state, new_extra
def clean_chain(self, abc_scenario: ABCScenario, chain_ensemble_state: cdict) -> cdict: chain_ensemble_state.threshold = chain_ensemble_state.threshold[:, 0] chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0] return chain_ensemble_state