def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, cov:np.array, cell: Cell, node_id: int, *args, **kwargs)->NodeDesc: # node is a list of edges assert len(node) >= max_final_edges # covariance matrix shape must be square 2-D assert len(cov.shape) == 2 assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges # get the order and alpha of all ops other than 'none' in_ops = [(edge,op,alpha,i) for i, edge in enumerate(node) \ for op, alpha in edge._op.ops() if not isinstance(op, Zero)] assert len(in_ops) >= max_final_edges # order all the ops by alpha in_ops_sorted = sorted(in_ops, key=lambda in_op:in_op[2], reverse=True) # keep under consideration top half of the ops num_to_keep = max(max_final_edges, len(in_ops_sorted)//2) top_ops = in_ops_sorted[:num_to_keep] # get the covariance submatrix of the top ops only cov_inds = [] for edge, op, alpha, edge_num in top_ops: ind = self._divnas_cells[cell].node_num_to_node_op_to_cov_ind[node_id][op] cov_inds.append(ind) cov_top_ops = cov[np.ix_(cov_inds, cov_inds)] assert len(cov_inds) == len(top_ops) assert len(top_ops) >= max_final_edges assert cov_top_ops.shape[0] == cov_top_ops.shape[1] assert len(cov_top_ops.shape) == 2 # run brute force set selection algorithm # only on the top ops max_subset, max_mi = compute_brute_force_sol(cov_top_ops, max_final_edges) # note that elements of max_subset are indices into top_ops only selected_edges = [] for ind in max_subset: edge, op, alpha, edge_num = top_ops[ind] op_desc, _ = op.finalize() new_edge = EdgeDesc(op_desc, edge.input_ids) logger.info(f'selected edge: {edge_num}, op: {op_desc.name}') selected_edges.append(new_edge) # save diagnostic information to disk expdir = get_expdir() sns.heatmap(cov_top_ops, annot=True, fmt='.1g', cmap='coolwarm') savename = os.path.join( expdir, f'cell_{cell.desc.id}_node_{node_id}_cov.png') plt.savefig(savename) logger.info('') return NodeDesc(selected_edges, node_desc.conv_params)
def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder, trainer_class: TArchTrainer, finalizers: Finalizers) -> SearchResult: logger.pushd('search') # region config vars self.conf_search = conf_search conf_checkpoint = conf_search['checkpoint'] resume = conf_search['resume'] conf_post_train = conf_search['post_train'] final_desc_foldername = conf_search['final_desc_foldername'] conf_petridish = conf_search['petridish'] # petridish distributed search related parameters self._convex_hull_eps = conf_petridish['convex_hull_eps'] self._sampling_max_try = conf_petridish['sampling_max_try'] self._max_madd = conf_petridish['max_madd'] self._max_hull_points = conf_petridish['max_hull_points'] self._checkpoints_foldername = conf_petridish['checkpoints_foldername'] # endregion self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) # parent models list self._hull_points: List[ConvexHullPoint] = [] self._ensure_dataset_download(conf_search) # checkpoint will restore the hull we had is_restored = self._restore_checkpoint() # seed the pool with many seed models of different # macro parameters like number of cells, reductions etc if parent pool # could not be restored and/or this is the first time this job has been run. future_ids = [] if is_restored else self._create_seed_jobs( conf_search, model_desc_builder) while not self._is_search_done(): logger.info(f'Ray jobs running: {len(future_ids)}') if future_ids: # get first completed job job_id_done, future_ids = ray.wait(future_ids) hull_point = ray.get(job_id_done[0]) logger.info( f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed' ) if hull_point.is_trained_stage(): self._update_convex_hull(hull_point) # sample a point and search sampled_point = sample_from_hull(self._hull_points, self._convex_hull_eps, self._sampling_max_try) future_id = SearcherPetridish.search_model_desc_dist.remote( self, conf_search, sampled_point, model_desc_builder, trainer_class, finalizers, common.get_state()) future_ids.append(future_id) logger.info( f'Added sampled point {sampled_point.id} for search') elif hull_point.job_stage == JobStage.SEARCH: # create the job to train the searched model future_id = SearcherPetridish.train_model_desc_dist.remote( self, conf_post_train, hull_point, common.get_state()) future_ids.append(future_id) logger.info( f'Added sampled point {hull_point.id} for post-search training' ) else: raise RuntimeError( f'Job stage "{hull_point.job_stage}" is not expected in search loop' ) # cancel any remaining jobs to free up gpus for the eval phase for future_id in future_ids: ray.cancel(future_id, force=True) # without force, main process stops ray.wait([future_id]) # plot and save the hull expdir = common.get_expdir() assert expdir plot_frontier(self._hull_points, self._convex_hull_eps, expdir) best_point = save_hull_frontier(self._hull_points, self._convex_hull_eps, final_desc_foldername, expdir) save_hull(self._hull_points, expdir) plot_pool(self._hull_points, expdir) # return best point as search result search_result = SearchResult(best_point.model_desc, search_metrics=None, train_metrics=best_point.metrics) self.clean_log_result(conf_search, search_result) logger.popd() return search_result
def fit(self, data_loaders: data.DataLoaders) -> Metrics: logger.pushd(self._title) assert data_loaders.train_dl is not None self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq) # create optimizers and schedulers self._multi_optim = self.create_multi_optim(len(data_loaders.train_dl)) # before checkpoint restore, convert to amp self.model = self._apex.to_amp( self.model, self._multi_optim, batch_size=data_loaders.train_dl.batch_size) self._lossfn = self._lossfn.to(self.get_device()) self.pre_fit(data_loaders) # we need to restore checkpoint after all objects are created because # restoring checkpoint requires load_state_dict calls on these objects self._start_epoch = 0 # do we have a checkpoint checkpoint_avail = self._checkpoint is not None checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint resumed = False if checkpoint_val: # restore checkpoint resumed = True self.restore_checkpoint() elif checkpoint_avail: # TODO: bad checkpoint? self._checkpoint.clear() logger.warn({ 'resumed': resumed, 'checkpoint_avail': checkpoint_avail, 'checkpoint_val': checkpoint_val, 'start_epoch': self._start_epoch, 'total_epochs': self._epochs }) logger.info({ 'aux_weight': self._aux_weight, 'grad_clip': self._grad_clip, 'drop_path_prob': self._drop_path_prob, 'validation_freq': self._validation_freq, 'batch_chunks': self.batch_chunks }) if self._start_epoch >= self._epochs: logger.warn( f'fit done because start_epoch {self._start_epoch}>={self._epochs}' ) return self.get_metrics( ) # we already finished the run, we might be checkpointed logger.pushd('epochs') for epoch in range(self._start_epoch, self._epochs): logger.pushd(epoch) self._set_epoch(epoch, data_loaders) self.pre_epoch(data_loaders) self._train_epoch(data_loaders.train_dl) self.post_epoch(data_loaders) logger.popd() logger.popd() self.post_fit(data_loaders) # make sure we don't keep references to the graph del self._multi_optim logger.popd() return self.get_metrics()
def _train_dist(evaluater: Evaluater, conf_eval: Config, model_desc_builder: ModelDescBuilder, model_desc_filename: str, common_state) -> ConvexHullPoint: """Train given a model""" common.init_from(common_state) # region config vars conf_model_desc = conf_eval['model_desc'] max_cells = conf_model_desc['n_cells'] conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] conf_petridish = conf_eval['petridish'] cell_count_scale = conf_petridish['cell_count_scale'] #endregion #register ops as we are in different process now model_desc_builder.pre_build(conf_model_desc) model_filename = utils.append_to_filename(model_desc_filename, '_model', '.pt') full_desc_filename = utils.append_to_filename(model_desc_filename, '_full', '.yaml') metrics_filename = utils.append_to_filename(model_desc_filename, '_metrics', '.yaml') model_stats_filename = utils.append_to_filename( model_desc_filename, '_model_stats', '.yaml') # DEBUG print(f'received {model_desc_filename}') # create checkpoint for this specific model desc by changing the config checkpoint = None if conf_checkpoint is not None: conf_checkpoint['filename'] = model_filename.split( '.')[0] + '_checkpoint.pth' checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) if checkpoint is not None and resume: if 'metrics_stats' in checkpoint: # return the output we had recorded in the checkpoint convex_hull_point = checkpoint['metrics_stats'] return convex_hull_point # template model is what we used during the search template_model_desc = ModelDesc.load(model_desc_filename) # we first scale this model by number of cells, keeping reductions same as in search n_cells = math.ceil( len(template_model_desc.cell_descs()) * cell_count_scale) n_cells = min(n_cells, max_cells) # DEBUG print( f'{model_desc_filename} has {len(template_model_desc.cell_descs())} cells, scaling to {n_cells} cells via {cell_count_scale} factor' ) conf_model_desc = copy.deepcopy(conf_model_desc) conf_model_desc['n_cells'] = n_cells conf_model_desc[ 'n_reductions'] = n_reductions = template_model_desc.cell_type_count( CellType.Reduction) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = evaluater.model_from_desc(model_desc) train_metrics = evaluater.train_model(conf_eval, model, checkpoint) train_metrics.save(metrics_filename) # get metrics_stats model_stats = nas_utils.get_model_stats(model) # save metrics_stats with open(model_stats_filename, 'w') as f: yaml.dump(model_stats, f) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) logger.info({'model_save_path': model_filename}) hull_point = ConvexHullPoint( JobStage.EVAL_TRAINED, 0, 0, model_desc, (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())), metrics=train_metrics, model_stats=model_stats) if checkpoint: checkpoint.new() checkpoint['metrics_stats'] = hull_point checkpoint.commit() return hull_point