Exemple #1
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:

        initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n,
                                                          initial_state, initial_extra, **kwargs)

        n = len(initial_state.value)
        if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value,
                                                                               random_keys[:n])

        if not hasattr(initial_state, 'simulated_data'):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value,
                                                                                random_keys[:n])

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data)

        if not hasattr(initial_state, 'threshold'):
            if self.threshold_schedule is None:
                initial_state.threshold = jnp.zeros(n) + jnp.inf
            else:
                initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0]

        if not hasattr(initial_state, 'ess'):
            initial_state.ess = jnp.zeros(n) + n

        return initial_state, initial_extra
Exemple #2
0
 def startup(self, abc_scenario: ABCScenario, sampler: MCMCSampler, n: int,
             initial_state: cdict, initial_extra: cdict,
             **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(abc_scenario, sampler,
                                                    n, initial_state,
                                                    initial_extra, **kwargs)
     if sampler.parameters.stepsize is None:
         initial_extra.parameters.stepsize = 1.0
     if sampler.parameters.threshold is None:
         initial_extra.parameters.threshold = 50.
         initial_state.threshold = initial_extra.parameters.threshold
     initial_extra.parameters.stepsize = jnp.ones(
         abc_scenario.dim) * initial_extra.parameters.stepsize
     initial_state.stepsize = initial_extra.parameters.stepsize
     initial_extra.post_mean = initial_state.value
     initial_extra.diag_post_cov = initial_extra.parameters.stepsize * abc_scenario.dim / 2.38**2
     return initial_state, initial_extra
Exemple #3
0
    def adapt(self,
              previous_ensemble_state: cdict,
              previous_extra: cdict,
              new_ensemble_state: cdict,
              new_extra: cdict) -> Tuple[cdict, cdict]:
        n = new_ensemble_state.value.shape[0]
        next_threshold = self.next_threshold(new_ensemble_state, new_extra)
        new_ensemble_state.threshold = jnp.ones(n) * next_threshold
        new_extra.parameters.threshold = next_threshold
        new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra,
                                                        new_ensemble_state, new_extra)
        new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight)
        alive_inds = previous_ensemble_state.log_weight > -jnp.inf
        new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum()
        new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra,
                                                               new_ensemble_state, new_extra)

        return new_ensemble_state, new_extra
Exemple #4
0
 def clean_chain(self,
                 abc_scenario: ABCScenario,
                 chain_ensemble_state: cdict) -> cdict:
     chain_ensemble_state.threshold = chain_ensemble_state.threshold[:, 0]
     chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0]
     return chain_ensemble_state