def main(*args): flags = tf.flags.FLAGS opt = Config() for key in flags: opt.setdefault(key, flags.get_flag_value(key, None)) check_args(opt) data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _suffix in ('json', 'yaml'): # for compatibility # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path(f'parameters/root.{_suffix}') if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path(f'parameters/{opt.model}.{_suffix}') if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model) opt.update(model_params) model = get_model(opt.model)(**model_params) root = '{}/{}'.format(opt.save_dir, model.name) if opt.comment: root += '_' + opt.comment opt.root = root verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO # map model to trainer, ~~manually~~ automatically, by setting `_trainer` # attribute in models trainer = model.trainer train_data, test_data, infer_data = fetch_datasets(data_config_file, opt) train_config, test_config, infer_config = init_loader_config(opt) test_config.subdir = test_data.name infer_config.subdir = 'infer' # start fitting! dump(opt) with trainer(model, root, verbosity) as t: # prepare loader loader = partial(QuickLoader, n_threads=opt.threads) train_loader = loader(train_data, 'train', train_config, augmentation=True) val_loader = loader(train_data, 'val', train_config, crop='center', steps_per_epoch=1) test_loader = loader(test_data, 'test', test_config) infer_loader = loader(infer_data, 'infer', infer_config) # fit t.fit([train_loader, val_loader], train_config) # validate t.benchmark(test_loader, test_config) # do inference t.infer(infer_loader, infer_config) if opt.export: t.export(opt.root + '/exported', opt.freeze)
def main(*args): flags = tf.flags.FLAGS flags.mark_as_parsed() opt = Config() for key in flags: opt.setdefault(key, flags.get_flag_value(key, None)) check_args(opt) data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _suffix in ('json', 'yaml'): # apply a 2-stage (or master-slave) configuration, master can be override by slave model_config_root = Path('parameters/{}.{}'.format('root', _suffix)) model_config_file = Path('parameters/{}.{}'.format(opt.model, _suffix)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model) opt.update(model_params) model = get_model(opt.model)(**model_params) root = '{}/{}_sc{}_c{}'.format(opt.save_dir, model.name, opt.scale, opt.channel) if opt.comment: root += '_' + opt.comment opt.root = root verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO # map model to trainer, manually if opt.model == 'zssr': trainer = ZSSR elif opt.model == 'frvsr': trainer = FRVSR else: trainer = VSR train_data, test_data, infer_data = fetch_datasets(data_config_file, opt) train_config, test_config, infer_config = init_loader_config(opt) test_config.subdir = test_data.name # start fitting! with trainer(model, root, verbosity) as t: # prepare loader loader = partial(QuickLoader, n_threads=opt.threads) train_loader = loader(train_data, 'train', train_config, augmentation=True) val_loader = loader(train_data, 'val', train_config, augmentation=True, crop='center', steps_per_epoch=1) test_loader = loader(test_data, 'test', test_config) infer_loader = loader(infer_data, 'infer', infer_config) # fit t.fit([train_loader, val_loader], train_config) # validate t.benchmark(test_loader, test_config) # do inference t.infer(infer_loader, infer_config) if opt.export: t.export(opt.root)