def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: # move model to CPU before finalize because each op will serialize # its parameters and we don't want copy of these parameters hanging on GPU original = model.device_type() if to_cpu: model.cpu() # finalize will create copy of state and this can overflow GPU RAM assert model.device_type() == 'cpu' cell_descs = self.finalize_cells(model) if restore_device: model.to(original, non_blocking=True) return ModelDesc( conf_model_desc=model.desc.conf_model_desc, model_stems=[op.finalize()[0] for op in model.model_stems], pool_op=model.pool_op.finalize()[0], cell_descs=cell_descs, aux_tower_descs=model.desc.aux_tower_descs, logits_op=model.logits_op.finalize()[0])
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() conf_loader = conf['nas']['search']['loader'] train_dl, val_dl, test_dl = get_data(conf_loader) # wrap all cells in the model self._divnas_cells: Dict[int, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(train_dl): _, _ = model(x), None # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)