def _save_checkpoint(self, epoch, best=False): """Save model weights. :param epoch: current epoch :type epoch: int """ save_dir = os.path.join(self.worker_path, str(epoch)) FileOps.make_dir(save_dir) for name in self.model.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = FileOps.join_path(save_dir, save_filename) net = getattr(self.model, 'net' + name) best_file = FileOps.join_path(self.worker_path, "model_{}.pth".format(name)) if vega.is_gpu_device() and torch.cuda.is_available(): # torch.save(net.module.cpu().state_dict(), save_path) torch.save(net.module.state_dict(), save_path) # net.cuda() if best: torch.save(net.module.state_dict(), best_file) elif vega.is_npu_device(): torch.save(net.state_dict(), save_path) if best: torch.save(net.state_dict(), best_file) else: torch.save(net.cpu().state_dict(), save_path) if best: torch.save(net.cpu().state_dict(), best_file)
def save_results(self): """Save the results of evolution contains the information of pupulation and elitism.""" _path = FileOps.join_path(self.local_output_path, General.step_name) FileOps.make_dir(_path) arch_file = FileOps.join_path(_path, 'arch.txt') arch_child = FileOps.join_path(_path, 'arch_child.txt') sel_arch_file = FileOps.join_path(_path, 'selected_arch.npy') sel_arch = [] with open(arch_file, 'a') as fw_a, open(arch_child, 'a') as fw_ac: writer_a = csv.writer(fw_a, lineterminator='\n') writer_ac = csv.writer(fw_ac, lineterminator='\n') writer_ac.writerow( ['Population Iteration: ' + str(self.evolution_count + 1)]) for c in range(self.individual_num): writer_ac.writerow( self._log_data(net_info_type='active_only', pop=self.pop[c], value=self.pop[c].fitness)) writer_a.writerow( ['Population Iteration: ' + str(self.evolution_count + 1)]) for c in range(self.elitism_num): writer_a.writerow( self._log_data(net_info_type='active_only', pop=self.elitism[c], value=self.elit_fitness[c])) sel_arch.append(self.elitism[c].gene) sel_arch = np.stack(sel_arch) np.save(sel_arch_file, sel_arch) if self.backup_base_path is not None: FileOps.copy_folder(self.local_output_path, self.backup_base_path)
def _get_model_desc(self): model_desc = self.trainer.model_desc if not model_desc: if ModelConfig.model_desc_file is not None: desc_file = ModelConfig.model_desc_file desc_file = desc_file.replace("{local_base_path}", self.trainer.local_base_path) if ":" not in desc_file: desc_file = os.path.abspath(desc_file) if ":" in desc_file: local_desc_file = FileOps.join_path( self.trainer.local_output_path, os.path.basename(desc_file)) FileOps.copy_file(desc_file, local_desc_file) desc_file = local_desc_file model_desc = Config(desc_file) logger.info("net_desc:{}".format(model_desc)) elif ModelConfig.model_desc is not None: model_desc = ModelConfig.model_desc elif ModelConfig.models_folder is not None: folder = ModelConfig.models_folder.replace( "{local_base_path}", self.trainer.local_base_path) pattern = FileOps.join_path(folder, "desc_*.json") desc_file = glob.glob(pattern)[0] model_desc = Config(desc_file) return model_desc
def _save_pb_model(self, weight_file, model_id): from tensorflow.python.framework import graph_util valid_data = self.trainer.valid_loader.input_fn() iterator = valid_data.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: batch = sess.run(one_element) input_shape = batch[0].shape with tf.Graph().as_default(): input_holder_shape = (None, ) + tuple(input_shape[1:]) input_holder = tf.placeholder(dtype=tf.float32, shape=input_holder_shape) self.trainer.model.training = False output = self.trainer.model(input_holder) if isinstance(output, tuple): output_name = [output[0].name.split(":")[0]] else: output_name = [output.name.split(":")[0]] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if weight_file is not None: saver = tf.train.Saver() last_weight_file = tf.train.latest_checkpoint(weight_file) if last_weight_file: saver.restore(sess, last_weight_file) constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_name) output_graph = FileOps.join_path(weight_file, '{}.pb'.format(model_id)) with tf.gfile.FastGFile(output_graph, mode='wb') as f: f.write(constant_graph.SerializeToString())
def before_train(self, logs=None): """Call before_train of the managed callbacks.""" super().before_train(logs) """Be called before the training process.""" hpo_result = FileOps.load_pickle( FileOps.join_path(self.trainer.local_output_path, 'best_config.pickle')) logging.info("loading stage1_hpo_result \n{}".format(hpo_result)) feature_interaction_score = hpo_result['feature_interaction_score'] print('feature_interaction_score:', feature_interaction_score) sorted_pairs = sorted(feature_interaction_score.items(), key=lambda x: abs(x[1]), reverse=True) if ModelConfig.model_desc: fis_ratio = ModelConfig.model_desc["custom"]["fis_ratio"] else: fis_ratio = 1.0 top_k = int(len(feature_interaction_score) * min(1.0, fis_ratio)) self.selected_pairs = list(map(lambda x: x[0], sorted_pairs[:top_k])) # add selected_pairs setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
def load_records_from_model_folder(cls, model_folder): """Transfer json_file to records.""" if not model_folder or not os.path.exists(model_folder): logging.error("Failed to load records from model folder, folder={}".format(model_folder)) return [] records = [] pattern = FileOps.join_path(model_folder, "desc_*.json") files = glob.glob(pattern) for _file in files: try: with open(_file) as f: worker_id = _file.split(".")[-2].split("_")[-1] weights_file = os.path.join(os.path.dirname(_file), "model_{}".format(worker_id)) if vega.is_torch_backend(): weights_file = '{}.pth'.format(weights_file) elif vega.is_ms_backend(): weights_file = '{}.ckpt'.format(weights_file) if not os.path.exists(weights_file): weights_file = None sample = dict(worker_id=worker_id, desc=json.load(f), weights_file=weights_file) record = ReportRecord().load_dict(sample) records.append(record) except Exception as ex: logging.info('Can not read records from json because {}'.format(ex)) return records
def save_report(self, records): """Save report to `reports.json`.""" try: _file = FileOps.join_path(TaskOps().local_output_path, "reports.json") FileOps.make_base_dir(_file) data = {"_steps_": []} for step in self.step_names: if step in self.steps: data["_steps_"].append(self.steps[step]) else: data["_steps_"].append({ "step_name": step, "status": Status.unstarted }) for record in records: if record.step_name in data: data[record.step_name].append(record.to_dict()) else: data[record.step_name] = [record.to_dict()] with open(_file, "w") as f: json.dump(data, f, indent=4, cls=JsonEncoder) except Exception: logging.warning(traceback.format_exc())
def after_valid(self, logs=None): """Call after_valid of the managed callbacks.""" self.model = self.trainer.model feature_interaction_score = self.model.get_feature_interaction_score() print('get feature_interaction_score', feature_interaction_score) feature_interaction = [] for feature in feature_interaction_score: if abs(feature_interaction_score[feature]) > 0: feature_interaction.append(feature) print('get feature_interaction', feature_interaction) curr_auc = float(self.trainer.valid_metrics.results['auc']) if curr_auc > self.best_score: best_config = { 'score': curr_auc, 'feature_interaction': feature_interaction } logging.info("BEST CONFIG IS\n{}".format(best_config)) pickle_result_file = FileOps.join_path( self.trainer.local_output_path, 'best_config.pickle') logging.info("Saved to {}".format(pickle_result_file)) FileOps.dump_pickle(best_config, pickle_result_file) self.best_score = curr_auc
def _get_current_step_records(self): step_name = General.step_name models_folder = PipeStepConfig.pipe_step.get("models_folder") cur_index = PipelineConfig.steps.index(step_name) if cur_index >= 1 or models_folder: if not models_folder: models_folder = FileOps.join_path( TaskOps().local_output_path, PipelineConfig.steps[cur_index - 1]) models_folder = models_folder.replace("{local_base_path}", TaskOps().local_base_path) records = ReportServer().load_records_from_model_folder( models_folder) else: records = self._load_single_model_records() final_records = [] for record in records: if not record.weights_file: logger.error("Model file is not existed, id={}".format( record.worker_id)) else: record.step_name = General.step_name final_records.append(record) logging.debug("Records: {}".format(final_records)) return final_records
def _backup(self): """Backup result worker folder.""" if self.need_backup is True and self.backup_base_path is not None: backup_worker_path = FileOps.join_path(self.backup_base_path, self.get_worker_subpath()) FileOps.copy_folder( self.get_local_worker_path(self.step_name, self.worker_id), backup_worker_path)
def _output_records(self, step_name, records): """Dump records.""" columns = ["worker_id", "performance", "desc"] outputs = [] for record in records: record = record.serialize() _record = {} for key in columns: _record[key] = record[key] outputs.append(deepcopy(_record)) data = pd.DataFrame(outputs) step_path = FileOps.join_path(TaskOps().local_output_path, step_name) FileOps.make_dir(step_path) _file = FileOps.join_path(step_path, "output.csv") try: data.to_csv(_file, index=False) except Exception: logging.error("Failed to save output file, file={}".format(_file)) for record in outputs: worker_id = record["worker_id"] worker_path = TaskOps().get_local_worker_path(step_name, worker_id) outputs_globs = [] outputs_globs += glob.glob(FileOps.join_path(worker_path, "desc_*.json")) outputs_globs += glob.glob(FileOps.join_path(worker_path, "hps_*.json")) outputs_globs += glob.glob(FileOps.join_path(worker_path, "model_*")) outputs_globs += glob.glob(FileOps.join_path(worker_path, "performance_*.json")) for _file in outputs_globs: if os.path.isfile(_file): FileOps.copy_file(_file, step_path) elif os.path.isdir(_file): FileOps.copy_folder(_file, FileOps.join_path(step_path, os.path.basename(_file)))
def get_pareto_list_size(self): """Get the number of pareto list.""" pareto_list_size = 0 pareto_file_locate = FileOps.join_path(self.local_base_path, "result", "pareto_front.csv") if os.path.exists(pareto_file_locate): pareto_front_df = pd.read_csv(pareto_file_locate) pareto_list_size = pareto_front_df.size return pareto_list_size
def _load_checkpoint(self): """Load checkpoint.""" if vega.is_torch_backend(): if hasattr(self.trainer.config, "checkpoint_path"): checkpoint_path = self.trainer.config.checkpoint_path else: checkpoint_path = self.trainer.get_local_worker_path() checkpoint_file = FileOps.join_path( checkpoint_path, self.trainer.checkpoint_file_name) if os.path.exists(checkpoint_file): try: logging.info("Load checkpoint file, file={}".format( checkpoint_file)) checkpoint = torch.load(checkpoint_file) if self.trainer.multi_task: self.trainer.model.load_state_dict( checkpoint["weight"], strict=False) multi_task_checkpoint = torch.load( FileOps.join_path( checkpoint_path, self.trainer.multi_task, self.trainer.checkpoint_file_name)) self.trainer.optimizer.load_state_dict( multi_task_checkpoint["optimizer"]) self.trainer.lr_scheduler.load_state_dict( multi_task_checkpoint["lr_scheduler"]) else: self.trainer.model.load_state_dict( checkpoint["weight"]) self.trainer.optimizer.load_state_dict( checkpoint["optimizer"]) self.trainer.lr_scheduler.load_state_dict( checkpoint["lr_scheduler"]) if self.trainer._resume_training: # epoch = checkpoint["epoch"] self.trainer._start_epoch = checkpoint["epoch"] logging.info( "Resume fully train, change start epoch to {}". format(self.trainer._start_epoch)) except Exception as e: logging.info("Load checkpoint failed {}".format(e)) else: logging.info( "skip loading checkpoint file that do not exist, {}". format(checkpoint_file))
def _get_model_desc(self): model_desc = self.model_desc self.saved_folder = self.get_local_worker_path(self.step_name, self.worker_id) if not model_desc: if os.path.exists( FileOps.join_path(self.saved_folder, 'desc_{}.json'.format(self.worker_id))): model_config = Config( FileOps.join_path(self.saved_folder, 'desc_{}.json'.format(self.worker_id))) if "type" not in model_config and "modules" not in model_config: model_config = ModelConfig.model_desc model_desc = model_config elif ModelConfig.model_desc_file is not None: desc_file = ModelConfig.model_desc_file desc_file = desc_file.replace("{local_base_path}", self.local_base_path) if ":" not in desc_file: desc_file = os.path.abspath(desc_file) if ":" in desc_file: local_desc_file = FileOps.join_path( self.local_output_path, os.path.basename(desc_file)) FileOps.copy_file(desc_file, local_desc_file) desc_file = local_desc_file model_desc = Config(desc_file) logger.info("net_desc:{}".format(model_desc)) elif ModelConfig.model_desc is not None: model_desc = ModelConfig.model_desc elif ModelConfig.models_folder is not None: folder = ModelConfig.models_folder.replace( "{local_base_path}", self.local_base_path) pattern = FileOps.join_path(folder, "desc_*.json") desc_file = glob.glob(pattern)[0] model_desc = Config(desc_file) elif PipeStepConfig.pipe_step.get("models_folder") is not None: folder = PipeStepConfig.pipe_step.get("models_folder").replace( "{local_base_path}", self.local_base_path) desc_file = FileOps.join_path( folder, "desc_{}.json".format(self.worker_id)) model_desc = Config(desc_file) logger.info("Load model from model folder {}.".format(folder)) return model_desc
def _copy_needed_file(self): if self.config.pareto_front_file is None: raise FileNotFoundError( "Config item paretor_front_file not found in config file.") init_pareto_front_file = self.config.pareto_front_file.replace( "{local_base_path}", self.local_base_path) self.pareto_front_file = FileOps.join_path(self.local_output_path, self.step_name, "pareto_front.csv") FileOps.make_base_dir(self.pareto_front_file) FileOps.copy_file(init_pareto_front_file, self.pareto_front_file) if self.config.random_file is None: raise FileNotFoundError( "Config item random_file not found in config file.") init_random_file = self.config.random_file.replace( "{local_base_path}", self.local_base_path) self.random_file = FileOps.join_path(self.local_output_path, self.step_name, "random.csv") FileOps.copy_file(init_random_file, self.random_file)
def __init__(self, **kwargs): """Construct the Imagenet class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) split = 'train' if self.mode == 'train' else 'val' local_data_path = FileOps.join_path(self.args.data_path, split) delattr(self, 'loader') ImageFolder.__init__(self, root=local_data_path, transform=Compose(self.transforms.__transform__))
def set_trainer(self, trainer): """Set trainer object for current callback.""" self.trainer = trainer self.trainer._train_loop = self._train_loop self.cfg = self.trainer.config self._worker_id = self.trainer._worker_id self.worker_path = self.trainer.get_local_worker_path() self.output_path = self.trainer.local_output_path self.best_model_name = "model_best" self.best_model_file = FileOps.join_path( self.worker_path, "model_{}.pth".format(self.trainer.worker_id))
def before_train(self, logs=None): """Be called before the whole train process.""" self.trainer.config.call_metrics_on_train = False self.cfg = self.trainer.config self.worker_id = self.trainer.worker_id self.local_base_path = self.trainer.local_base_path self.local_output_path = self.trainer.local_output_path self.result_path = FileOps.join_path(self.trainer.local_base_path, "result") FileOps.make_dir(self.result_path) self.logger_patch()
def _init_next_rung(self): """Init next rung to search.""" next_rung_id = self.rung_id + 1 if next_rung_id >= self.total_rungs: self.rung_id = self.rung_id + 1 return for i in range(self.config_count): self.all_config_dict[i][next_rung_id] = self.all_config_dict[i][ self.rung_id] current_score = [] for i in range(self.config_count): current_score.append((i, self.best_score_dict[self.rung_id][i])) current_score.sort(key=lambda current_score: current_score[1]) for i in range(4): better_id = current_score[self.config_count - 1 - i][0] worse_id = current_score[i][0] better_worker_result_path = FileOps.join_path( self.local_base_path, 'cache', 'pba', str(better_id), 'checkpoint') FileOps.make_dir(better_worker_result_path) worse_worker_result_path = FileOps.join_path( self.local_base_path, 'cache', 'pba', str(worse_id), 'checkpoint') FileOps.make_dir(worse_worker_result_path) shutil.rmtree(worse_worker_result_path) shutil.copytree(better_worker_result_path, worse_worker_result_path) self.all_config_dict[worse_id] = self.all_config_dict[better_id] policy_unchange = self.all_config_dict[worse_id][next_rung_id] policy_changed = self.explore(policy_unchange) self.all_config_dict[worse_id][next_rung_id] = policy_changed for id in range(self.config_count): self.best_score_dict[next_rung_id][id] = -1 * float('inf') tmp_row_data = { 'config_id': id, 'rung_id': next_rung_id, 'status': StatusType.WAITTING } self._add_to_board(tmp_row_data) self.rung_id = self.rung_id + 1
def _get_search_space_list(self): """Get search space list from models folder.""" models_folder = PipeStepConfig.pipe_step.get("models_folder") if not models_folder: self.search_space_list = None return self.search_space_list = [] models_folder = models_folder.replace("{local_base_path}", TaskOps().local_base_path) pattern = FileOps.join_path(models_folder, "*.json") files = glob.glob(pattern) for file in files: with open(file) as f: self.search_space_list.append(json.load(f))
def after_train(self, logs=None): """Call after_train of the managed callbacks.""" curr_auc = float(self.trainer.valid_metrics.results['auc']) self.sieve_board = self.sieve_board.append( { 'selected_feature_pairs': self.selected_pairs, 'score': curr_auc }, ignore_index=True) result_file = FileOps.join_path( self.trainer.local_output_path, '{}_result.csv'.format(self.trainer.__worker_id__)) self.sieve_board.to_csv(result_file, sep='\t')
def load_model(self): """Load model.""" self.saved_folder = self.get_local_worker_path(self.step_name, self.worker_id) if not self.model_desc: self.model_desc = self._get_model_desc() if not self.weights_file: if vega.is_torch_backend(): self.weights_file = FileOps.join_path( self.saved_folder, 'model_{}.pth'.format(self.worker_id)) elif vega.is_ms_backend(): for file in os.listdir(self.saved_folder): if file.endswith(".ckpt"): self.weights_file = FileOps.join_path( self.saved_folder, file) elif vega.is_tf_backend(): self.weights_file = FileOps.join_path( self.saved_folder, 'model_{}'.format(self.worker_id)) if self.weights_file is not None and os.path.exists(self.weights_file): self.model = ModelZoo.get_model(self.model_desc, self.weights_file) else: logger.info("evalaute model without loading weights file") self.model = ModelZoo.get_model(self.model_desc)
def logger_patch(self): """Patch the default logger.""" worker_path = self.trainer.get_local_worker_path() worker_spec_log_file = FileOps.join_path(worker_path, 'current_worker.log') logger = logging.getLogger(__name__) for hdlr in logger.handlers: logger.removeHandler(hdlr) for hdlr in logging.root.handlers: logging.root.removeHandler(hdlr) logger.addHandler(logging.FileHandler(worker_spec_log_file)) logger.addHandler(logging.StreamHandler()) logger.setLevel(logging.INFO) logging.root = logger
def before_train(self, logs=None): """Call before_train of the managed callbacks.""" super().before_train(logs) """Be called before the training process.""" hpo_result = FileOps.load_pickle(FileOps.join_path( self.trainer.local_output_path, 'best_config.pickle')) logging.info("loading stage1_hpo_result \n{}".format(hpo_result)) self.selected_pairs = hpo_result['feature_interaction'] logging.info('feature_interaction:', self.selected_pairs) # add selected_pairs setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
def update(self, record): """Update current performance into hpo score board. :param hps: hyper parameters need to update :param performance: trainer performance """ super().update(record) config_id = str(record.get('worker_id')) step_name = record.get('step_name') worker_result_path = self.get_local_worker_path(step_name, config_id) new_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', config_id, 'checkpoint') FileOps.make_dir(worker_result_path) FileOps.make_dir(new_worker_result_path) if os.path.exists(new_worker_result_path): shutil.rmtree(new_worker_result_path) shutil.copytree(worker_result_path, new_worker_result_path)
def _show_report(self): performance_file = FileOps.join_path(TaskOps().local_output_path, self.steps[-1].name, "output.csv") try: data = pd.read_csv(performance_file) except Exception: logging.info(" result file output.csv is not existed or empty") return if data.shape[1] < 2 or data.shape[0] == 0: logging.info(" result file output.csv is empty") return logging.info(" result:") data = json.loads(data.to_json()) for key in data["worker_id"].keys(): logging.info(" {:>3s}: {}".format(str(data["worker_id"][key]), data["performance"][key]))
def _generate_init_model(self): """Generate init model by loading pretrained model. :return: initial model after loading pretrained model :rtype: torch.nn.Module """ model_init = self._new_model_init() chn_node_mask = self._init_chn_node_mask() if vega.is_torch_backend(): if vega.is_gpu_device(): checkpoint = torch.load(self.config.init_model_file + '.pth') model_init.load_state_dict(checkpoint) model = PruneResnet(model_init).apply( chn_node_mask, self.base_net_desc.backbone.chn_mask) model.to(self.device) elif vega.is_npu_device(): device = "npu:{}".format(os.environ.get('DEVICE_ID', 0)) checkpoint = torch.load(self.config.init_model_file + '.pth', map_location=torch.device( '{}'.format(device))) model_init.load_state_dict(checkpoint) model = PruneResnet(model_init).apply( chn_node_mask, self.base_net_desc.backbone.chn_mask) model.npu() elif vega.is_tf_backend(): model = model_init with tf.compat.v1.Session( config=self.trainer._init_session_config()) as sess: saver = tf.compat.v1.train.import_meta_graph("{}.meta".format( self.config.init_model_file)) saver.restore(sess, self.config.init_model_file) all_weight = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.VARIABLES) all_weight = [ t for t in all_weight if not t.name.endswith('Momentum:0') ] PruneResnet(all_weight).apply( chn_node_mask, self.base_net_desc.backbone.chn_mask) save_file = FileOps.join_path( self.trainer.get_local_worker_path(), 'prune_model') saver.save(sess, save_file) elif vega.is_ms_backend(): parameter_dict = load_checkpoint(self.config.init_model_file) load_param_into_net(model_init, parameter_dict) model = PruneResnet(model_init).apply( chn_node_mask, self.base_net_desc.backbone.chn_mask) return model
def _new_model_init(self): """Init new model. :return: initial model after loading pretrained model :rtype: torch.nn.Module """ init_model_file = self.config.init_model_file if ":" in init_model_file: local_path = FileOps.join_path( self.trainer.get_local_worker_path(), os.path.basename(init_model_file)) FileOps.copy_file(init_model_file, local_path) self.config.init_model_file = local_path network_desc = copy.deepcopy(self.base_net_desc) network_desc.backbone.cfgs = network_desc.backbone.base_cfgs model_init = NetworkDesc(network_desc).to_model() return model_init
def _saved_multi_checkpoint(self, epoch): """Save multi tasks checkpoint.""" FileOps.make_dir(self.trainer.get_local_worker_path(), self.trainer.multi_task) checkpoint_file = FileOps.join_path( self.trainer.get_local_worker_path(), self.trainer.multi_task, self.trainer.checkpoint_file_name) logging.debug("Start Save Multi Task Model, model_file=%s", self.trainer.model_pickle_file_name) if vega.is_torch_backend(): ckpt = { 'epoch': epoch, 'weight': self.trainer.model.state_dict(), 'optimizer': self.trainer.optimizer.state_dict(), 'lr_scheduler': self.trainer.lr_scheduler.state_dict(), } torch.save(ckpt, checkpoint_file) self.trainer.checkpoint_file = checkpoint_file
def _save_descript(self): """Save result descript.""" template_file = self.config.darts_template_file genotypes = self.search_alg.codec.calc_genotype( self._get_arch_weights()) if template_file == "{default_darts_cifar10_template}": template = DartsNetworkTemplateConfig.cifar10 elif template_file == "{default_darts_cifar100_template}": template = DartsNetworkTemplateConfig.cifar100 elif template_file == "{default_darts_imagenet_template}": template = DartsNetworkTemplateConfig.imagenet else: dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file)) FileOps.copy_file(template_file, dst) template = Config(dst) model_desc = self._gen_model_desc(genotypes, template) self.trainer.config.codec = model_desc