def sample_from_hull(hull_points: List[ConvexHullPoint], convex_hull_eps: float, sampling_max_try: int) -> ConvexHullPoint: front_points, eps_points, xs, ys = model_descs_on_front( hull_points, convex_hull_eps) logger.info(f'num models in pool: {len(hull_points)}') logger.info(f'num models on front: {len(front_points)}') logger.info(f'num models on front with eps: {len(eps_points)}') # reverse sort by metrics performance eps_points.sort(reverse=True, key=lambda p: p.metrics.best_val_top1()) # default choice sampled_point = random.choice(hull_points) # go through sorted list of models near convex hull for _ in range(sampling_max_try): for point in eps_points: p = 1.0 / (point.sampling_count + 1.0) should_select = np.random.binomial(1, p) if should_select == 1: sampled_point = point # if here, sampling was not successful logger.warn('sampling was not successful, returning a random parent') sampled_point.sampling_count += 1 return sampled_point
def _restore_checkpoint(self) -> bool: can_restore = self._checkpoint is not None \ and 'convex_hull_points' in self._checkpoint if can_restore: self._hull_points = self._checkpoint['convex_hull_points'] logger.warn({'Hull restored': True}) return can_restore
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: logger.pushd(self._title) self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq) # create optimizers and schedulers self._multi_optim = self.create_multi_optim(len(train_dl)) # before checkpoint restore, convert to amp self.model = self._apex.to_amp(self.model, self._multi_optim, batch_size=train_dl.batch_size) self._lossfn = self._lossfn.to(self.get_device()) self.pre_fit(train_dl, val_dl) # we need to restore checkpoint after all objects are created because # restoring checkpoint requires load_state_dict calls on these objects self._start_epoch = 0 # do we have a checkpoint checkpoint_avail = self._checkpoint is not None checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint resumed = False if checkpoint_val: # restore checkpoint resumed = True self.restore_checkpoint() elif checkpoint_avail: # TODO: bad checkpoint? self._checkpoint.clear() logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': checkpoint_val, 'start_epoch': self._start_epoch, 'total_epochs': self._epochs}) logger.info({'aux_weight': self._aux_weight, 'grad_clip': self._grad_clip, 'drop_path_prob': self._drop_path_prob, 'validation_freq': self._validation_freq, 'batch_chunks': self.batch_chunks}) if self._start_epoch >= self._epochs: logger.warn(f'fit done because start_epoch {self._start_epoch}>={self._epochs}') return self.get_metrics() # we already finished the run, we might be checkpointed logger.pushd('epochs') for epoch in range(self._start_epoch, self._epochs): logger.pushd(epoch) self._set_epoch(epoch, train_dl, val_dl) self.pre_epoch(train_dl, val_dl) self._train_epoch(train_dl) self.post_epoch(train_dl, val_dl) logger.popd() logger.popd() self.post_fit(train_dl, val_dl) # make sure we don't keep references to the graph del self._multi_optim logger.popd() return self.get_metrics()
def __init__(self, conf_train: Config, model: nn.Module, checkpoint: Optional[CheckPoint] = None) -> None: # region config vars self.conf_train = conf_train conf_lossfn = conf_train['lossfn'] self._aux_weight = conf_train['aux_weight'] self._grad_clip = conf_train['grad_clip'] self._drop_path_prob = conf_train['drop_path_prob'] self._logger_freq = conf_train['logger_freq'] self._title = conf_train['title'] self._epochs = conf_train['epochs'] self.conf_optim = conf_train['optimizer'] self.conf_sched = conf_train['lr_schedule'] self.batch_chunks = conf_train['batch_chunks'] conf_validation = conf_train['validation'] conf_apex = conf_train['apex'] self._validation_freq = 0 if conf_validation is None else conf_validation[ 'freq'] # endregion logger.pushd(self._title + '__init__') self._apex = ApexUtils(conf_apex, logger) self._checkpoint = checkpoint self.model = model self._lossfn = ml_utils.get_lossfn(conf_lossfn) # using separate apex for Tester is not possible because we must use # same distributed model as Trainer and hence they must share apex self._tester = Tester(conf_validation, model, self._apex) \ if conf_validation else None self._metrics: Optional[Metrics] = None self._droppath_module = self._get_droppath_module() if self._droppath_module is None and self._drop_path_prob > 0.0: logger.warn({'droppath_module': None}) self._start_epoch = -1 # nothing is started yet logger.popd()
def restore_checkpoint(self, conf_search:Config, macro_combinations)\ ->Tuple[int, Optional[SearchResult]]: conf_pareto = conf_search['pareto'] pareto_summary_filename = conf_pareto['summary_filename'] summary_filepath = utils.full_path(pareto_summary_filename) # if checkpoint is available then restart from last combination we were running checkpoint_avail = self._checkpoint is not None resumed, state = False, None start_macro_i, best_result = 0, None if checkpoint_avail: state = self._checkpoint.get('search', None) if state is not None: start_macro_i = state['start_macro_i'] assert start_macro_i >= 0 and start_macro_i < len( macro_combinations) best_result = yaml.load(state['best_result'], Loader=yaml.Loader) start_macro_i += 1 # resume after the last checkpoint resumed = True if not resumed: # erase previous file left over from run utils.zero_file(summary_filepath) logger.warn({ 'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': state is not None, 'start_macro_i': start_macro_i, 'total_macro_combinations': len(macro_combinations) }) return start_macro_i, best_result
def _restore_checkpoint(self, macro_combinations)\ ->Tuple[int, Optional[SearchResult]]: checkpoint_avail = self._checkpoint is not None resumed, state = False, None start_macro, best_result = 0, None if checkpoint_avail: state = self._checkpoint.get('search', None) if state is not None: start_macro = state['start_macro'] assert start_macro >= 0 and start_macro < len(macro_combinations) best_result = yaml.load(state['best_result'], Loader=yaml.Loader) start_macro += 1 # resume after the last checkpoint resumed = True if not resumed: # erase previous file left over from run utils.zero_file(self._parito_filepath) logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': state is not None, 'start_macro': start_macro, 'total_macro_combinations': len(macro_combinations)}) return start_macro, best_result