def clean_chain(self, scenario: Scenario, chain_ensemble_state: cdict) -> cdict: chain_ensemble_state.temperature = chain_ensemble_state.temperature[:, 0] scenario.temperature = float(chain_ensemble_state.temperature[-1]) chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0] return chain_ensemble_state
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_temperature = self.next_temperature(new_ensemble_state, new_extra) new_ensemble_state.temperature = jnp.ones(n) * next_temperature new_ensemble_state.log_weight = previous_ensemble_state.log_weight \ + self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * jnp.exp( log_ess_log_weight(new_ensemble_state.log_weight)) return new_ensemble_state, new_extra