def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config, hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder, trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\ ->ConvexHullPoint: # as this runs in different process, initiaze globals common.init_from(common_state) #register ops as we are in different process now conf_model_desc = conf_search['model_desc'] model_desc_builder.pre_build(conf_model_desc) assert hull_point.is_trained_stage() # cloning is strictly not needed but just in case if we run this # function in same process, it would be good to avoid surprise model_desc = hull_point.model_desc.clone() searcher._add_node(model_desc, model_desc_builder) model_desc, search_metrics = searcher.search_model_desc( conf_search, model_desc, trainer_class, finalizers) cells, reductions, nodes = hull_point.cells_reductions_nodes new_point = ConvexHullPoint( JobStage.SEARCH, hull_point.id, hull_point.sampling_count, model_desc, (cells, reductions, nodes + 1), # we added a node metrics=search_metrics) return new_point
def _create_seed_jobs(self, conf_search: Config, model_desc_builder: ModelDescBuilder) -> list: conf_model_desc = conf_search['model_desc'] conf_seed_train = conf_search['seed_train'] future_ids = [] # ray job IDs macro_combinations = list(self.get_combinations(conf_search)) for reductions, cells, nodes in macro_combinations: # if N R N R N R cannot be satisfied, ignore combination if cells < reductions * 2 + 1: continue # create seed model model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc, (cells, reductions, nodes)) # pre-train the seed model future_id = SearcherPetridish.train_model_desc_dist.remote( self, conf_seed_train, hull_point, common.get_state()) future_ids.append(future_id) return future_ids
def train_model_desc_dist(searcher:'SearcherPetridish', conf_train:Config, hull_point:ConvexHullPoint, common_state:CommonState)\ ->ConvexHullPoint: # as this runs in different process, initialize globals common.init_from(common_state) assert not hull_point.is_trained_stage() model_metrics = searcher.train_model_desc(hull_point.model_desc, conf_train) model_stats = nas_utils.get_model_stats(model_metrics.model) new_point = ConvexHullPoint(hull_point.next_stage(), hull_point.id, hull_point. sampling_count, hull_point.model_desc, hull_point.cells_reductions_nodes, model_metrics.metrics, model_stats) return new_point
def _update_convex_hull(self, new_point:ConvexHullPoint)->None: assert new_point.is_trained_stage() # only add models for which we have metrics and stats self._hull_points.append(new_point) if self._checkpoint is not None: self._checkpoint.new() self._checkpoint['convex_hull_points'] = self._hull_points self._checkpoint.commit() logger.info(f'Added to convex hull points: MAdd {new_point.model_stats.MAdd}, ' f'num cells {len(new_point.model_desc.cell_descs())}, ' f'num nodes in cell {len(new_point.model_desc.cell_descs()[0].nodes())}')
def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list: conf_model_desc = conf_search['model_desc'] conf_seed_train = conf_search['seed_train'] future_ids = [] # ray job IDs seed_model_stats = [] # seed model stats for visualization and debugging macro_combinations = list(self.get_combinations(conf_search)) for reductions, cells, nodes in macro_combinations: # if N R N R N R cannot be satisfied, ignore combination if cells < reductions * 2 + 1: continue # create seed model model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc, (cells, reductions, nodes)) # pre-train the seed model future_id = SearcherPetridish.train_model_desc_dist.remote(self, conf_seed_train, hull_point, common.get_state()) future_ids.append(future_id) # build a model so we can get its model stats temp_model = Model(model_desc, droppath=True, affine=True) seed_model_stats.append(nas_utils.get_model_stats(temp_model)) # save the model stats in a plot and tsv file so we can # visualize the spread on the x-axis expdir = common.get_expdir() assert expdir plot_seed_model_stats(seed_model_stats, expdir) return future_ids
def _train_dist(evaluater: Evaluater, conf_eval: Config, model_desc_builder: ModelDescBuilder, model_desc_filename: str, common_state) -> ConvexHullPoint: """Train given a model""" common.init_from(common_state) # region config vars conf_model_desc = conf_eval['model_desc'] max_cells = conf_model_desc['n_cells'] conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] conf_petridish = conf_eval['petridish'] cell_count_scale = conf_petridish['cell_count_scale'] #endregion #register ops as we are in different process now model_desc_builder.pre_build(conf_model_desc) model_filename = utils.append_to_filename(model_desc_filename, '_model', '.pt') full_desc_filename = utils.append_to_filename(model_desc_filename, '_full', '.yaml') metrics_filename = utils.append_to_filename(model_desc_filename, '_metrics', '.yaml') model_stats_filename = utils.append_to_filename( model_desc_filename, '_model_stats', '.yaml') # create checkpoint for this specific model desc by changing the config checkpoint = None if conf_checkpoint is not None: conf_checkpoint['filename'] = model_filename.split( '.')[0] + '_checkpoint.pth' checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) if checkpoint is not None and resume: if 'metrics_stats' in checkpoint: # return the output we had recorded in the checkpoint convex_hull_point = checkpoint['metrics_stats'] return convex_hull_point # template model is what we used during the search template_model_desc = ModelDesc.load(model_desc_filename) # we first scale this model by number of cells, keeping reductions same as in search n_cells = math.ceil( len(template_model_desc.cell_descs()) * cell_count_scale) n_cells = min(n_cells, max_cells) conf_model_desc = copy.deepcopy(conf_model_desc) conf_model_desc['n_cells'] = n_cells conf_model_desc[ 'n_reductions'] = n_reductions = template_model_desc.cell_type_count( CellType.Reduction) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = evaluater.model_from_desc(model_desc) train_metrics = evaluater.train_model(conf_eval, model, checkpoint) train_metrics.save(metrics_filename) # get metrics_stats model_stats = nas_utils.get_model_stats(model) # save metrics_stats with open(model_stats_filename, 'w') as f: yaml.dump(model_stats, f) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) # TODO: Causes logging error at random times. Commenting out as stop-gap fix. # logger.info({'model_save_path': model_filename}) hull_point = ConvexHullPoint( JobStage.EVAL_TRAINED, 0, 0, model_desc, (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())), metrics=train_metrics, model_stats=model_stats) if checkpoint: checkpoint.new() checkpoint['metrics_stats'] = hull_point checkpoint.commit() return hull_point