def startup(self, scenario: Scenario, n: int, initial_state: Union[None, cdict], initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) if hasattr(self, 'parameters') and hasattr(self.parameters, key): setattr(self.parameters, key, value) if not hasattr(self, 'max_iter')\ or not (isinstance(self.max_iter, int) or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')): raise AttributeError(self.__repr__() + ' max_iter must be int') if not hasattr(initial_extra, 'iter'): initial_extra.iter = 0 if hasattr(self, 'parameters'): if not hasattr(initial_extra, 'parameters'): initial_extra.parameters = cdict() for key, value in self.parameters.__dict__.items(): if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None: setattr(initial_extra.parameters, key, value) return initial_state, initial_extra
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: extra.iter = extra.iter + 1 n = ensemble_state.value.shape[0] resample_bool = self.resample_criterion(ensemble_state, extra) random_keys_all = random.split(extra.random_key, n + 2) extra.random_key = random_keys_all[-1] resampled_ensemble_state \ = cond(resample_bool, lambda state: self.resample(state, random_keys_all[-2]), lambda state: state, ensemble_state) advanced_state = vmap(self.forward_proposal, in_axes=(None, 0, None, 0))(scenario, resampled_ensemble_state, extra, random_keys_all[:n]) advanced_state, advanced_extra = self.adapt(resampled_ensemble_state, extra, advanced_state, extra) return advanced_state, advanced_extra
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: n = ensemble_state.value.shape[0] extra.iter = extra.iter + 1 random_keys = random.split(extra.random_key, n + 2) batch_inds = self.get_batch_inds(random_keys[-1]) extra.random_key = random_keys[-2] phi_hat = self.kernelised_grad_matrix(ensemble_state.value, ensemble_state.grad_potential, extra.parameters.kernel_params, batch_inds) extra.opt_state = self.opt_update(extra.iter, -phi_hat, extra.opt_state) ensemble_state.value = self.get_params(extra.opt_state) ensemble_state.potential, ensemble_state.grad_potential \ = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n]) ensemble_state, extra = self.adapt(ensemble_state, extra) return ensemble_state, extra