Ejemplo n.º 1
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if not hasattr(initial_state, 'prior_potential') and is_implemented(
                abc_scenario.prior_potential):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.prior_potential = abc_scenario.prior_potential(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'simulated_data'):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.simulated_data = abc_scenario.likelihood_sample(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = abc_scenario.distance_function(
                initial_state.simulated_data)
        return initial_state, initial_extra
Ejemplo n.º 2
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:

        initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n,
                                                          initial_state, initial_extra, **kwargs)

        n = len(initial_state.value)
        if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value,
                                                                               random_keys[:n])

        if not hasattr(initial_state, 'simulated_data'):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value,
                                                                                random_keys[:n])

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data)

        if not hasattr(initial_state, 'threshold'):
            if self.threshold_schedule is None:
                initial_state.threshold = jnp.zeros(n) + jnp.inf
            else:
                initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0]

        if not hasattr(initial_state, 'ess'):
            initial_state.ess = jnp.zeros(n) + n

        return initial_state, initial_extra