def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def clean_chain_ar(self, abc_scenario: ABCScenario, chain_state: cdict): threshold = jnp.quantile(chain_state.distance, self.parameters.acceptance_rate) self.parameters.threshold = float(threshold) chain_state.log_weight = jnp.where(chain_state.distance < threshold, 0., -jnp.inf) return chain_state
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'log_weight'): initial_state.log_weight = jnp.zeros(n) if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_temperature = self.next_temperature(new_ensemble_state, new_extra) new_ensemble_state.temperature = jnp.ones(n) * next_temperature new_ensemble_state.log_weight = previous_ensemble_state.log_weight \ + self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * jnp.exp( log_ess_log_weight(new_ensemble_state.log_weight)) return new_ensemble_state, new_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_threshold = self.next_threshold(new_ensemble_state, new_extra) new_ensemble_state.threshold = jnp.ones(n) * next_threshold new_extra.parameters.threshold = next_threshold new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight) alive_inds = previous_ensemble_state.log_weight > -jnp.inf new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum() new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) return new_ensemble_state, new_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: self.mcmc_sampler.correction = check_correction( self.mcmc_sampler.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) first_temp = self.next_temperature(initial_state, initial_extra) scenario.temperature = first_temp initial_state.temperature += first_temp initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential initial_state.log_weight = -first_temp * initial_state.likelihood_potential initial_state.ess = jnp.repeat( jnp.exp(log_ess_log_weight(initial_state.log_weight)), n) initial_state, initial_extra = vmap( lambda state: self.mcmc_sampler.startup( scenario, n, state, initial_extra))(initial_state) initial_extra = initial_extra[0] return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(abc_scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = abc_scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(abc_scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) initial_state.log_weight = self.log_weight(abc_scenario, initial_state, initial_extra) return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: new_ensemble_state.log_weight = previous_ensemble_state.log_weight + new_ensemble_state.log_weight return new_ensemble_state, new_extra