def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) overwrite_from_env(opt) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise FileNotFoundError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat if opt.parameter: model_config_file = Path(opt.parameter) else: model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}') if model_config_file.exists(): opt.update(compat_param(Config(str(model_config_file)))) # get model parameters from pre-defined YAML file model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) # construct model model = get_model(opt.model)(**model_params) if opt.cuda: model.cuda() if opt.pretrain: model.load(opt.pretrain) root = f'{opt.save_dir}/{opt.model}' if opt.comment: root += '_' + opt.comment root = Path(root) datasets = load_datasets(data_config_file) try: test_datas = [datasets[t.upper()] for t in opt.test] if opt.test else [] except KeyError: test_datas = [Config(test=Config(lr=Dataset(*opt.test)), name='infer')] if opt.video: test_datas[0].test.lr.use_like_video_() # enter model executor environment with model.get_executor(root) as t: for data in test_datas: run_benchmark = False if data.test.hr is None else True if run_benchmark: ld = Loader(data.test.hr, data.test.lr, opt.scale, threads=opt.threads) else: ld = Loader(data.test.hr, data.test.lr, threads=opt.threads) if opt.channel == 1: # convert data color space to grayscale ld.set_color_space('hr', 'L') ld.set_color_space('lr', 'L') config = t.query_config(opt) config.inference_results_hooks = [ save_inference_images(root / data.name, opt.output_index, opt.auto_rename) ] if run_benchmark: t.benchmark(ld, config) else: t.infer(ld, config) if opt.export: t.export(opt.export)
def main(): flags, args = parser.parse_known_args() opt = Config() # An EasyDict object # overwrite flag values into opt object for pair in flags._get_kwargs(): opt.setdefault(*pair) # fetch dataset descriptions data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise FileNotFoundError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat if opt.parameter: model_config_file = Path(opt.parameter) else: model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}') if model_config_file.exists(): opt.update(compat_param(Config(str(model_config_file)))) # get model parameters from pre-defined YAML file model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) # construct model model = get_model(opt.model)(**model_params) if opt.cuda: model.cuda() if opt.pretrain: model.load(opt.pretrain) root = f'{opt.save_dir}/{opt.model}' if opt.comment: root += '_' + opt.comment dataset = load_datasets(data_config_file, opt.dataset) # construct data loader for training lt = Loader(dataset.train.hr, dataset.train.lr, opt.scale, threads=opt.threads) lt.image_augmentation() # construct data loader for validating lv = None if dataset.val is not None: lv = Loader(dataset.val.hr, dataset.val.lr, opt.scale, threads=opt.threads) lt.cropper(RandomCrop(opt.scale)) if opt.traced_val and lv is not None: lv.cropper(CenterCrop(opt.scale)) elif lv is not None: lv.cropper(RandomCrop(opt.scale)) if opt.channel == 1: # convert data color space to grayscale lt.set_color_space('hr', 'L') lt.set_color_space('lr', 'L') if lv is not None: lv.set_color_space('hr', 'L') lv.set_color_space('lr', 'L') # enter model executor environment with model.get_executor(root) as t: config = t.query_config(opt) if opt.lr_decay: config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) t.fit([lt, lv], config) if opt.export: t.export(opt.export)