def search_model_desc(self, conf_search:Config, model_desc:ModelDesc, trainer_class:TArchTrainer, finalizers:Finalizers)\ ->Tuple[ModelDesc, Optional[Metrics]]: # if trainer is not specified for algos like random search we return same desc if trainer_class is None: return model_desc, None logger.pushd('arch_search') conf_trainer = conf_search['trainer'] conf_loader = conf_search['loader'] model = Model(model_desc, droppath=False, affine=False) # get data data_loaders = self.get_data(conf_loader) # search arch arch_trainer = trainer_class(conf_trainer, model, checkpoint=None) search_metrics = arch_trainer.fit(data_loaders) # finalize found_desc = self.finalize_model(model, finalizers) logger.popd() return found_desc, search_metrics
def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\ ->Optional[ModelMetrics]: """Train given description""" # region conf vars conf_trainer = conf_train['trainer'] conf_loader = conf_train['loader'] trainer_title = conf_trainer['title'] epochs = conf_trainer['epochs'] drop_path_prob = conf_trainer['drop_path_prob'] # endregion # if epochs ==0 then nothing to train, so save time if epochs <= 0: return None logger.pushd(trainer_title) model = Model(model_desc, droppath=drop_path_prob > 0.0, affine=True) # get data data_loaders = self.get_data(conf_loader) trainer = Trainer(conf_trainer, model, checkpoint=None) train_metrics = trainer.fit(data_loaders) logger.popd() return ModelMetrics(model, train_metrics)
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() conf_loader = conf['nas']['search']['loader'] train_dl, val_dl, test_dl = get_data(conf_loader) # wrap all cells in the model self._divnas_cells: Dict[int, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(train_dl): _, _ = model(x), None # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)
def test_darts_zero_model(): conf = common_init(config_filepath='confs/darts_cifar.yaml') conf_search = conf['nas']['search'] model_desc = conf_search['model_desc'] macro_builder = MacroBuilder(model_desc, aux_tower=False) model_desc = macro_builder.build() m = Model(model_desc, False, True) y, aux = m(torch.rand((1, 3, 32, 32))) assert isinstance(y, torch.Tensor) and y.shape==(1,10) and aux is None
def test_petridish_zero_model(): conf = common_init(config_filepath='confs/petridish_cifar.yaml') conf_search = conf['nas']['search'] model_desc = conf_search['model_desc'] model_desc_builder = ModelDescBuilder() model_desc = model_desc_builder.build(model_desc) m = Model(model_desc, False, True) y, aux = m(torch.rand((1, 3, 32, 32))) assert isinstance(y, torch.Tensor) and y.shape == (1, 10) and aux is None
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: # move model to CPU before finalize because each op will serialize # its parameters and we don't want copy of these parameters hanging on GPU original = model.device_type() if to_cpu: model.cpu() # finalize will create copy of state and this can overflow GPU RAM assert model.device_type() == 'cpu' cell_descs = self.finalize_cells(model) if restore_device: model.to(original, non_blocking=True) return ModelDesc( conf_model_desc=model.desc.conf_model_desc, model_stems=[op.finalize()[0] for op in model.model_stems], pool_op=model.pool_op.finalize()[0], cell_descs=cell_descs, aux_tower_descs=model.desc.aux_tower_descs, logits_op=model.logits_op.finalize()[0])
def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list: conf_model_desc = conf_search['model_desc'] conf_seed_train = conf_search['seed_train'] future_ids = [] # ray job IDs seed_model_stats = [] # seed model stats for visualization and debugging macro_combinations = list(self.get_combinations(conf_search)) for reductions, cells, nodes in macro_combinations: # if N R N R N R cannot be satisfied, ignore combination if cells < reductions * 2 + 1: continue # create seed model model_desc = self.build_model_desc(model_desc_builder, conf_model_desc, reductions, cells, nodes) hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc, (cells, reductions, nodes)) # pre-train the seed model future_id = SearcherPetridish.train_model_desc_dist.remote(self, conf_seed_train, hull_point, common.get_state()) future_ids.append(future_id) # build a model so we can get its model stats temp_model = Model(model_desc, droppath=True, affine=True) seed_model_stats.append(nas_utils.get_model_stats(temp_model)) # save the model stats in a plot and tsv file so we can # visualize the spread on the x-axis expdir = common.get_expdir() assert expdir plot_seed_model_stats(seed_model_stats, expdir) return future_ids
def _get_alphas(model:Model)->Iterator[nn.Parameter]: return model.all_owned().param_by_kind('alphas')
# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from archai.nas.model_desc import ModelDesc from archai.common.common import common_init from archai.nas.model import Model from archai.algos.petridish.petridish_model_desc_builder import PetridishModelBuilder from archai.common.model_summary import summary conf = common_init(config_filepath='confs/petridish_cifar.yaml', param_args=['--common.experiment_name', 'petridish_run2_seed42_eval']) conf_eval = conf['nas']['eval'] conf_model_desc = conf_eval['model_desc'] conf_model_desc['n_cells'] = 14 template_model_desc = ModelDesc.load('$expdir/final_model_desc.yaml') model_builder = PetridishModelBuilder() model_desc = model_builder.build(conf_model_desc, template=template_model_desc) mb = PetridishModelBuilder() model = Model(model_desc, droppath=False, affine=False) summary(model, [64, 3, 32, 32]) exit(0)
def model_from_desc(self, model_desc) -> Model: return Model(model_desc, droppath=True, affine=True)