Exemplo n.º 1
0
 def clean_chain(self, scenario: Scenario,
                 chain_ensemble_state: cdict) -> cdict:
     chain_ensemble_state.temperature = chain_ensemble_state.temperature[:,
                                                                         0]
     scenario.temperature = float(chain_ensemble_state.temperature[-1])
     chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0]
     return chain_ensemble_state
Exemplo n.º 2
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if not hasattr(scenario, 'prior_sample'):
            raise TypeError(
                f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented'
            )

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        random_keys = random.split(initial_extra.random_key, 2 * n + 1)

        initial_extra.random_key = random_keys[-1]
        initial_state.prior_potential = vmap(scenario.prior_potential)(
            initial_state.value, random_keys[:n])
        initial_state.likelihood_potential = vmap(
            scenario.likelihood_potential)(initial_state.value,
                                           random_keys[n:(2 * n)])
        initial_state.potential = initial_state.prior_potential
        initial_state.temperature = jnp.zeros(n)
        initial_state.log_weight = jnp.zeros(n)
        initial_state.ess = jnp.zeros(n) + n

        scenario.temperature = 0.

        return initial_state, initial_extra
Exemplo n.º 3
0
 def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict,
           new_ensemble_state: cdict,
           new_extra: cdict) -> Tuple[cdict, cdict]:
     n = new_ensemble_state.value.shape[0]
     next_temperature = self.next_temperature(new_ensemble_state, new_extra)
     new_ensemble_state.temperature = jnp.ones(n) * next_temperature
     new_ensemble_state.log_weight = previous_ensemble_state.log_weight \
                                     + self.log_weight(previous_ensemble_state, previous_extra,
                                                       new_ensemble_state, new_extra)
     new_ensemble_state.ess = jnp.ones(n) * jnp.exp(
         log_ess_log_weight(new_ensemble_state.log_weight))
     return new_ensemble_state, new_extra