def main(entry: str, ddf_file): entry = entry.upper() all_data = load_datasets(ddf_file) if entry not in all_data: raise KeyError(f"The dataset `{entry}` not found in the DDF") data = all_data.get(entry) print(f"Dataset: {data.name}") def _check(name: str): print(f"\n========= CHECKING {name} =========\n") if name in data and data[name] is not None: print(f"Found `{name}` set in \"{data.name}\":") _hr = data[name].hr _lr = data[name].lr video_type = _hr.as_video if video_type: print(f"\"{data.name}\" is video data") if _hr is not None: _hr = _hr.compile() print(f"Found {len(_hr)} ground-truth {name} data") if _lr is not None: _lr = _lr.compile() print(f"Found {len(_lr)} custom degraded {name} data") if len(_hr) != len(_lr): print( f" [E] Ground-truth data and degraded data quantity not matched!!" ) elif video_type: for x, y in zip(_hr, _lr): if x.frames != y.frames: print( f" [E] Video clip {x.name}|{y.name} quantity not matched!!" ) else: print(f"{data.name} doesn't contain any {name} data.") _check('train') _check('val') _check('test')
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) overwrite_from_env(opt) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise FileNotFoundError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat if opt.parameter: model_config_file = Path(opt.parameter) else: model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}') if model_config_file.exists(): opt.update(compat_param(Config(str(model_config_file)))) # get model parameters from pre-defined YAML file model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) # construct model model = get_model(opt.model)(**model_params) if opt.cuda: model.cuda() if opt.pretrain: model.load(opt.pretrain) root = f'{opt.save_dir}/{opt.model}' if opt.comment: root += '_' + opt.comment root = Path(root) datasets = load_datasets(data_config_file) try: test_datas = [datasets[t.upper()] for t in opt.test] if opt.test else [] except KeyError: test_datas = [Config(test=Config(lr=Dataset(*opt.test)), name='infer')] if opt.video: test_datas[0].test.lr.use_like_video_() # enter model executor environment with model.get_executor(root) as t: for data in test_datas: run_benchmark = False if data.test.hr is None else True if run_benchmark: ld = Loader(data.test.hr, data.test.lr, opt.scale, threads=opt.threads) else: ld = Loader(data.test.hr, data.test.lr, threads=opt.threads) if opt.channel == 1: # convert data color space to grayscale ld.set_color_space('hr', 'L') ld.set_color_space('lr', 'L') config = t.query_config(opt) config.inference_results_hooks = [ save_inference_images(root / data.name, opt.output_index, opt.auto_rename) ] if run_benchmark: t.benchmark(ld, config) else: t.infer(ld, config) if opt.export: t.export(opt.export)
def main(): flags, args = parser.parse_known_args() opt = Config() # An EasyDict object # overwrite flag values into opt object for pair in flags._get_kwargs(): opt.setdefault(*pair) # fetch dataset descriptions data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise FileNotFoundError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat if opt.parameter: model_config_file = Path(opt.parameter) else: model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}') if model_config_file.exists(): opt.update(compat_param(Config(str(model_config_file)))) # get model parameters from pre-defined YAML file model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) # construct model model = get_model(opt.model)(**model_params) if opt.cuda: model.cuda() if opt.pretrain: model.load(opt.pretrain) root = f'{opt.save_dir}/{opt.model}' if opt.comment: root += '_' + opt.comment dataset = load_datasets(data_config_file, opt.dataset) # construct data loader for training lt = Loader(dataset.train.hr, dataset.train.lr, opt.scale, threads=opt.threads) lt.image_augmentation() # construct data loader for validating lv = None if dataset.val is not None: lv = Loader(dataset.val.hr, dataset.val.lr, opt.scale, threads=opt.threads) lt.cropper(RandomCrop(opt.scale)) if opt.traced_val and lv is not None: lv.cropper(CenterCrop(opt.scale)) elif lv is not None: lv.cropper(RandomCrop(opt.scale)) if opt.channel == 1: # convert data color space to grayscale lt.set_color_space('hr', 'L') lt.set_color_space('lr', 'L') if lv is not None: lv.set_color_space('hr', 'L') lv.set_color_space('lr', 'L') # enter model executor environment with model.get_executor(root) as t: config = t.query_config(opt) if opt.lr_decay: config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) t.fit([lt, lv], config) if opt.export: t.export(opt.export)