예제 #1
0
파일: smc.py 프로젝트: SamDuffield/mocat
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if not hasattr(scenario, 'prior_sample'):
            raise TypeError(
                f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented'
            )

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        random_keys = random.split(initial_extra.random_key, 2 * n + 1)

        initial_extra.random_key = random_keys[-1]
        initial_state.prior_potential = vmap(scenario.prior_potential)(
            initial_state.value, random_keys[:n])
        initial_state.likelihood_potential = vmap(
            scenario.likelihood_potential)(initial_state.value,
                                           random_keys[n:(2 * n)])
        initial_state.potential = initial_state.prior_potential
        initial_state.temperature = jnp.zeros(n)
        initial_state.log_weight = jnp.zeros(n)
        initial_state.ess = jnp.zeros(n) + n

        scenario.temperature = 0.

        return initial_state, initial_extra
예제 #2
0
 def clean_chain_ar(self, abc_scenario: ABCScenario, chain_state: cdict):
     threshold = jnp.quantile(chain_state.distance,
                              self.parameters.acceptance_rate)
     self.parameters.threshold = float(threshold)
     chain_state.log_weight = jnp.where(chain_state.distance < threshold,
                                        0., -jnp.inf)
     return chain_state
예제 #3
0
파일: smc.py 프로젝트: SamDuffield/mocat
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     if not hasattr(initial_state, 'log_weight'):
         initial_state.log_weight = jnp.zeros(n)
     if not hasattr(initial_state, 'ess'):
         initial_state.ess = jnp.zeros(n) + n
     return initial_state, initial_extra
예제 #4
0
파일: smc.py 프로젝트: SamDuffield/mocat
 def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict,
           new_ensemble_state: cdict,
           new_extra: cdict) -> Tuple[cdict, cdict]:
     n = new_ensemble_state.value.shape[0]
     next_temperature = self.next_temperature(new_ensemble_state, new_extra)
     new_ensemble_state.temperature = jnp.ones(n) * next_temperature
     new_ensemble_state.log_weight = previous_ensemble_state.log_weight \
                                     + self.log_weight(previous_ensemble_state, previous_extra,
                                                       new_ensemble_state, new_extra)
     new_ensemble_state.ess = jnp.ones(n) * jnp.exp(
         log_ess_log_weight(new_ensemble_state.log_weight))
     return new_ensemble_state, new_extra
예제 #5
0
    def adapt(self,
              previous_ensemble_state: cdict,
              previous_extra: cdict,
              new_ensemble_state: cdict,
              new_extra: cdict) -> Tuple[cdict, cdict]:
        n = new_ensemble_state.value.shape[0]
        next_threshold = self.next_threshold(new_ensemble_state, new_extra)
        new_ensemble_state.threshold = jnp.ones(n) * next_threshold
        new_extra.parameters.threshold = next_threshold
        new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra,
                                                        new_ensemble_state, new_extra)
        new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight)
        alive_inds = previous_ensemble_state.log_weight > -jnp.inf
        new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum()
        new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra,
                                                               new_ensemble_state, new_extra)

        return new_ensemble_state, new_extra
예제 #6
0
파일: smc.py 프로젝트: SamDuffield/mocat
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:

        self.mcmc_sampler.correction = check_correction(
            self.mcmc_sampler.correction)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        first_temp = self.next_temperature(initial_state, initial_extra)
        scenario.temperature = first_temp
        initial_state.temperature += first_temp
        initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential
        initial_state.log_weight = -first_temp * initial_state.likelihood_potential
        initial_state.ess = jnp.repeat(
            jnp.exp(log_ess_log_weight(initial_state.log_weight)), n)

        initial_state, initial_extra = vmap(
            lambda state: self.mcmc_sampler.startup(
                scenario, n, state, initial_extra))(initial_state)
        initial_extra = initial_extra[0]
        return initial_state, initial_extra
예제 #7
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict = None,
                initial_extra: cdict = None,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(abc_scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = abc_scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(abc_scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        initial_state.log_weight = self.log_weight(abc_scenario, initial_state,
                                                   initial_extra)
        return initial_state, initial_extra
예제 #8
0
파일: smc.py 프로젝트: SamDuffield/mocat
 def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict,
           new_ensemble_state: cdict,
           new_extra: cdict) -> Tuple[cdict, cdict]:
     new_ensemble_state.log_weight = previous_ensemble_state.log_weight + new_ensemble_state.log_weight
     return new_ensemble_state, new_extra