Example #1
0
def sample_from_hull(hull_points: List[ConvexHullPoint],
                     convex_hull_eps: float,
                     sampling_max_try: int) -> ConvexHullPoint:
    front_points, eps_points, xs, ys = model_descs_on_front(
        hull_points, convex_hull_eps)

    logger.info(f'num models in pool: {len(hull_points)}')
    logger.info(f'num models on front: {len(front_points)}')
    logger.info(f'num models on front with eps: {len(eps_points)}')

    # reverse sort by metrics performance
    eps_points.sort(reverse=True, key=lambda p: p.metrics.best_val_top1())

    # default choice
    sampled_point = random.choice(hull_points)
    # go through sorted list of models near convex hull
    for _ in range(sampling_max_try):
        for point in eps_points:
            p = 1.0 / (point.sampling_count + 1.0)
            should_select = np.random.binomial(1, p)
            if should_select == 1:
                sampled_point = point

    # if here, sampling was not successful
    logger.warn('sampling was not successful, returning a random parent')

    sampled_point.sampling_count += 1

    return sampled_point
Example #2
0
    def _restore_checkpoint(self) -> bool:
        can_restore = self._checkpoint is not None \
                        and 'convex_hull_points' in self._checkpoint
        if can_restore:
            self._hull_points = self._checkpoint['convex_hull_points']
            logger.warn({'Hull restored': True})

        return can_restore
Example #3
0
    def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
        logger.pushd(self._title)

        self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)

        # create optimizers and schedulers
        self._multi_optim = self.create_multi_optim(len(train_dl))
        # before checkpoint restore, convert to amp
        self.model = self._apex.to_amp(self.model, self._multi_optim,
                                       batch_size=train_dl.batch_size)

        self._lossfn = self._lossfn.to(self.get_device())

        self.pre_fit(train_dl, val_dl)

        # we need to restore checkpoint after all objects are created because
        # restoring checkpoint requires load_state_dict calls on these objects
        self._start_epoch = 0
        # do we have a checkpoint
        checkpoint_avail = self._checkpoint is not None
        checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint
        resumed = False
        if checkpoint_val:
            # restore checkpoint
            resumed = True
            self.restore_checkpoint()
        elif checkpoint_avail: # TODO: bad checkpoint?
            self._checkpoint.clear()
        logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail,
                     'checkpoint_val': checkpoint_val,
                     'start_epoch': self._start_epoch,
                     'total_epochs': self._epochs})
        logger.info({'aux_weight': self._aux_weight,
                     'grad_clip': self._grad_clip,
                     'drop_path_prob': self._drop_path_prob,
                     'validation_freq': self._validation_freq,
                     'batch_chunks': self.batch_chunks})

        if self._start_epoch >= self._epochs:
            logger.warn(f'fit done because start_epoch {self._start_epoch}>={self._epochs}')
            return self.get_metrics() # we already finished the run, we might be checkpointed

        logger.pushd('epochs')
        for epoch in range(self._start_epoch, self._epochs):
            logger.pushd(epoch)
            self._set_epoch(epoch, train_dl, val_dl)
            self.pre_epoch(train_dl, val_dl)
            self._train_epoch(train_dl)
            self.post_epoch(train_dl, val_dl)
            logger.popd()
        logger.popd()
        self.post_fit(train_dl, val_dl)

        # make sure we don't keep references to the graph
        del self._multi_optim

        logger.popd()
        return self.get_metrics()
Example #4
0
    def __init__(self,
                 conf_train: Config,
                 model: nn.Module,
                 checkpoint: Optional[CheckPoint] = None) -> None:
        # region config vars
        self.conf_train = conf_train
        conf_lossfn = conf_train['lossfn']
        self._aux_weight = conf_train['aux_weight']
        self._grad_clip = conf_train['grad_clip']
        self._drop_path_prob = conf_train['drop_path_prob']
        self._logger_freq = conf_train['logger_freq']
        self._title = conf_train['title']
        self._epochs = conf_train['epochs']
        self.conf_optim = conf_train['optimizer']
        self.conf_sched = conf_train['lr_schedule']
        self.batch_chunks = conf_train['batch_chunks']
        conf_validation = conf_train['validation']
        conf_apex = conf_train['apex']
        self._validation_freq = 0 if conf_validation is None else conf_validation[
            'freq']
        # endregion

        logger.pushd(self._title + '__init__')

        self._apex = ApexUtils(conf_apex, logger)

        self._checkpoint = checkpoint
        self.model = model

        self._lossfn = ml_utils.get_lossfn(conf_lossfn)
        # using separate apex for Tester is not possible because we must use
        # same distributed model as Trainer and hence they must share apex
        self._tester = Tester(conf_validation, model, self._apex) \
                        if conf_validation else None
        self._metrics: Optional[Metrics] = None

        self._droppath_module = self._get_droppath_module()
        if self._droppath_module is None and self._drop_path_prob > 0.0:
            logger.warn({'droppath_module': None})

        self._start_epoch = -1  # nothing is started yet

        logger.popd()
Example #5
0
    def restore_checkpoint(self, conf_search:Config, macro_combinations)\
            ->Tuple[int, Optional[SearchResult]]:

        conf_pareto = conf_search['pareto']
        pareto_summary_filename = conf_pareto['summary_filename']

        summary_filepath = utils.full_path(pareto_summary_filename)

        # if checkpoint is available then restart from last combination we were running
        checkpoint_avail = self._checkpoint is not None
        resumed, state = False, None
        start_macro_i, best_result = 0, None
        if checkpoint_avail:
            state = self._checkpoint.get('search', None)
            if state is not None:
                start_macro_i = state['start_macro_i']
                assert start_macro_i >= 0 and start_macro_i < len(
                    macro_combinations)

                best_result = yaml.load(state['best_result'],
                                        Loader=yaml.Loader)

                start_macro_i += 1  # resume after the last checkpoint
                resumed = True

        if not resumed:
            # erase previous file left over from run
            utils.zero_file(summary_filepath)

        logger.warn({
            'resumed': resumed,
            'checkpoint_avail': checkpoint_avail,
            'checkpoint_val': state is not None,
            'start_macro_i': start_macro_i,
            'total_macro_combinations': len(macro_combinations)
        })
        return start_macro_i, best_result
Example #6
0
    def _restore_checkpoint(self, macro_combinations)\
            ->Tuple[int, Optional[SearchResult]]:
        checkpoint_avail = self._checkpoint is not None
        resumed, state = False, None
        start_macro, best_result = 0, None
        if checkpoint_avail:
            state = self._checkpoint.get('search', None)
            if state is not None:
                start_macro = state['start_macro']
                assert start_macro >= 0 and start_macro < len(macro_combinations)
                best_result = yaml.load(state['best_result'], Loader=yaml.Loader)

                start_macro += 1 # resume after the last checkpoint
                resumed = True

        if not resumed:
            # erase previous file left over from run
            utils.zero_file(self._parito_filepath)

        logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail,
                     'checkpoint_val': state is not None,
                     'start_macro': start_macro,
                     'total_macro_combinations': len(macro_combinations)})
        return start_macro, best_result