def _train_desc(self, model_desc:ModelDesc, conf_train:Config)->MetricsStats: """Train given description""" # region conf vars conf_trainer = conf_train['trainer'] conf_loader = conf_train['loader'] trainer_title = conf_trainer['title'] epochs = conf_trainer['epochs'] drop_path_prob = conf_trainer['drop_path_prob'] # endregion logger.pushd(trainer_title) if epochs == 0: # nothing to pretrain, save time metrics_stats = MetricsStats(model_desc, None, None) else: model = nas_utils.model_from_desc(model_desc, droppath=drop_path_prob>0.0, affine=True) # get data train_dl, val_dl = self.get_data(conf_loader) assert train_dl is not None trainer = Trainer(conf_trainer, model, checkpoint=None) train_metrics = trainer.fit(train_dl, val_dl) metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers) logger.popd() return metrics_stats
def _search_desc(self, model_desc:ModelDesc, search_iter:int)->ModelDesc: logger.pushd('arch_search') nas_utils.build_cell(model_desc, self.cell_builder, search_iter) if self.trainer_class: model = nas_utils.model_from_desc(model_desc, droppath=False, affine=False) # get data train_dl, val_dl = self.get_data(self.conf_loader) assert train_dl is not None # search arch arch_trainer = self.trainer_class(self.conf_train, model, checkpoint=None) train_metrics = arch_trainer.fit(train_dl, val_dl) metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers) found_desc = metrics_stats.model_desc else: # if no trainer needed, for example, for random search found_desc = model_desc logger.popd() return found_desc
def generate_pareto(self)->ModelDesc: macro_combinations = list(self._macro_combinations()) start_macro, best_result = self._restore_checkpoint(macro_combinations) for macro_comb_i in range(start_macro, len(macro_combinations)): reductions, cells, nodes = macro_combinations[macro_comb_i] logger.pushd(f'r{reductions}.c{cells}.n{nodes}') model_desc = self._build_macro(reductions, cells, nodes) # prep seed model and train it model_desc = self._seed_model(model_desc, reductions, cells, nodes) model_desc, best_result = self._search_iters(model_desc, best_result, reductions, cells, nodes) assert best_result is not None self._record_checkpoint(macro_comb_i, best_result) logger.popd() # reductions, cells, nodes assert best_result is not None best_result.model_desc().clear_trainables() logger.info({'best_macro_params':best_result.macro_params, 'best_metric':best_result.metrics_stats}) best_result.model_desc().save(self.final_desc_filename) return best_result.model_desc()
def _train_epoch(self, train_dl: DataLoader) -> None: steps = len(train_dl) self.model.train() logger.pushd('steps') for step, (x, y) in enumerate(train_dl): logger.pushd(step) assert self.model.training # derived class might alter the mode # TODO: please check that no algorithm is invalidated by swapping prestep with zero grad self._multi_optim.zero_grad() self.pre_step(x, y) # divide batch in to chunks if needed so it fits in GPU RAM if self.batch_chunks > 1: x_chunks, y_chunks = torch.chunk( x, self.batch_chunks), torch.chunk(y, self.batch_chunks) else: x_chunks, y_chunks = (x, ), (y, ) logits_chunks = [] loss_sum, loss_count = 0.0, 0 for xc, yc in zip(x_chunks, y_chunks): xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True) logits_c, aux_logits = self.model(xc), None tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >= 2 # if self._aux_weight: # TODO: some other way to validate? # assert tupled_out, "aux_logits cannot be None unless aux tower is disabled" if tupled_out: # then we are using model created by desc logits_c, aux_logits = logits_c[0], logits_c[1] loss_c = self.compute_loss(self._lossfn, yc, logits_c, self._aux_weight, aux_logits) self._apex.backward(loss_c, self._multi_optim) loss_sum += loss_c.item() * len(logits_c) loss_count += len(logits_c) logits_chunks.append(logits_c.detach().cpu()) # TODO: original darts clips alphas as well but pt.darts doesn't self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim) self._multi_optim.step() # TODO: we possibly need to sync so all replicas are upto date self._apex.sync_devices() self.post_step(x, y, ml_utils.join_chunks(logits_chunks), torch.tensor(loss_sum / loss_count), steps) logger.popd() # end of step self._multi_optim.epoch() logger.popd()
def eval_arch(conf_eval: Config, cell_builder: Optional[CellBuilder]): logger.pushd('eval_arch') # region conf vars conf_loader = conf_eval['loader'] model_filename = conf_eval['model_filename'] metric_filename = conf_eval['metric_filename'] conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] conf_train = conf_eval['trainer'] # endregion if cell_builder: cell_builder.register_ops() model = create_model(conf_eval) # get data train_dl, _, test_dl = data.get_data(conf_loader) assert train_dl is not None and test_dl is not None checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) trainer = Trainer(conf_train, model, checkpoint) train_metrics = trainer.fit(train_dl, test_dl) train_metrics.save(metric_filename) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) logger.info({'model_save_path': model_filename}) logger.popd()
def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\ ->Optional[ModelMetrics]: """Train given description""" # region conf vars conf_trainer = conf_train['trainer'] conf_loader = conf_train['loader'] trainer_title = conf_trainer['title'] epochs = conf_trainer['epochs'] drop_path_prob = conf_trainer['drop_path_prob'] # endregion # if epochs ==0 then nothing to train, so save time if epochs <= 0: return None logger.pushd(trainer_title) model = Model(model_desc, droppath=drop_path_prob > 0.0, affine=True) # get data data_loaders = self.get_data(conf_loader) trainer = Trainer(conf_trainer, model, checkpoint=None) train_metrics = trainer.fit(data_loaders) logger.popd() return ModelMetrics(model, train_metrics)
def evaluate(self, conf_eval: Config, model_desc_builder: ModelDescBuilder) -> EvalResult: logger.pushd('eval_arch') # region conf vars conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] model_filename = conf_eval['model_filename'] metric_filename = conf_eval['metric_filename'] # endregion model = self.create_model(conf_eval, model_desc_builder) checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) train_metrics = self.train_model(conf_eval, model, checkpoint) train_metrics.save(metric_filename) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) logger.info({'model_save_path': model_filename}) logger.popd() return EvalResult(train_metrics)
def search_model_desc(self, conf_search:Config, model_desc:ModelDesc, trainer_class:TArchTrainer, finalizers:Finalizers)\ ->Tuple[ModelDesc, Optional[Metrics]]: # if trainer is not specified for algos like random search we return same desc if trainer_class is None: return model_desc, None logger.pushd('arch_search') conf_trainer = conf_search['trainer'] conf_loader = conf_search['loader'] model = Model(model_desc, droppath=False, affine=False) # get data data_loaders = self.get_data(conf_loader) # search arch arch_trainer = trainer_class(conf_trainer, model, checkpoint=None) search_metrics = arch_trainer.fit(data_loaders) # finalize found_desc = self.finalize_model(model, finalizers) logger.popd() return found_desc, search_metrics
def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder, trainer_class: TArchTrainer, finalizers: Finalizers) -> SearchResult: # region config vars conf_model_desc = conf_search['model_desc'] conf_post_train = conf_search['post_train'] conf_checkpoint = conf_search['checkpoint'] resume = conf_search['resume'] # endregion self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) macro_combinations = list(self.get_combinations(conf_search)) start_macro_i, best_search_result = self.restore_checkpoint( conf_search, macro_combinations) best_macro_comb = -1, -1, -1 # reductions, cells, nodes for macro_comb_i in range(start_macro_i, len(macro_combinations)): reductions, cells, nodes = macro_combinations[macro_comb_i] logger.pushd(f'r{reductions}.c{cells}.n{nodes}') # build model description that we will search on model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) # perform search on model description model_desc, search_metrics = self.search_model_desc( conf_search, model_desc, trainer_class, finalizers) # train searched model for few epochs to get some perf metrics model_metrics = self.train_model_desc(model_desc, conf_post_train) assert model_metrics is not None, "'post_train' section in yaml should have non-zero epochs if running combinations search" # save result self.save_trained(conf_search, reductions, cells, nodes, model_metrics) # update the best result so far if self.is_better_metrics(best_search_result.search_metrics, model_metrics.metrics): best_search_result = SearchResult(model_desc, search_metrics, model_metrics.metrics) best_macro_comb = reductions, cells, nodes # checkpoint assert best_search_result is not None self.record_checkpoint(macro_comb_i, best_search_result) logger.popd() # reductions, cells, nodes assert best_search_result is not None self.clean_log_result(conf_search, best_search_result) logger.info({'best_macro_comb': best_macro_comb}) return best_search_result
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() conf_loader = conf['nas']['search']['loader'] train_dl, val_dl, test_dl = get_data(conf_loader) # wrap all cells in the model self._divnas_cells: Dict[int, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(train_dl): _, _ = model(x), None # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)
def evaluate(self, conf_eval: Config, model_desc_builder: ModelDescBuilder) -> EvalResult: """Takes a folder of model descriptions output by search process and trains them in a distributed manner using ray with 1 gpu""" logger.pushd('evaluate') final_desc_foldername: str = conf_eval['final_desc_foldername'] # get list of model descs in the gallery folder final_desc_folderpath = utils.full_path(final_desc_foldername) files = [os.path.join(final_desc_folderpath, f) \ for f in glob.glob(os.path.join(final_desc_folderpath, 'model_desc_*.yaml')) \ if os.path.isfile(os.path.join(final_desc_folderpath, f))] logger.info({'model_desc_files': len(files)}) # to avoid all workers download datasets individually, let's do it before hand self._ensure_dataset_download(conf_eval) future_ids = [] for model_desc_filename in files: future_id = EvaluaterPetridish._train_dist.remote( self, conf_eval, model_desc_builder, model_desc_filename, common.get_state()) future_ids.append(future_id) # wait for all eval jobs to be finished ready_refs, remaining_refs = ray.wait(future_ids, num_returns=len(future_ids)) # plot pareto curve of gallery of models hull_points = [ray.get(ready_ref) for ready_ref in ready_refs] save_hull(hull_points, common.get_expdir()) plot_pool(hull_points, common.get_expdir()) best_point = max(hull_points, key=lambda p: p.metrics.best_val_top1()) logger.info({ 'best_val_top1': best_point.metrics.best_val_top1(), 'best_MAdd': best_point.model_stats.MAdd }) logger.popd() return EvalResult(best_point.metrics)
def __init__(self, conf_train: Config, model: nn.Module, checkpoint: Optional[CheckPoint] = None) -> None: # region config vars self.conf_train = conf_train conf_lossfn = conf_train['lossfn'] self._aux_weight = conf_train['aux_weight'] self._grad_clip = conf_train['grad_clip'] self._drop_path_prob = conf_train['drop_path_prob'] self._logger_freq = conf_train['logger_freq'] self._title = conf_train['title'] self._epochs = conf_train['epochs'] self.conf_optim = conf_train['optimizer'] self.conf_sched = conf_train['lr_schedule'] self.batch_chunks = conf_train['batch_chunks'] conf_validation = conf_train['validation'] conf_apex = conf_train['apex'] self._validation_freq = 0 if conf_validation is None else conf_validation[ 'freq'] # endregion logger.pushd(self._title + '__init__') self._apex = ApexUtils(conf_apex, logger) self._checkpoint = checkpoint self.model = model self._lossfn = ml_utils.get_lossfn(conf_lossfn) # using separate apex for Tester is not possible because we must use # same distributed model as Trainer and hence they must share apex self._tester = Tester(conf_validation, model, self._apex) \ if conf_validation else None self._metrics: Optional[Metrics] = None self._droppath_module = self._get_droppath_module() if self._droppath_module is None and self._drop_path_prob > 0.0: logger.warn({'droppath_module': None}) self._start_epoch = -1 # nothing is started yet logger.popd()
def _search_iters(self, model_desc:ModelDesc, best_result:Optional[SearchResult], reductions:int, cells:int, nodes:int)->\ Tuple[ModelDesc, Optional[SearchResult]]: for search_iter in range(self.search_iters): logger.pushd(f'{search_iter}') # execute search iteration followed by training the model model_desc = self._search_desc(model_desc, search_iter) metrics_stats = self._train_desc(model_desc, self.conf_postsearch_train) model_desc = metrics_stats.model_desc # save result self._save_trained(reductions, cells, nodes, search_iter, metrics_stats) if metrics_stats.is_better(best_result.metrics_stats \ if best_result is not None else None): best_result = SearchResult(metrics_stats,(reductions, cells, nodes)) logger.popd() # search_iter return model_desc, best_result
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: logger.pushd(self._title) self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq) # create optimizers and schedulers self._multi_optim = self.create_multi_optim(len(train_dl)) # before checkpoint restore, convert to amp self.model = self._apex.to_amp(self.model, self._multi_optim, batch_size=train_dl.batch_size) self._lossfn = self._lossfn.to(self.get_device()) self.pre_fit(train_dl, val_dl) # we need to restore checkpoint after all objects are created because # restoring checkpoint requires load_state_dict calls on these objects self._start_epoch = 0 # do we have a checkpoint checkpoint_avail = self._checkpoint is not None checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint resumed = False if checkpoint_val: # restore checkpoint resumed = True self.restore_checkpoint() elif checkpoint_avail: # TODO: bad checkpoint? self._checkpoint.clear() logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': checkpoint_val, 'start_epoch': self._start_epoch, 'total_epochs': self._epochs}) logger.info({'aux_weight': self._aux_weight, 'grad_clip': self._grad_clip, 'drop_path_prob': self._drop_path_prob, 'validation_freq': self._validation_freq, 'batch_chunks': self.batch_chunks}) if self._start_epoch >= self._epochs: logger.warn(f'fit done because start_epoch {self._start_epoch}>={self._epochs}') return self.get_metrics() # we already finished the run, we might be checkpointed logger.pushd('epochs') for epoch in range(self._start_epoch, self._epochs): logger.pushd(epoch) self._set_epoch(epoch, train_dl, val_dl) self.pre_epoch(train_dl, val_dl) self._train_epoch(train_dl) self.post_epoch(train_dl, val_dl) logger.popd() logger.popd() self.post_fit(train_dl, val_dl) # make sure we don't keep references to the graph del self._multi_optim logger.popd() return self.get_metrics()
def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder, trainer_class: TArchTrainer, finalizers: Finalizers) -> SearchResult: logger.pushd('search') # region config vars self.conf_search = conf_search conf_checkpoint = conf_search['checkpoint'] resume = conf_search['resume'] conf_post_train = conf_search['post_train'] final_desc_foldername = conf_search['final_desc_foldername'] conf_petridish = conf_search['petridish'] # petridish distributed search related parameters self._convex_hull_eps = conf_petridish['convex_hull_eps'] self._sampling_max_try = conf_petridish['sampling_max_try'] self._max_madd = conf_petridish['max_madd'] self._max_hull_points = conf_petridish['max_hull_points'] self._checkpoints_foldername = conf_petridish['checkpoints_foldername'] # endregion self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) # parent models list self._hull_points: List[ConvexHullPoint] = [] self._ensure_dataset_download(conf_search) # checkpoint will restore the hull we had is_restored = self._restore_checkpoint() # seed the pool with many seed models of different # macro parameters like number of cells, reductions etc if parent pool # could not be restored and/or this is the first time this job has been run. future_ids = [] if is_restored else self._create_seed_jobs( conf_search, model_desc_builder) while not self._is_search_done(): logger.info(f'Ray jobs running: {len(future_ids)}') if future_ids: # get first completed job job_id_done, future_ids = ray.wait(future_ids) hull_point = ray.get(job_id_done[0]) logger.info( f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed' ) if hull_point.is_trained_stage(): self._update_convex_hull(hull_point) # sample a point and search sampled_point = sample_from_hull(self._hull_points, self._convex_hull_eps, self._sampling_max_try) future_id = SearcherPetridish.search_model_desc_dist.remote( self, conf_search, sampled_point, model_desc_builder, trainer_class, finalizers, common.get_state()) future_ids.append(future_id) logger.info( f'Added sampled point {sampled_point.id} for search') elif hull_point.job_stage == JobStage.SEARCH: # create the job to train the searched model future_id = SearcherPetridish.train_model_desc_dist.remote( self, conf_post_train, hull_point, common.get_state()) future_ids.append(future_id) logger.info( f'Added sampled point {hull_point.id} for post-search training' ) else: raise RuntimeError( f'Job stage "{hull_point.job_stage}" is not expected in search loop' ) # cancel any remaining jobs to free up gpus for the eval phase for future_id in future_ids: ray.cancel(future_id, force=True) # without force, main process stops ray.wait([future_id]) # plot and save the hull expdir = common.get_expdir() assert expdir plot_frontier(self._hull_points, self._convex_hull_eps, expdir) best_point = save_hull_frontier(self._hull_points, self._convex_hull_eps, final_desc_foldername, expdir) save_hull(self._hull_points, expdir) plot_pool(self._hull_points, expdir) # return best point as search result search_result = SearchResult(best_point.model_desc, search_metrics=None, train_metrics=best_point.metrics) self.clean_log_result(conf_search, search_result) logger.popd() return search_result