예제 #1
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: Union[None, cdict],
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            if hasattr(self, 'parameters') and hasattr(self.parameters, key):
                setattr(self.parameters, key, value)

        if not hasattr(self, 'max_iter')\
            or not (isinstance(self.max_iter, int)
                    or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')):
            raise AttributeError(self.__repr__() + ' max_iter must be int')

        if not hasattr(initial_extra, 'iter'):
            initial_extra.iter = 0

        if hasattr(self, 'parameters'):
            if not hasattr(initial_extra, 'parameters'):
                initial_extra.parameters = cdict()
            for key, value in self.parameters.__dict__.items():
                if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None:
                    setattr(initial_extra.parameters, key, value)
        return initial_state, initial_extra
예제 #2
0
파일: smc.py 프로젝트: SamDuffield/mocat
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        extra.iter = extra.iter + 1
        n = ensemble_state.value.shape[0]

        resample_bool = self.resample_criterion(ensemble_state, extra)

        random_keys_all = random.split(extra.random_key, n + 2)
        extra.random_key = random_keys_all[-1]

        resampled_ensemble_state \
            = cond(resample_bool,
                   lambda state: self.resample(state, random_keys_all[-2]),
                   lambda state: state,
                   ensemble_state)

        advanced_state = vmap(self.forward_proposal,
                              in_axes=(None, 0, None,
                                       0))(scenario, resampled_ensemble_state,
                                           extra, random_keys_all[:n])
        advanced_state, advanced_extra = self.adapt(resampled_ensemble_state,
                                                    extra, advanced_state,
                                                    extra)

        return advanced_state, advanced_extra
예제 #3
0
파일: svgd.py 프로젝트: SamDuffield/mocat
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        n = ensemble_state.value.shape[0]
        extra.iter = extra.iter + 1

        random_keys = random.split(extra.random_key, n + 2)
        batch_inds = self.get_batch_inds(random_keys[-1])
        extra.random_key = random_keys[-2]

        phi_hat = self.kernelised_grad_matrix(ensemble_state.value,
                                              ensemble_state.grad_potential,
                                              extra.parameters.kernel_params,
                                              batch_inds)

        extra.opt_state = self.opt_update(extra.iter, -phi_hat,
                                          extra.opt_state)
        ensemble_state.value = self.get_params(extra.opt_state)

        ensemble_state.potential, ensemble_state.grad_potential \
            = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n])

        ensemble_state, extra = self.adapt(ensemble_state, extra)

        return ensemble_state, extra