def init_loader_config(opt): train_config = Config(**opt, crop='random', feature_callbacks=[], label_callbacks=[]) benchmark_config = Config(**opt, crop=None, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) infer_config = Config(**opt, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) benchmark_config.batch = opt.test_batch or 1 benchmark_config.steps_per_epoch = -1 if opt.channel == 1: train_config.convert_to = 'gray' benchmark_config.convert_to = 'gray' if opt.output_color == 'RGB': benchmark_config.convert_to = 'yuv' benchmark_config.feature_callbacks = train_config.feature_callbacks + [to_gray()] benchmark_config.label_callbacks = train_config.label_callbacks + [to_gray()] benchmark_config.output_callbacks = [to_rgb()] benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) else: train_config.convert_to = 'rgb' benchmark_config.convert_to = 'rgb' benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) if opt.add_custom_callbacks is not None: for fn in opt.add_custom_callbacks: train_config.feature_callbacks += [globals()[fn]] benchmark_config.feature_callbacks += [globals()[fn]] infer_config.feature_callbacks += [globals()[fn]] if opt.lr_decay: train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) # modcrop: A boolean to specify whether to crop the edge of images to be divisible # by `scale`. It's useful when to provide batches with original shapes. infer_config.modcrop = False return train_config, benchmark_config, infer_config
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path('Parameters/root.{}'.format(_ext)) if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path('Parameters/{}.{}'.format( opt.model, _ext)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model, {}) opt.update(model_params) suppress_opt_by_args(model_params, *args) model = get_model(flags.model)(**model_params) if flags.cuda: model.cuda() root = f'{flags.save_dir}/{flags.model}' if flags.comment: root += '_' + flags.comment verbosity = logging.DEBUG if flags.verbose else logging.INFO trainer = model.trainer datasets = load_datasets(data_config_file) dataset = datasets[flags.dataset.upper()] train_config = Config(crop=opt.train_data_crop, feature_callbacks=[], label_callbacks=[], convert_to='rgb', **opt) if opt.channel == 1: train_config.convert_to = 'gray' if opt.lr_decay: train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) train_config.random_val = not opt.traced_val train_config.cuda = flags.cuda if opt.verbose: dump(opt) with trainer(model, root, verbosity, opt.pth) as t: if opt.seed is not None: t.set_seed(opt.seed) tloader = QuickLoader(dataset, 'train', train_config, True, flags.thread) vloader = QuickLoader(dataset, 'val', train_config, False, flags.thread, batch=1, crop=opt.val_data_crop, steps_per_epoch=opt.val_num) t.fit([tloader, vloader], train_config) if opt.export: t.export(opt.export)