def fetch_datasets(data_config_file, opt): all_datasets = load_datasets(data_config_file) dataset = all_datasets[opt.dataset.upper()] if opt.test: test_data = all_datasets[opt.test.upper()] else: test_data = dataset if opt.infer: infer_dir = Path(opt.infer) if infer_dir.exists(): # infer files in this directory if infer_dir.is_file(): images = [str(infer_dir)] else: images = list(infer_dir.glob('*')) if not images: images = infer_dir.iterdir() infer_data = Dataset(infer=images, mode='pil-image1', modcrop=False) else: infer_data = all_datasets[opt.infer.upper()] else: infer_data = test_data if opt.mode: dataset.mode = opt.mode test_data.mode = opt.mode infer_data.mode = opt.mode return dataset, test_data, infer_data
def test_no_shuffle(self): d = Dataset('data/').include('*.png') data = d.compile() ld = Loader(data, data, threads=4) ld.cropper(CenterCrop(1)) itr1 = ld.make_one_shot_iterator([1, 3, 16, 16], -1, False) ret1 = list(itr1) self.assertEqual(len(ret1), 16) self.assert_psnr(ret1) itr2 = ld.make_one_shot_iterator([1, 3, 16, 16], -1, False) ret2 = list(itr2) self.assertEqual(len(ret2), 16) self.assert_psnr(ret2) for x, y in zip(ret1, ret2): self.assertTrue(np.all((x['hr'] - y['hr']) < 1e-4))
def test_auto_deduce_shape(self): d = Dataset('data').include_reg('set5') ld = Loader(d, scale=1) itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1) ret = list(itr) self.assertEqual(len(ret), 5) self.assert_psnr(ret)
def test_memory_limit(): d = Dataset('data/') d = d.include('*.png') data = d.compile() ld = Loader(data, data, threads=4) ld.cropper(RandomCrop(1)) ld.image_augmentation() itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True, data.capacity / 2) ret = list(itr) assert len(ret) is 10 assert_psnr(ret) itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True, data.capacity / 2) ret = list(itr) assert len(ret) is 10 assert_psnr(ret)
def test_no_shuffle_limit(self): d = Dataset('data/') d = d.include('*.png') data = d.compile() ld = Loader(data, data, threads=4) ld.cropper(RandomCrop(1)) ld.image_augmentation() itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, False, data.capacity / 2) ret = list(itr) self.assertEqual(len(ret), 10) self.assert_psnr(ret) itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, False, data.capacity / 2) ret = list(itr) self.assertEqual(len(ret), 10) self.assert_psnr(ret)
def fetch_datasets(data_config_file, opt): all_datasets = load_datasets(data_config_file) dataset = all_datasets[opt.dataset.upper()] if opt.test: test_data = all_datasets[opt.test.upper()] else: test_data = dataset if opt.infer: infer_dir = Path(opt.infer) if infer_dir.exists(): # infer files in this directory if infer_dir.is_file(): images = [str(infer_dir)] else: images = list(infer_dir.glob('*')) if not images: images = infer_dir.iterdir() infer_data = Dataset(infer=images, mode='pil-image1', modcrop=False) else: infer_data = all_datasets[opt.infer.upper()] else: infer_data = test_data # TODO temp use, delete if not need if opt.cifar: cifar_data, cifar_test = tf.keras.datasets.cifar10.load_data() dataset = Dataset(**dataset) dataset.mode = 'numpy' dataset.train = [cifar_data[0]] dataset.val = [cifar_test[0]] dataset.test = [cifar_test[0]] return dataset, dataset, infer_data return dataset, test_data, infer_data
def test_load_empty_data(): d = Dataset('not-found') ld = Loader(d, scale=1) itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1) assert len(list(itr)) is 0 itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10) ret = list(itr) assert len(ret) is 10 assert not ret[0]['hr'] assert not ret[0]['lr'] assert not ret[0]['name']
def test_load_empty_data(self): d = Dataset('not-found') ld = Loader(d, scale=1) itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1) self.assertEqual(len(list(itr)), 0) itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10) ret = list(itr) self.assertEqual(len(ret), 10) self.assertFalse(ret[0]['hr']) self.assertFalse(ret[0]['lr']) self.assertFalse(ret[0]['name'])
def load_folder(path): """loading `path` into a Dataset""" if not Path(path).exists(): raise ValueError("--input_dir can't be found") images = list(Path(path).glob('*')) images.sort() if not images: images = list(Path(path).iterdir()) D = Dataset(test=images) return D
def test_simplest_loader(self): d = Dataset('data/set5_x2') ld = Loader(d, scale=2, threads=4) itr = ld.make_one_shot_iterator([4, 3, 4, 4], 10, True) self.assertEqual(len(itr), 10) ret = list(itr) self.assertEqual(len(ret), 10) itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True) self.assertEqual(len(itr), 10) ret = list(itr) self.assertEqual(len(ret), 10) self.assert_psnr(ret)
def test_simplest_loader(): d = Dataset('data/set5_x2') ld = Loader(d, scale=2, threads=4) itr = ld.make_one_shot_iterator([4, 3, 4, 4], 10, True) assert len(itr) is 10 ret = list(itr) assert len(ret) is 10 itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True) assert len(itr) is 10 ret = list(itr) assert len(ret) is 10 assert_psnr(ret)
def _handle_path(path, sess): if path.endswith('.npz'): f = np.load(path) m, s = f['mu'][:], f['sigma'][:] f.close() else: path = pathlib.Path(path) files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + list( path.glob('*.bmp')) images = Dataset(train=files, patch_size=48, strides=48) loader = BatchLoader(50, images, 'train', 1, convert_to='RGB') x = np.concatenate([img[0] for img in loader]) m, s = calculate_activation_statistics(x, sess) return m, s
def test_complex_loader(self): d = Dataset('data').use_like_video().include_reg('hr/xiuxian') hr = d.compile() d = Dataset('data').use_like_video().include_reg('lr/xiuxian') lr = d.compile() ld = Loader(hr, lr, threads=4) ld.image_augmentation() ld.cropper(RandomCrop(2)) itr = ld.make_one_shot_iterator([4, 3, 3, 16, 16], 10, shuffle=True) ret = list(itr) self.assertEqual(len(ret), 10) self.assert_psnr(ret)
def predict(self, files, mode='pil-image1', depth=1, **kwargs): r"""Predict output for frames Args: files: a list of frames as inputs mode: specify file format. `pil-image1` for PIL supported images, or `NV12/YV12/RGB` for raw data depth: specify length of sequence of images. 1 for images, >1 for videos """ sess = tf.get_default_session() ckpt_last = self._restore_model(sess) files = [Path(file) for file in to_list(files)] data = Dataset(test=files, mode=mode, depth=depth, modcrop=False, **kwargs) loader = QuickLoader(1, data, 'test', self.model.scale, -1, crop=None, **kwargs) it = loader.make_one_shot_iterator() if len(it): print('===================================') print(f'Predicting model: {self.model.name} by {ckpt_last}') print('===================================') else: return for img in tqdm.tqdm(it, 'Infer', ascii=True): feature, label, name = img[self.fi], img[self.li], img[-1] tf.logging.debug('output: ' + str(name)) for fn in self.feature_callbacks: feature = fn(feature, name=name) outputs = self.model.test_batch(feature, None) for fn in self.output_callbacks: outputs = fn(outputs, input=img[self.fi], label=img[self.li], mode=loader.color_format, name=name)
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path('Parameters/root.{}'.format(_ext)) if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) model = get_model(flags.model)(**model_params) if flags.cuda: model.cuda() root = f'{flags.save_dir}/{flags.model}' if flags.comment: root += '_' + flags.comment verbosity = logging.DEBUG if flags.verbose else logging.INFO trainer = model.trainer datasets = load_datasets(data_config_file) try: test_datas = [datasets[t.upper()] for t in flags.test] run_benchmark = True except KeyError: test_datas = [] for pattern in flags.test: test_data = Dataset(test=_glob_absolute_pattern(pattern), mode='pil-image1', modcrop=False) father = Path(flags.test) while not father.is_dir(): if father.parent == father: break father = father.parent test_data.name = father.stem test_datas.append(test_data) run_benchmark = False if opt.verbose: dump(opt) for test_data in test_datas: loader_config = Config(convert_to='rgb', feature_callbacks=[], label_callbacks=[], output_callbacks=[], **opt) loader_config.batch = 1 loader_config.subdir = test_data.name loader_config.output_callbacks += [ save_image(root, flags.output_index, flags.auto_rename)] if opt.channel == 1: loader_config.convert_to = 'gray' with trainer(model, root, verbosity, flags.pth) as t: if flags.seed is not None: t.set_seed(flags.seed) loader = QuickLoader(test_data, 'test', loader_config, n_threads=flags.thread) loader_config.epoch = flags.epoch if run_benchmark: t.benchmark(loader, loader_config) else: t.infer(loader, loader_config)
def test_image_data(): d = Dataset('data/set5_x2') data = d.compile() assert len(data) is 5 assert data.capacity == 983040
def test_include_exclude(): d = Dataset('data') d.include_('xiuxian*') data1 = d.compile() d = Dataset('data') d.include_reg_('set5') data2 = d.compile() d = Dataset('data').include_reg('set5').exclude('png') data3 = d.compile() assert len(data1) is 6 assert len(data2) is 5 assert len(data3) is 0
def test_video_data(self): d = Dataset('data/video/custom_pair').use_like_video() data = d.compile() self.assertEqual(len(data), 2)
def test_video_data(): d = Dataset('data/video/custom_pair').use_like_video() data = d.compile() assert len(data) is 2
def test_include_exclude(self): d = Dataset('data') d.include_('xiuxian*') data1 = d.compile() d = Dataset('data') d.include_reg_('set5') data2 = d.compile() d = Dataset('data').include_reg('set5').exclude('png') data3 = d.compile() self.assertEqual(len(data1), 6) self.assertEqual(len(data2), 5) self.assertEqual(len(data3), 0)
def test_multi_url(self): d = Dataset('data/set5_x2', 'data/kitti_car') data = d.compile() self.assertEqual(len(data), 8)
def test_multi_url(): d = Dataset('data/set5_x2', 'data/kitti_car') data = d.compile() assert len(data) is 8
def test_image_data(self): d = Dataset('data/set5_x2') data = d.compile() self.assertEqual(len(data), 5) self.assertEqual(data.capacity, 983040)