Esempio n. 1
0
    def pre_fit(self, train_dl: DataLoader,
                val_dl: Optional[DataLoader]) -> None:
        super().pre_fit(train_dl, val_dl)

        # optimizers, schedulers needs to be recreated for each fit call
        # as they have state
        assert val_dl is not None
        lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())

        self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model,
                                          lossfn)
    def pre_fit(self, data_loaders: data.DataLoaders) -> None:
        super().pre_fit(data_loaders)

        # optimizers, schedulers needs to be recreated for each fit call
        # as they have state
        assert data_loaders.val_dl is not None
        w_momentum = self._conf_w_optim['momentum']
        w_decay = self._conf_w_optim['decay']
        lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())

        self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim,
                                               w_momentum, w_decay, self.model,
                                               lossfn, self.get_device(),
                                               self.batch_chunks)
Esempio n. 3
0
    def __init__(self,
                 conf_train: Config,
                 model: nn.Module,
                 checkpoint: Optional[CheckPoint] = None) -> None:
        # region config vars
        self.conf_train = conf_train
        conf_lossfn = conf_train['lossfn']
        self._aux_weight = conf_train['aux_weight']
        self._grad_clip = conf_train['grad_clip']
        self._drop_path_prob = conf_train['drop_path_prob']
        self._logger_freq = conf_train['logger_freq']
        self._title = conf_train['title']
        self._epochs = conf_train['epochs']
        self.conf_optim = conf_train['optimizer']
        self.conf_sched = conf_train['lr_schedule']
        self.batch_chunks = conf_train['batch_chunks']
        conf_validation = conf_train['validation']
        conf_apex = conf_train['apex']
        self._validation_freq = 0 if conf_validation is None else conf_validation[
            'freq']
        # endregion

        logger.pushd(self._title + '__init__')

        self._apex = ApexUtils(conf_apex, logger)

        self._checkpoint = checkpoint
        self.model = model

        self._lossfn = ml_utils.get_lossfn(conf_lossfn)
        # using separate apex for Tester is not possible because we must use
        # same distributed model as Trainer and hence they must share apex
        self._tester = Tester(conf_validation, model, self._apex) \
                        if conf_validation else None
        self._metrics: Optional[Metrics] = None

        self._droppath_module = self._get_droppath_module()
        if self._droppath_module is None and self._drop_path_prob > 0.0:
            logger.warn({'droppath_module': None})

        self._start_epoch = -1  # nothing is started yet

        logger.popd()