def __init__(self, search_space=None): super(SpNas, self).__init__(search_space) self.search_space = search_space self.codec = Codec(self.cfg.codec, search_space) self.sample_level = self.cfg.sample_level self.max_sample = self.cfg.max_sample self.max_optimal = self.cfg.max_optimal self._total_list_name = self.cfg.total_list self.serial_settings = self.cfg.serial_settings self._total_list = ListDict() self.sample_count = 0 self.init_code = None remote_output_path = FileOps.join_path(self.local_output_path, self.cfg.step_name) if 'last_search_result' in self.cfg: last_search_file = self.cfg.last_search_result assert FileOps.exists(os.path.join(remote_output_path, last_search_file) ), "Not found serial results!" # self.download_task_folder() last_search_results = os.path.join(self.local_output_path, last_search_file) last_search_results = ListDict.load_csv(last_search_results) pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, last_search_results) # re-write config template if self.cfg.regnition: self.codec.config_template['model']['backbone']['reignition'] = True assert FileOps.exists(os.path.join(remote_output_path, pre_arch + '_imagenet.pth') ), "Not found {} pretrained .pth file!".format(pre_arch) pretrained_pth = os.path.join(self.local_output_path, pre_arch + '_imagenet.pth') self.codec.config_template['model']['pretrained'] = pretrained_pth pre_worker_id = -1 # update config template self.init_code = dict(arch=pre_arch, pre_arch=pre_arch.split('_')[1], pre_worker_id=pre_worker_id) logging.info("inited SpNas {}-level search...".format(self.sample_level))
def _save_model_desc(self): """Save final model desc of NAS.""" pf_file = FileOps.join_path(self.trainer.local_output_path, self.trainer.step_name, "pareto_front.csv") if not FileOps.exists(pf_file): return with open(pf_file, "r") as file: pf = pd.read_csv(file) pareto_fronts = pf["encoding"].tolist() search_space = SearchSpace() codec = QuantCodec('QuantCodec', search_space) for i, pareto_front in enumerate(pareto_fronts): pareto_front = [int(x) for x in pareto_front[1:-1].split(',')] model_desc = Config() model_desc.modules = search_space.search_space.modules model_desc.backbone = codec.decode(pareto_front)._desc.backbone self.trainer.output_model_desc(i, model_desc)