def clean_chain(self, scenario: Scenario, chain_ensemble_state: cdict) -> cdict: chain_ensemble_state.temperature = chain_ensemble_state.temperature[:, 0] scenario.temperature = float(chain_ensemble_state.temperature[-1]) chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0] return chain_ensemble_state
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'log_weight'): initial_state.log_weight = jnp.zeros(n) if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_temperature = self.next_temperature(new_ensemble_state, new_extra) new_ensemble_state.temperature = jnp.ones(n) * next_temperature new_ensemble_state.log_weight = previous_ensemble_state.log_weight \ + self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * jnp.exp( log_ess_log_weight(new_ensemble_state.log_weight)) return new_ensemble_state, new_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_threshold = self.next_threshold(new_ensemble_state, new_extra) new_ensemble_state.threshold = jnp.ones(n) * next_threshold new_extra.parameters.threshold = next_threshold new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight) alive_inds = previous_ensemble_state.log_weight > -jnp.inf new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum() new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) return new_ensemble_state, new_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: self.mcmc_sampler.correction = check_correction( self.mcmc_sampler.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) first_temp = self.next_temperature(initial_state, initial_extra) scenario.temperature = first_temp initial_state.temperature += first_temp initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential initial_state.log_weight = -first_temp * initial_state.likelihood_potential initial_state.ess = jnp.repeat( jnp.exp(log_ess_log_weight(initial_state.log_weight)), n) initial_state, initial_extra = vmap( lambda state: self.mcmc_sampler.startup( scenario, n, state, initial_extra))(initial_state) initial_extra = initial_extra[0] return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n, initial_state, initial_extra, **kwargs) n = len(initial_state.value) if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'simulated_data'): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'distance'): initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data) if not hasattr(initial_state, 'threshold'): if self.threshold_schedule is None: initial_state.threshold = jnp.zeros(n) + jnp.inf else: initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0] if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def clean_chain(self, abc_scenario: ABCScenario, chain_ensemble_state: cdict) -> cdict: chain_ensemble_state.threshold = chain_ensemble_state.threshold[:, 0] chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0] return chain_ensemble_state