def init_loader_config(opt): train_config = Config(**opt, crop='random', feature_callbacks=[], label_callbacks=[]) benchmark_config = Config(**opt, crop=None, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) infer_config = Config(**opt, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) benchmark_config.batch = opt.test_batch or 1 benchmark_config.steps_per_epoch = -1 if opt.channel == 1: train_config.convert_to = 'gray' benchmark_config.convert_to = 'gray' if opt.output_color == 'RGB': benchmark_config.convert_to = 'yuv' benchmark_config.feature_callbacks = train_config.feature_callbacks + [to_gray()] benchmark_config.label_callbacks = train_config.label_callbacks + [to_gray()] benchmark_config.output_callbacks = [to_rgb()] benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) else: train_config.convert_to = 'rgb' benchmark_config.convert_to = 'rgb' benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) if opt.add_custom_callbacks is not None: for fn in opt.add_custom_callbacks: train_config.feature_callbacks += [globals()[fn]] benchmark_config.feature_callbacks += [globals()[fn]] infer_config.feature_callbacks += [globals()[fn]] if opt.lr_decay: train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) # modcrop: A boolean to specify whether to crop the edge of images to be divisible # by `scale`. It's useful when to provide batches with original shapes. infer_config.modcrop = False return train_config, benchmark_config, infer_config
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) overwrite_from_env(opt) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path('Parameters/root.{}'.format(_ext)) if opt.parameter: model_config_file = Path(opt.parameter) else: model_config_file = Path('Parameters/{}.{}'.format( opt.model, _ext)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) model = get_model(opt.model)(**model_params) if opt.cuda: model.cuda() root = f'{opt.save_dir}/{opt.model}' if opt.comment: root += '_' + opt.comment verbosity = logging.DEBUG if opt.verbose else logging.INFO trainer = model.trainer datasets = load_datasets(data_config_file) try: test_datas = [datasets[t.upper()] for t in opt.test] run_benchmark = True except KeyError: test_datas = [] for pattern in opt.test: test_data = Dataset(test=_glob_absolute_pattern(pattern), test_pair=_glob_absolute_pattern(pattern), mode='pil-image1', modcrop=False, parser='custom_pairs') father = Path(pattern) while not father.is_dir(): if father.parent == father: break father = father.parent test_data.name = father.stem test_datas.append(test_data) run_benchmark = False if opt.verbose: dump(opt) for test_data in test_datas: loader_config = Config(convert_to='rgb', feature_callbacks=[], label_callbacks=[], output_callbacks=[], **opt) loader_config.batch = 1 loader_config.subdir = test_data.name loader_config.output_callbacks += [ save_image(root, opt.output_index, opt.auto_rename) ] if opt.channel == 1: loader_config.convert_to = 'gray' if opt.output_color == 'RGB': loader_config.convert_to = 'yuv' loader_config.feature_callbacks = [to_gray()] loader_config.label_callbacks = [to_gray()] loader_config.output_callbacks.insert(0, to_rgb()) with trainer(model, root, verbosity, opt.pth) as t: if opt.seed is not None: t.set_seed(opt.seed) loader = QuickLoader(test_data, 'test', loader_config, n_threads=opt.thread) loader_config.epoch = opt.epoch if run_benchmark: t.benchmark(loader, loader_config) else: t.infer(loader, loader_config)