def _copy_needed_file(self): if "pareto_front_file" not in self.cfg or self.cfg.pareto_front_file is None: raise FileNotFoundError("Config item paretor_front_file not found in config file.") init_pareto_front_file = self.cfg.pareto_front_file.replace("{local_base_path}", self.local_base_path) self.pareto_front_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "pareto_front.csv") FileOps.make_base_dir(self.pareto_front_file) FileOps.copy_file(init_pareto_front_file, self.pareto_front_file) if "random_file" not in self.cfg or self.cfg.random_file is None: raise FileNotFoundError("Config item random_file not found in config file.") init_random_file = self.cfg.random_file.replace("{local_base_path}", self.local_base_path) self.random_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "random.csv") FileOps.copy_file(init_random_file, self.random_file)
def dataset_init(self): """Initialize dataset.""" if not os.path.exists(self.args.HR_dir): logging.info("Moving data from s3 to local") FileOps.copy_file(self.args.remote_data_file, self.args.local_data_file) os.system('tar -xf %s -C %s && rm -rf %s' % (self.args.local_data_file, self.args.local_data_root, self.args.local_data_file)) # os.system('unzip -d -q %s && rm -rf %s' % (local_save_path, local_save_path)) logging.info('Move done!') # if not "train" in data_args.keys(): raise KeyError("Train data config is must!") self.Y_paths = sorted(self.make_dataset( self.args.LR_dir, float("inf"))) if self.args.LR_dir is not None else None self.HR_paths = sorted( self.make_dataset( self.args.HR_dir, float("inf"))) if self.args.HR_dir is not None else None self.trans_norm = transforms.Compose( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) for i in range(len(self.HR_paths)): file_name = os.path.basename(self.HR_paths[i]) if (file_name.find("0401") >= 0): logging.info( "We find the possion of NO. 401 in the HR patch NO. {}". format(i)) self.HR_paths = self.HR_paths[:i] break for i in range(len(self.Y_paths)): file_name = os.path.basename(self.Y_paths[i]) if (file_name.find("0401") >= 0): logging.info( "We find the possion of NO. 401 in the LR patch NO. {}". format(i)) self.Y_paths = self.Y_paths[i:] break self.Y_size = len(self.Y_paths) if self.train: self.load_size = self.args.load_size self.crop_size = self.args.crop_size self.upscale = self.args.upscale self.augment_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip() ]) self.HR_transform = transforms.RandomCrop( int(self.crop_size * self.upscale)) self.LR_transform = transforms.RandomCrop(self.crop_size)
def _new_model_init(self, model_prune): """Init new model. :param model_prune: searched pruned model :type model_prune: torch.nn.Module :return: initial model after loading pretrained model :rtype: torch.nn.Module """ init_model_file = self.config.init_model_file if ":" in init_model_file: local_path = FileOps.join_path( self.trainer.get_local_worker_path(), os.path.basename(init_model_file)) FileOps.copy_file(init_model_file, local_path) self.config.init_model_file = local_path network_desc = copy.deepcopy(self.base_net_desc) network_desc.backbone.chn = network_desc.backbone.base_chn network_desc.backbone.chn_node = network_desc.backbone.base_chn_node network_desc.backbone.encoding = model_prune.encoding model_init = NetworkDesc(network_desc).to_model() return model_init
import logging import horovod.torch as hvd from vega.core.common.class_factory import ClassFactory from vega.core.common.user_config import UserConfig from vega.core.common.file_ops import FileOps parser = argparse.ArgumentParser(description='Horovod Fully Train') parser.add_argument('--cf_file', type=str, help='ClassFactory pickle file') args = parser.parse_args() if 'VEGA_INIT_ENV' in os.environ: exec(os.environ.copy()['VEGA_INIT_ENV']) logging.info('start horovod setting') hvd.init() try: import moxing as mox mox.file.set_auth(obs_client_log=False) except: pass FileOps.copy_file(args.cf_file, './cf_file.pickle') hvd.join() with open('./cf_file.pickle', 'rb') as f: cf_content = pickle.load(f) ClassFactory.__configs__ = cf_content.get('configs') ClassFactory.__registry__ = cf_content.get('registry') UserConfig().__data__ = cf_content.get('data') cls_trainer = ClassFactory.get_cls('trainer') trainer = cls_trainer(None, 0) trainer.train_process()