def _init_model(self): """Load model desc from save path and parse to model.""" model = self.trainer.model if self.trainer.config.is_detection_trainer: model_desc = self.trainer.model_desc else: model_desc = self._get_model_desc() if model_desc: ModelConfig.model_desc = model_desc pretrained_model_file = self._get_pretrained_model_file() if not model: if not model_desc: raise Exception( "Failed to Init model, can not get model description.") model = ModelZoo.get_model(model_desc, pretrained_model_file) if model: if zeus.is_torch_backend(): import torch if self.trainer.use_cuda: model = model.cuda() if General._parallel and General.devices_per_trainer > 1: model = torch.nn.DataParallel(self.trainer.model) if zeus.is_tf_backend(): if pretrained_model_file: model_folder = os.path.dirname(pretrained_model_file) FileOps.copy_folder(model_folder, self.trainer.get_local_worker_path()) return model
def _do_horovod_fully_train(self): pwd_dir = os.path.dirname(os.path.abspath(__file__)) cf_file = os.path.join(pwd_dir, 'cf.pickle') cf_content = { 'registry': ClassFactory.__registry__, 'general_config': General().to_json(), 'pipe_step_config': PipeStepConfig().to_json() } with open(cf_file, 'wb') as f: pickle.dump(cf_content, f) cf_file_remote = os.path.join(self.task.local_base_path, 'cf.pickle') FileOps.copy_file(cf_file, cf_file_remote) if os.environ.get('DLS_TASK_NUMBER') is None: # local cluster worker_ips = '127.0.0.1' if General.cluster.master_ip is not None and General.cluster.master_ip != '127.0.0.1': worker_ips = General.cluster.master_ip for ip in General.cluster.slaves: worker_ips = worker_ips + ',' + ip cmd = [ 'bash', '{}/horovod/run_cluster_horovod_train.sh'.format(pwd_dir), str(self.world_device_size), cf_file_remote, worker_ips ] else: # Roma cmd = [ 'bash', '{}/horovod/run_horovod_train.sh'.format(pwd_dir), str(self.world_device_size), cf_file_remote ] proc = subprocess.Popen(cmd, env=os.environ) proc.wait()
def dataset_init(self): """Construct method. If both data_dir and label_dir are provided, then use data_dir and label_dir Otherwise use root_dir and list_file. """ if "data_dir" in self.args and "label_dir" in self.args: self.args.data_dir = FileOps.download_dataset(self.args.data_dir) self.args.label_dir = FileOps.download_dataset(self.args.label_dir) self.data_files = sorted(glob.glob(osp.join(self.args.data_dir, "*"))) self.label_files = sorted(glob.glob(osp.join(self.args.label_dir, "*"))) else: if "root_dir" not in self.args or "list_file" not in self.args: raise Exception("You must provide a root_dir and a list_file!") self.args.root_dir = FileOps.download_dataset(self.args.root_dir) with open(osp.join(self.args.root_dir, self.args.list_file)) as f: lines = f.readlines() self.data_files = [None] * len(lines) self.label_files = [None] * len(lines) for i, line in enumerate(lines): data_file_name, label_file_name = line.strip().split() self.data_files[i] = osp.join(self.args.root_dir, data_file_name) self.label_files[i] = osp.join(self.args.root_dir, label_file_name) datatype = self._get_datatype() if datatype == "image": self.read_fn = self._read_item_image else: self.read_fn = self._read_item_pickle
def _save_worker_record(cls, record): step_name = record.get('step_name') worker_id = record.get('worker_id') _path = TaskOps().get_local_worker_path(step_name, worker_id) for record_name in ["desc", "performance"]: _file_name = None _file = None record_value = record.get(record_name) if not record_value: continue _file = None try: # for cars/darts save multi-desc if isinstance(record_value, list) and record_name == "desc": for idx, value in enumerate(record_value): _file_name = "desc_{}.json".format(idx) _file = FileOps.join_path(_path, _file_name) with open(_file, "w") as f: json.dump(value, f) else: _file_name = None if record_name == "desc": _file_name = "desc_{}.json".format(worker_id) if record_name == "performance": _file_name = "performance_{}.json".format(worker_id) _file = FileOps.join_path(_path, _file_name) with open(_file, "w") as f: json.dump(record_value, f) except Exception as ex: logging.error( "Failed to save {}, file={}, desc={}, msg={}".format( record_name, _file, record_value, str(ex)))
def _append_record_to_csv(self, record_name=None, step_name=None, record=None, mode='a'): """Transfer record to csv file.""" local_output_path = os.path.join(TaskOps().local_output_path, step_name) logging.debug( "recode to csv, local_output_path={}".format(local_output_path)) if not record_name and os.path.exists(local_output_path): return file_path = os.path.join(local_output_path, "{}.csv".format(record_name)) FileOps.make_base_dir(file_path) try: for key in record: if isinstance(record[key], dict) or isinstance( record[key], list): record[key] = str(record[key]) data = pd.DataFrame([record]) if not os.path.exists(file_path): data.to_csv(file_path, index=False) elif os.path.exists(file_path) and os.path.getsize( file_path) and mode == 'a': data.to_csv(file_path, index=False, mode=mode, header=0) else: data.to_csv(file_path, index=False, mode=mode) except Exception as ex: logging.info( 'Can not transfer record to csv file Error: {}'.format(ex))
def dataset_init(self): """Costruct method, which will load some dateset information.""" self.args.root_HR = FileOps.download_dataset(self.args.root_HR) self.args.root_LR = FileOps.download_dataset(self.args.root_LR) if self.args.subfile is not None: with open(self.args.subfile ) as f: # lmdb format has no self.args.subfile file_names = sorted([line.rstrip('\n') for line in f]) self.datatype = util.get_files_datatype(file_names) self.paths_HR = [ os.path.join(self.args.root_HR, file_name) for file_name in file_names ] self.paths_LR = [ os.path.join(self.args.root_LR, file_name) for file_name in file_names ] else: self.datatype = util.get_datatype(self.args.root_LR) self.paths_LR = util.get_paths_from_dir(self.args.root_LR) self.paths_HR = util.get_paths_from_dir(self.args.root_HR) if self.args.save_in_memory: self.imgs_LR = [self._read_img(path) for path in self.paths_LR] self.imgs_HR = [self._read_img(path) for path in self.paths_HR]
def dump_model_visual_info(trainer, epoch, model, inputs): """Dump model to tensorboard event files. :param trainer: trainer. :type worker: object that the class was inherited from DistributedWorker. :param model: model. :type model: model. :param inputs: input data. :type inputs: data. """ (_, visual, interval, title, worker_id, output_path) = _get_trainer_info(trainer) if visual is not True: return if epoch % interval != 0: return title = str(worker_id) _path = FileOps.join_path(output_path, title) FileOps.make_dir(_path) try: with SummaryWriter(_path) as writer: writer.add_graph(model, (inputs,)) except Exception as e: logging.error("Failed to dump model visual info, worker id: {}, epoch: {}, error: {}".format( worker_id, epoch, str(e) ))
def _get_model_desc(self): model_desc = self.trainer.model_desc if not model_desc or 'modules' not in model_desc: if ModelConfig.model_desc_file is not None: desc_file = ModelConfig.model_desc_file desc_file = desc_file.replace("{local_base_path}", self.trainer.local_base_path) if ":" not in desc_file: desc_file = os.path.abspath(desc_file) if ":" in desc_file: local_desc_file = FileOps.join_path( self.trainer.local_output_path, os.path.basename(desc_file)) FileOps.copy_file(desc_file, local_desc_file) desc_file = local_desc_file model_desc = Config(desc_file) logger.info("net_desc:{}".format(model_desc)) elif ModelConfig.model_desc is not None: model_desc = ModelConfig.model_desc elif ModelConfig.models_folder is not None: folder = ModelConfig.models_folder.replace( "{local_base_path}", self.trainer.local_base_path) pattern = FileOps.join_path(folder, "desc_*.json") desc_file = glob.glob(pattern)[0] model_desc = Config(desc_file) else: return None return model_desc
def dump(self): """Dump report to file.""" try: _file = FileOps.join_path(TaskOps().step_path, "reports.csv") FileOps.make_base_dir(_file) data = self.all_records data_dict = {} for step in data: step_data = step.serialize().items() for k, v in step_data: if k in data_dict: data_dict[k].append(v) else: data_dict[k] = [v] data = pd.DataFrame(data_dict) data.to_csv(_file, index=False) _file = os.path.join(TaskOps().step_path, ".reports") _dump_data = [ ReportServer._hist_records, ReportServer.__instances__ ] with open(_file, "wb") as f: pickle.dump(_dump_data, f, protocol=pickle.HIGHEST_PROTOCOL) self.backup_output_path() except Exception: logging.warning(traceback.format_exc())
def _backup(self): """Backup result worker folder.""" if self.need_backup is True and self.backup_base_path is not None: backup_worker_path = FileOps.join_path( self.backup_base_path, self.get_worker_subpath()) FileOps.copy_folder( self.get_local_worker_path(self.step_name, self.worker_id), backup_worker_path)
def _save_checkpoint(self, epoch, best=False): """Save model weights. :param epoch: current epoch :type epoch: int """ save_dir = os.path.join(self.worker_path, str(epoch)) FileOps.make_dir(save_dir) for name in self.model.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = FileOps.join_path(save_dir, save_filename) net = getattr(self.model, 'net' + name) best_file = FileOps.join_path(self.worker_path, "model_{}.pth".format(name)) if self.cfg.cuda and torch.cuda.is_available(): # torch.save(net.module.cpu().state_dict(), save_path) torch.save(net.module.state_dict(), save_path) # net.cuda() if best: torch.save(net.module.state_dict(), best_file) else: torch.save(net.cpu().state_dict(), save_path) if best: torch.save(net.cpu().state_dict(), best_file)
def search(self): """Search an id and hps from hpo.""" sample = self.hpo.propose() if sample is None: return None re_hps = {} sample = copy.deepcopy(sample) sample_id = sample.get('config_id') trans_para = sample.get('configs') rung_id = sample.get('rung_id') all_para = sample.get('all_configs') re_hps['dataset.transforms'] = [{ 'type': 'PBATransformer', 'para_array': trans_para, 'all_para': all_para, 'operation_names': self.operation_names }] checkpoint_path = FileOps.join_path(self.local_base_path, 'worker', 'cache', str(sample_id), 'checkpoint') FileOps.make_dir(checkpoint_path) if os.path.exists(checkpoint_path): re_hps['trainer.checkpoint_path'] = checkpoint_path if 'epoch' in sample: re_hps['trainer.epochs'] = sample.get('epoch') return dict(worker_id=sample_id, encoded_desc=re_hps, rung_id=rung_id)
def _save_best_model(self): """Save best model.""" if zeus.is_torch_backend(): torch.save(self.trainer.model.state_dict(), self.trainer.weights_file) elif zeus.is_tf_backend(): worker_path = self.trainer.get_local_worker_path() model_id = "model_{}".format(self.trainer.worker_id) weights_folder = FileOps.join_path(worker_path, model_id) FileOps.make_dir(weights_folder) checkpoint_file = tf.train.latest_checkpoint(worker_path) ckpt_globs = glob.glob("{}.*".format(checkpoint_file)) for _file in ckpt_globs: dst_file = model_id + os.path.splitext(_file)[-1] FileOps.copy_file(_file, FileOps.join_path(weights_folder, dst_file)) FileOps.copy_file(FileOps.join_path(worker_path, 'checkpoint'), weights_folder) elif zeus.is_ms_backend(): worker_path = self.trainer.get_local_worker_path() save_path = os.path.join( worker_path, "model_{}.ckpt".format(self.trainer.worker_id)) for file in os.listdir(worker_path): if file.startswith("CKP") and file.endswith(".ckpt"): self.weights_file = FileOps.join_path(worker_path, file) os.rename(self.weights_file, save_path)
def before_train(self, logs=None): """Call before_train of the managed callbacks.""" super().before_train(logs) """Be called before the training process.""" hpo_result = FileOps.load_pickle( FileOps.join_path(self.trainer.local_output_path, 'best_config.pickle')) logging.info("loading stage1_hpo_result \n{}".format(hpo_result)) feature_interaction_score = hpo_result['feature_interaction_score'] print('feature_interaction_score:', feature_interaction_score) sorted_pairs = sorted(feature_interaction_score.items(), key=lambda x: abs(x[1]), reverse=True) if ModelConfig.model_desc: fis_ratio = ModelConfig.model_desc["custom"]["fis_ratio"] else: fis_ratio = 1.0 top_k = int(len(feature_interaction_score) * min(1.0, fis_ratio)) self.selected_pairs = list(map(lambda x: x[0], sorted_pairs[:top_k])) # add selected_pairs setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
def after_valid(self, logs=None): """Call after_valid of the managed callbacks.""" self.model = self.trainer.model feature_interaction_score = self.model.get_feature_interaction_score() print('get feature_interaction_score', feature_interaction_score) feature_interaction = [] for feature in feature_interaction_score: if abs(feature_interaction_score[feature]) > 0: feature_interaction.append(feature) print('get feature_interaction', feature_interaction) curr_auc = float(self.trainer.valid_metrics.results['auc']) if curr_auc > self.best_score: best_config = { 'score': curr_auc, 'feature_interaction': feature_interaction } logging.info("BEST CONFIG IS\n{}".format(best_config)) pickle_result_file = FileOps.join_path( self.trainer.local_output_path, 'best_config.pickle') logging.info("Saved to {}".format(pickle_result_file)) FileOps.dump_pickle(best_config, pickle_result_file) self.best_score = curr_auc
def __init__(self, **kwargs): """Construct the Imagenet class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) split = 'train' if self.mode == 'train' else 'val' local_data_path = FileOps.join_path(self.args.data_path, split) ImageFolder.__init__(self, root=local_data_path, transform=Compose(self.transforms.__transform__))
def before_train(self, logs=None): """Be called before the whole train process.""" self.trainer.config.call_metrics_on_train = False self.cfg = self.trainer.config self.worker_id = self.trainer.worker_id self.local_base_path = self.trainer.local_base_path self.local_output_path = self.trainer.local_output_path self.result_path = FileOps.join_path(self.trainer.local_base_path, "result") FileOps.make_dir(self.result_path) self.logger_patch()
def load_model(self): """Load model.""" if not self.model_desc and not self.weights_file: saved_folder = self.get_local_worker_path(self.step_name, self.worker_id) self.weights_file = FileOps.join_path( saved_folder, 'model_{}.pth'.format(self.worker_id)) self.model_desc = FileOps.join_path( saved_folder, 'desc_{}.json'.format(self.worker_id)) if 'modules' not in self.model_desc: self.model_desc = ModelConfig.model_desc self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
def save_results(self): """Save the results of evolution contains the information of pupulation and elitism.""" _path = FileOps.join_path(self.local_output_path, General.step_name) FileOps.make_dir(_path) arch_file = FileOps.join_path(_path, 'arch.txt') arch_child = FileOps.join_path(_path, 'arch_child.txt') sel_arch_file = FileOps.join_path(_path, 'selected_arch.npy') sel_arch = [] with open(arch_file, 'a') as fw_a, open(arch_child, 'a') as fw_ac: writer_a = csv.writer(fw_a, lineterminator='\n') writer_ac = csv.writer(fw_ac, lineterminator='\n') writer_ac.writerow( ['Population Iteration: ' + str(self.evolution_count + 1)]) for c in range(self.individual_num): writer_ac.writerow( self._log_data(net_info_type='active_only', pop=self.pop[c], value=self.pop[c].fitness)) writer_a.writerow( ['Population Iteration: ' + str(self.evolution_count + 1)]) for c in range(self.elitism_num): writer_a.writerow( self._log_data(net_info_type='active_only', pop=self.elitism[c], value=self.elit_fitness[c])) sel_arch.append(self.elitism[c].gene) sel_arch = np.stack(sel_arch) np.save(sel_arch_file, sel_arch) if self.backup_base_path is not None: FileOps.copy_folder(self.local_output_path, self.backup_base_path)
def copy_pareto_output(self, step_name=None, worker_ids=[]): """Copy files related to pareto from worker to output.""" taskops = TaskOps() local_output_path = os.path.join(taskops.local_output_path, step_name) if not (step_name and os.path.exists(local_output_path)): return for worker_id in worker_ids: desDir = os.path.join(local_output_path, str(worker_id)) FileOps.make_dir(desDir) local_worker_path = taskops.get_worker_subpath( step_name, str(worker_id)) srcDir = FileOps.join_path(taskops.local_base_path, local_worker_path) copy_search_file(srcDir, desDir)
def before_train(self, logs=None): """Call before_train of the managed callbacks.""" super().before_train(logs) """Be called before the training process.""" hpo_result = FileOps.load_pickle(FileOps.join_path( self.trainer.local_output_path, 'best_config.pickle')) logging.info("loading stage1_hpo_result \n{}".format(hpo_result)) self.selected_pairs = hpo_result['feature_interaction'] logging.info('feature_interaction:', self.selected_pairs) # add selected_pairs setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
def _init_dataloader(self): """Init dataloader from timm.""" if self.distributed and hvd.local_rank( ) == 0 and 'remote_data_dir' in self.config.dataset: FileOps.copy_folder(self.config.dataset.remote_data_dir, self.config.dataset.data_dir) if self.distributed: hvd.join() args = self.config.dataset train_dir = os.path.join(self.config.dataset.data_dir, 'train') dataset_train = Dataset(train_dir) world_size, rank = None, None if self.distributed: world_size, rank = hvd.size(), hvd.rank() self.trainer.train_loader = create_loader( dataset_train, input_size=tuple(args.input_size), batch_size=args.batch_size, is_training=True, use_prefetcher=self.config.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='random', mean=tuple(args.mean), std=tuple(args.std), num_workers=args.workers, distributed=self.distributed, world_size=world_size, rank=rank) valid_dir = os.path.join(self.config.dataset.data_dir, 'val') dataset_eval = Dataset(valid_dir) self.trainer.valid_loader = create_loader( dataset_eval, input_size=tuple(args.input_size), batch_size=4 * args.batch_size, is_training=False, use_prefetcher=self.config.prefetcher, interpolation=args.interpolation, mean=tuple(args.mean), std=tuple(args.std), num_workers=args.workers, distributed=self.distributed, world_size=world_size, rank=rank) self.trainer.batch_num_train = len(self.trainer.train_loader) self.trainer.batch_num_valid = len(self.trainer.valid_loader)
def _get_pretrained_model_file(self): if ModelConfig.pretrained_model_file: model_file = ModelConfig.pretrained_model_file model_file = model_file.replace("{local_base_path}", self.trainer.local_base_path) model_file = model_file.replace("{worker_id}", str(self.trainer.worker_id)) if ":" not in model_file: model_file = os.path.abspath(model_file) if ":" in model_file: local_model_file = FileOps.join_path( self.trainer.local_output_path, os.path.basename(model_file)) FileOps.copy_file(model_file, local_model_file) model_file = local_model_file return model_file else: return None
def save_master_ip(ip_address, port, args): """Write the ip and port in a system path. :param str ip_address: The `ip_address` need to write. :param str port: The `port` need to write. :param argparse.ArgumentParser args: `args` is a argparse that should contain `init_method`, `rank` and `world_size`. """ temp_folder = TaskOps().temp_path FileOps.make_dir(temp_folder) file_path = os.path.join(temp_folder, 'ip_address.txt') logging.info("write ip, file path={}".format(file_path)) with open(file_path, 'w') as f: f.write(ip_address + "\n") f.write(port + "\n")
def _load_checkpoint(self): """Load checkpoint.""" if zeus.is_torch_backend(): checkpoint_file = FileOps.join_path( self.trainer.get_local_worker_path(), self.trainer.checkpoint_file_name) if os.path.exists(checkpoint_file): try: logging.info("Load checkpoint file, file={}".format( checkpoint_file)) checkpoint = torch.load(checkpoint_file) self.trainer.model.load_state_dict(checkpoint["weight"]) self.trainer.optimizer.load_state_dict( checkpoint["optimizer"]) self.trainer.lr_scheduler.load_state_dict( checkpoint["lr_scheduler"]) if self.trainer._resume_training: epoch = checkpoint["epoch"] self.trainer._start_epoch = checkpoint["epoch"] logging.info( "Resume fully train, change start epoch to {}". format(self.trainer._start_epoch)) except Exception as e: logging.info("Load checkpoint failed {}".format(e)) else: logging.info('Use default model')
def load_records_from_model_folder(cls, model_folder): """Transfer json_file to records.""" if not model_folder or not os.path.exists(model_folder): logging.error( "Failed to load records from model folder, folder={}".format( model_folder)) return [] records = [] pattern = FileOps.join_path(model_folder, "desc_*.json") files = glob.glob(pattern) for _file in files: try: with open(_file) as f: worker_id = _file.split(".")[-2].split("_")[-1] weights_file = os.path.join( os.path.dirname(_file), "model_{}.pth".format(worker_id)) if os.path.exists(weights_file): sample = dict(worker_id=worker_id, desc=json.load(f), weights_file=weights_file) else: sample = dict(worker_id=worker_id, desc=json.load(f)) record = ReportRecord().load_dict(sample) records.append(record) except Exception as ex: logging.info( 'Can not read records from json because {}'.format(ex)) return records
def _generate_init_model(self): """Generate init model by loading pretrained model. :return: initial model after loading pretrained model :rtype: torch.nn.Module """ model_init = self._new_model_init() chn_mask = self._init_chn_node_mask() if vega.is_torch_backend(): checkpoint = torch.load(self.config.init_model_file + '.pth') model_init.load_state_dict(checkpoint) model = PruneMobileNet(model_init).apply(chn_mask) model.to(self.device) elif vega.is_tf_backend(): model = model_init with tf.compat.v1.Session( config=self.trainer._init_session_config()) as sess: saver = tf.compat.v1.train.import_meta_graph("{}.meta".format( self.config.init_model_file)) saver.restore(sess, self.config.init_model_file) all_weight = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.VARIABLES) all_weight = [ t for t in all_weight if not t.name.endswith('Momentum:0') ] PruneMobileNet(all_weight).apply(chn_mask) save_file = FileOps.join_path( self.trainer.get_local_worker_path(), 'prune_model') saver.save(sess, save_file) elif vega.is_ms_backend(): parameter_dict = load_checkpoint(self.config.init_model_file) load_param_into_net(model_init, parameter_dict) model = PruneMobileNet(model_init).apply(chn_mask) return model
def __init__(self, **kwargs): """Construct the Cifar10 class.""" Dataset.__init__(self, **kwargs) self.args.data_path = FileOps.download_dataset(self.args.data_path) is_train = self.mode == 'train' or self.mode == 'val' and self.args.train_portion < 1 self.base_folder = 'cifar-10-batches-py' self.transform = Compose(self.transforms.__transform__) if is_train: files_list = [ "data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5" ] else: files_list = ['test_batch'] self.data = [] self.targets = [] # now load the picked numpy arrays for file_name in files_list: file_path = os.path.join(self.args.data_path, self.base_folder, file_name) with open(file_path, 'rb') as f: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def _get_current_step_records(self): step_name = General.step_name models_folder = PipeStepConfig.pipe_step.get("models_folder") cur_index = PipelineConfig.steps.index(step_name) if cur_index >= 1 or models_folder: # records = Report().get_step_records(PipelineConfig.steps[cur_index - 1]) if not models_folder: models_folder = FileOps.join_path( TaskOps().local_output_path, PipelineConfig.steps[cur_index - 1]) models_folder = models_folder.replace("{local_base_path}", TaskOps().local_base_path) records = Report().load_records_from_model_folder(models_folder) else: records = self._load_single_model_records() final_records = [] for record in records: if not record.weights_file: logger.error("Model file is not existed, id={}".format( record.worker_id)) else: record.step_name = General.step_name final_records.append(record) logging.debug("Records: {}".format(final_records)) return final_records
def load_model(self): """Load model.""" self.saved_folder = self.get_local_worker_path(self.step_name, self.worker_id) if not self.model_desc: self.model_desc = FileOps.join_path(self.saved_folder, 'desc_{}.json'.format(self.worker_id)) if not self.weights_file: if zeus.is_torch_backend(): self.weights_file = FileOps.join_path(self.saved_folder, 'model_{}.pth'.format(self.worker_id)) elif zeus.is_ms_backend(): for file in os.listdir(self.saved_folder): if file.startswith("CKP") and file.endswith(".ckpt"): self.weights_file = FileOps.join_path(self.saved_folder, file) if 'modules' not in self.model_desc: self.model_desc = ModelConfig.model_desc self.model = ModelZoo.get_model(self.model_desc, self.weights_file)