def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'prior_potential') and is_implemented( abc_scenario.prior_potential): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.prior_potential = abc_scenario.prior_potential( initial_state.value, subkey) if not hasattr(initial_state, 'simulated_data'): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.simulated_data = abc_scenario.likelihood_sample( initial_state.value, subkey) if not hasattr(initial_state, 'distance'): initial_state.distance = abc_scenario.distance_function( initial_state.simulated_data) return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n, initial_state, initial_extra, **kwargs) n = len(initial_state.value) if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'simulated_data'): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'distance'): initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data) if not hasattr(initial_state, 'threshold'): if self.threshold_schedule is None: initial_state.threshold = jnp.zeros(n) + jnp.inf else: initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0] if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra