def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config, hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder, trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\ ->ConvexHullPoint: # as this runs in different process, initiaze globals common.init_from(common_state) #register ops as we are in different process now conf_model_desc = conf_search['model_desc'] model_desc_builder.pre_build(conf_model_desc) assert hull_point.is_trained_stage() # cloning is strictly not needed but just in case if we run this # function in same process, it would be good to avoid surprise model_desc = hull_point.model_desc.clone() searcher._add_node(model_desc, model_desc_builder) model_desc, search_metrics = searcher.search_model_desc( conf_search, model_desc, trainer_class, finalizers) cells, reductions, nodes = hull_point.cells_reductions_nodes new_point = ConvexHullPoint( JobStage.SEARCH, hull_point.id, hull_point.sampling_count, model_desc, (cells, reductions, nodes + 1), # we added a node metrics=search_metrics) return new_point
def test_darts_zero_model(): conf = common_init(config_filepath='confs/algos/darts.yaml') conf_search = conf['nas']['search'] model_desc = conf_search['model_desc'] model_desc_builder = ModelDescBuilder() model_desc = model_desc_builder.build(model_desc) m = Model(model_desc, False, True) y, aux = m(torch.rand((1, 3, 32, 32))) assert isinstance(y, torch.Tensor) and y.shape == (1, 10) and aux is None
def build_model_desc(self, model_desc_builder: ModelDescBuilder, conf_model_desc: Config, reductions: int, cells: int, nodes: int) -> ModelDesc: # reset macro params in copy of config conf_model_desc = copy.deepcopy(conf_model_desc) conf_model_desc['n_reductions'] = reductions conf_model_desc['n_cells'] = cells # create model desc for search using model config # we will build model without call to model_desc_builder for pre-training model_desc = model_desc_builder.build(conf_model_desc, template=None) return model_desc
def create_model(self, conf_eval: Config, model_desc_builder: ModelDescBuilder, final_desc_filename=None, full_desc_filename=None) -> nn.Module: assert model_desc_builder is not None, 'Default evaluater requires model_desc_builder' # region conf vars # if explicitly passed in then don't get from conf if not final_desc_filename: final_desc_filename = conf_eval['final_desc_filename'] full_desc_filename = conf_eval['full_desc_filename'] conf_model_desc = conf_eval['model_desc'] # endregion # load model desc file to get template model template_model_desc = ModelDesc.load(final_desc_filename) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = self.model_from_desc(model_desc) logger.info({ 'model_factory': False, 'cells_len': len(model.desc.cell_descs()), 'init_node_ch': conf_model_desc['model_stems']['init_node_ch'], 'n_cells': conf_model_desc['n_cells'], 'n_reductions': conf_model_desc['n_reductions'], 'n_nodes': conf_model_desc['cell']['n_nodes'] }) return model
def _add_node(self, model_desc: ModelDesc, model_desc_builder: ModelDescBuilder) -> None: for ci, cell_desc in enumerate(model_desc.cell_descs()): reduction = (cell_desc.cell_type == CellType.Reduction) nodes = cell_desc.nodes() # petridish must seed with one node assert len(nodes) > 0 # input/output channels for all nodes are same conv_params = nodes[0].conv_params # assign input IDs to nodes, s0 and s1 have IDs 0 and 1 # however as we will be inserting new node before last one input_ids = list(range(len(nodes) + 1)) assert len(input_ids) >= 2 # 2 stem inputs op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op', params={ 'conv': conv_params, # specify strides for each input, later we will # give this to each primitive '_strides':[2 if reduction and j < 2 else 1 \ for j in input_ids], }, in_len=len(input_ids), trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=input_ids) new_node = NodeDesc(edges=[edge], conv_params=conv_params) nodes.insert(len(nodes) - 1, new_node) # output shape of all nodes are same node_shapes = cell_desc.node_shapes new_node_shape = copy.deepcopy(node_shapes[-1]) node_shapes.insert(len(node_shapes) - 1, new_node_shape) # post op needs rebuilding because number of inputs to it has changed so input/output channels may be different post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op( cell_desc.stem_shapes, node_shapes, cell_desc.conf_cell, ci) cell_desc.reset_nodes(nodes, node_shapes, post_op_desc, post_op_shape)
def _train_dist(evaluater: Evaluater, conf_eval: Config, model_desc_builder: ModelDescBuilder, model_desc_filename: str, common_state) -> ConvexHullPoint: """Train given a model""" common.init_from(common_state) # region config vars conf_model_desc = conf_eval['model_desc'] max_cells = conf_model_desc['n_cells'] conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] conf_petridish = conf_eval['petridish'] cell_count_scale = conf_petridish['cell_count_scale'] #endregion #register ops as we are in different process now model_desc_builder.pre_build(conf_model_desc) model_filename = utils.append_to_filename(model_desc_filename, '_model', '.pt') full_desc_filename = utils.append_to_filename(model_desc_filename, '_full', '.yaml') metrics_filename = utils.append_to_filename(model_desc_filename, '_metrics', '.yaml') model_stats_filename = utils.append_to_filename( model_desc_filename, '_model_stats', '.yaml') # create checkpoint for this specific model desc by changing the config checkpoint = None if conf_checkpoint is not None: conf_checkpoint['filename'] = model_filename.split( '.')[0] + '_checkpoint.pth' checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) if checkpoint is not None and resume: if 'metrics_stats' in checkpoint: # return the output we had recorded in the checkpoint convex_hull_point = checkpoint['metrics_stats'] return convex_hull_point # template model is what we used during the search template_model_desc = ModelDesc.load(model_desc_filename) # we first scale this model by number of cells, keeping reductions same as in search n_cells = math.ceil( len(template_model_desc.cell_descs()) * cell_count_scale) n_cells = min(n_cells, max_cells) conf_model_desc = copy.deepcopy(conf_model_desc) conf_model_desc['n_cells'] = n_cells conf_model_desc[ 'n_reductions'] = n_reductions = template_model_desc.cell_type_count( CellType.Reduction) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = evaluater.model_from_desc(model_desc) train_metrics = evaluater.train_model(conf_eval, model, checkpoint) train_metrics.save(metrics_filename) # get metrics_stats model_stats = nas_utils.get_model_stats(model) # save metrics_stats with open(model_stats_filename, 'w') as f: yaml.dump(model_stats, f) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) # TODO: Causes logging error at random times. Commenting out as stop-gap fix. # logger.info({'model_save_path': model_filename}) hull_point = ConvexHullPoint( JobStage.EVAL_TRAINED, 0, 0, model_desc, (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())), metrics=train_metrics, model_stats=model_stats) if checkpoint: checkpoint.new() checkpoint['metrics_stats'] = hull_point checkpoint.commit() return hull_point
def model_desc_builder(self)->Optional[ModelDescBuilder]: return ModelDescBuilder() # default model desc builder puts nodes with no edges