def main(*args, **kwargs): additional_functions = {} callbacks = FLAGS.f or [] callbacks += FLAGS.f2 or [] callbacks += FLAGS.f3 or [] if callbacks: m = import_module('custom_api') for fn_name in callbacks: try: if '#' in fn_name: fn_name = fn_name.split('#')[0] additional_functions[fn_name] = m.__dict__[fn_name] except KeyError: raise KeyError( "Function [{}] couldn't be found in 'custom_api.py'". format(fn_name)) if FLAGS.mode == 'run': return Run.run(*args[0][1:], **additional_functions) if FLAGS.mode == 'eval': if FLAGS.checkpoint_dir: return EvalModelCheckpoint.evaluate(*args[0][1:]) elif FLAGS.input_dir: return EvalDataDirectory.evaluate(*args[0][1:]) print(("In mode 'eval', parse either '--checkpoint_dir' with '--model'" " or '--input_dir' to evaluate models, see details --helpfull"))
def main(*args, **kwargs): additional_functions = {} callbacks = [] callbacks += FLAGS.f or [] callbacks += FLAGS.f2 or [] callbacks += FLAGS.f3 or [] if callbacks: m = import_module('custom_api') for fn_name in callbacks: try: if '#' in fn_name: fn_name = fn_name.split('#')[0] additional_functions[fn_name] = m.__dict__[fn_name] except KeyError: raise KeyError( "Function [{}] couldn't be found in 'custom_api.py'".format(fn_name)) return Run.run(*args[0][1:], **additional_functions)
def main(*args, **kwargs): additional_functions = {} if FLAGS.add_custom_callbacks: m = import_module('custom_api') for fn_name in FLAGS.add_custom_callbacks: try: additional_functions[fn_name] = m.__dict__[fn_name] except KeyError: raise KeyError(f"Function [{fn_name}] couldn't be found in 'custom_api.py'") if FLAGS.mode == 'run': return Run.run(**additional_functions) if FLAGS.mode == 'eval': if FLAGS.checkpoint_dir: return EvalModelCheckpoint.evaluate() elif FLAGS.input_dir: return EvalDataDirectory.evaluate() print(("In mode 'eval', parse either '--checkpoint_dir' with '--model'" " or '--input_dir' to evaluate models, see details --helpfull"))
def main(*args, **kwargs): flags = tf.flags.FLAGS check_args(flags) opt = Config() for key in flags: opt.setdefault(key, flags.get_flag_value(key, None)) opt.steps_per_epoch = opt.num # set random seed at first np.random.seed(opt.seed) # check output dir output_dir = Path(flags.save_dir) output_dir.mkdir(exist_ok=True, parents=True) writer = tf.io.TFRecordWriter( str(output_dir / "{}.tfrecords".format(opt.dataset))) data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") crf_matrix = np.load(opt.crf) if opt.crf else None # init loader config train_data, _, _ = Run.fetch_datasets(data_config_file, opt) train_config, _, _ = Run.init_loader_config(opt) loader = QuickLoader(train_data, opt.method, train_config, n_threads=opt.threads, augmentation=opt.augment) it = loader.make_one_shot_iterator(opt.memory_limit, shuffle=True) with tqdm.tqdm(it, unit='batch', ascii=True) as r: for items in r: label, feature, names = items[:3] # label is usually HR image, feature is usually LR image batch_label = np.split(label, label.shape[0]) batch_feature = np.split(feature, feature.shape[0]) batch_name = np.split(names, names.shape[0]) for hr, lr, name in zip(batch_label, batch_feature, batch_name): hr = np.squeeze(hr) lr = np.squeeze(lr) name = np.squeeze(name) with io.BytesIO() as fp: Image.fromarray(hr, 'RGB').save(fp, format='png') fp.seek(0) hr_png = fp.read() with io.BytesIO() as fp: Image.fromarray(lr, 'RGB').save(fp, format='png') fp.seek(0) lr_png = fp.read() lr_post = process(lr, crf_matrix, (opt.sigma[0], opt.sigma[1])) with io.BytesIO() as fp: if opt.jpeg_quality: Image.fromarray(lr_post, 'RGB').save(fp, format='jpeg', quality=opt.jpeg_quality) else: Image.fromarray(lr_post, 'RGB').save(fp, format='png') fp.seek(0) post_png = fp.read() label = "{}_{}_{}".format(*name).encode() make_tensor_label_records( [hr_png, lr_png, label, post_png], ["image/hr", "image/lr", "name", "image/post"], writer)