コード例 #1
0
def fetch_datasets(data_config_file, opt):
    all_datasets = load_datasets(data_config_file)
    dataset = all_datasets[opt.dataset.upper()]
    if opt.test:
        test_data = all_datasets[opt.test.upper()]
    else:
        test_data = dataset
    if opt.infer:
        infer_dir = Path(opt.infer)
        if infer_dir.exists():
            # infer files in this directory
            if infer_dir.is_file():
                images = [str(infer_dir)]
            else:
                images = list(infer_dir.glob('*'))
                if not images:
                    images = infer_dir.iterdir()
            infer_data = Dataset(infer=images,
                                 mode='pil-image1',
                                 modcrop=False)
        else:
            infer_data = all_datasets[opt.infer.upper()]
    else:
        infer_data = test_data
    if opt.mode:
        dataset.mode = opt.mode
        test_data.mode = opt.mode
        infer_data.mode = opt.mode
    return dataset, test_data, infer_data
コード例 #2
0
 def test_no_shuffle(self):
     d = Dataset('data/').include('*.png')
     data = d.compile()
     ld = Loader(data, data, threads=4)
     ld.cropper(CenterCrop(1))
     itr1 = ld.make_one_shot_iterator([1, 3, 16, 16], -1, False)
     ret1 = list(itr1)
     self.assertEqual(len(ret1), 16)
     self.assert_psnr(ret1)
     itr2 = ld.make_one_shot_iterator([1, 3, 16, 16], -1, False)
     ret2 = list(itr2)
     self.assertEqual(len(ret2), 16)
     self.assert_psnr(ret2)
     for x, y in zip(ret1, ret2):
         self.assertTrue(np.all((x['hr'] - y['hr']) < 1e-4))
コード例 #3
0
 def test_auto_deduce_shape(self):
     d = Dataset('data').include_reg('set5')
     ld = Loader(d, scale=1)
     itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1)
     ret = list(itr)
     self.assertEqual(len(ret), 5)
     self.assert_psnr(ret)
コード例 #4
0
def test_memory_limit():
    d = Dataset('data/')
    d = d.include('*.png')
    data = d.compile()
    ld = Loader(data, data, threads=4)
    ld.cropper(RandomCrop(1))
    ld.image_augmentation()
    itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True,
                                    data.capacity / 2)
    ret = list(itr)
    assert len(ret) is 10
    assert_psnr(ret)
    itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True,
                                    data.capacity / 2)
    ret = list(itr)
    assert len(ret) is 10
    assert_psnr(ret)
コード例 #5
0
 def test_no_shuffle_limit(self):
     d = Dataset('data/')
     d = d.include('*.png')
     data = d.compile()
     ld = Loader(data, data, threads=4)
     ld.cropper(RandomCrop(1))
     ld.image_augmentation()
     itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, False,
                                     data.capacity / 2)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     self.assert_psnr(ret)
     itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, False,
                                     data.capacity / 2)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     self.assert_psnr(ret)
コード例 #6
0
def fetch_datasets(data_config_file, opt):
    all_datasets = load_datasets(data_config_file)
    dataset = all_datasets[opt.dataset.upper()]
    if opt.test:
        test_data = all_datasets[opt.test.upper()]
    else:
        test_data = dataset
    if opt.infer:
        infer_dir = Path(opt.infer)
        if infer_dir.exists():
            # infer files in this directory
            if infer_dir.is_file():
                images = [str(infer_dir)]
            else:
                images = list(infer_dir.glob('*'))
                if not images:
                    images = infer_dir.iterdir()
            infer_data = Dataset(infer=images,
                                 mode='pil-image1',
                                 modcrop=False)
        else:
            infer_data = all_datasets[opt.infer.upper()]
    else:
        infer_data = test_data
    # TODO temp use, delete if not need
    if opt.cifar:
        cifar_data, cifar_test = tf.keras.datasets.cifar10.load_data()
        dataset = Dataset(**dataset)
        dataset.mode = 'numpy'
        dataset.train = [cifar_data[0]]
        dataset.val = [cifar_test[0]]
        dataset.test = [cifar_test[0]]
        return dataset, dataset, infer_data
    return dataset, test_data, infer_data
コード例 #7
0
def test_load_empty_data():
    d = Dataset('not-found')
    ld = Loader(d, scale=1)
    itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1)
    assert len(list(itr)) is 0
    itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10)
    ret = list(itr)
    assert len(ret) is 10
    assert not ret[0]['hr']
    assert not ret[0]['lr']
    assert not ret[0]['name']
コード例 #8
0
 def test_load_empty_data(self):
     d = Dataset('not-found')
     ld = Loader(d, scale=1)
     itr = ld.make_one_shot_iterator([1, -1, -1, -1], -1)
     self.assertEqual(len(list(itr)), 0)
     itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     self.assertFalse(ret[0]['hr'])
     self.assertFalse(ret[0]['lr'])
     self.assertFalse(ret[0]['name'])
コード例 #9
0
def load_folder(path):
    """loading `path` into a Dataset"""
    if not Path(path).exists():
        raise ValueError("--input_dir can't be found")

    images = list(Path(path).glob('*'))
    images.sort()
    if not images:
        images = list(Path(path).iterdir())
    D = Dataset(test=images)
    return D
コード例 #10
0
 def test_simplest_loader(self):
     d = Dataset('data/set5_x2')
     ld = Loader(d, scale=2, threads=4)
     itr = ld.make_one_shot_iterator([4, 3, 4, 4], 10, True)
     self.assertEqual(len(itr), 10)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True)
     self.assertEqual(len(itr), 10)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     self.assert_psnr(ret)
コード例 #11
0
def test_simplest_loader():
    d = Dataset('data/set5_x2')
    ld = Loader(d, scale=2, threads=4)
    itr = ld.make_one_shot_iterator([4, 3, 4, 4], 10, True)
    assert len(itr) is 10
    ret = list(itr)
    assert len(ret) is 10
    itr = ld.make_one_shot_iterator([4, 3, 16, 16], 10, True)
    assert len(itr) is 10
    ret = list(itr)
    assert len(ret) is 10
    assert_psnr(ret)
コード例 #12
0
def _handle_path(path, sess):
    if path.endswith('.npz'):
        f = np.load(path)
        m, s = f['mu'][:], f['sigma'][:]
        f.close()
    else:
        path = pathlib.Path(path)
        files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + list(
            path.glob('*.bmp'))
        images = Dataset(train=files, patch_size=48, strides=48)
        loader = BatchLoader(50, images, 'train', 1, convert_to='RGB')
        x = np.concatenate([img[0] for img in loader])
        m, s = calculate_activation_statistics(x, sess)
    return m, s
コード例 #13
0
 def test_complex_loader(self):
     d = Dataset('data').use_like_video().include_reg('hr/xiuxian')
     hr = d.compile()
     d = Dataset('data').use_like_video().include_reg('lr/xiuxian')
     lr = d.compile()
     ld = Loader(hr, lr, threads=4)
     ld.image_augmentation()
     ld.cropper(RandomCrop(2))
     itr = ld.make_one_shot_iterator([4, 3, 3, 16, 16], 10, shuffle=True)
     ret = list(itr)
     self.assertEqual(len(ret), 10)
     self.assert_psnr(ret)
コード例 #14
0
    def predict(self, files, mode='pil-image1', depth=1, **kwargs):
        r"""Predict output for frames

        Args:
            files: a list of frames as inputs
            mode: specify file format. `pil-image1` for PIL supported images, or `NV12/YV12/RGB` for raw data
            depth: specify length of sequence of images. 1 for images, >1 for videos
        """

        sess = tf.get_default_session()
        ckpt_last = self._restore_model(sess)
        files = [Path(file) for file in to_list(files)]
        data = Dataset(test=files,
                       mode=mode,
                       depth=depth,
                       modcrop=False,
                       **kwargs)
        loader = QuickLoader(1,
                             data,
                             'test',
                             self.model.scale,
                             -1,
                             crop=None,
                             **kwargs)
        it = loader.make_one_shot_iterator()
        if len(it):
            print('===================================')
            print(f'Predicting model: {self.model.name} by {ckpt_last}')
            print('===================================')
        else:
            return
        for img in tqdm.tqdm(it, 'Infer', ascii=True):
            feature, label, name = img[self.fi], img[self.li], img[-1]
            tf.logging.debug('output: ' + str(name))
            for fn in self.feature_callbacks:
                feature = fn(feature, name=name)
            outputs = self.model.test_batch(feature, None)
            for fn in self.output_callbacks:
                outputs = fn(outputs,
                             input=img[self.fi],
                             label=img[self.li],
                             mode=loader.color_format,
                             name=name)
コード例 #15
0
def main():
  flags, args = parser.parse_known_args()
  opt = Config()
  for pair in flags._get_kwargs():
    opt.setdefault(*pair)
  data_config_file = Path(flags.data_config)
  if not data_config_file.exists():
    raise RuntimeError("dataset config file doesn't exist!")
  for _ext in ('json', 'yaml', 'yml'):  # for compat
    # apply a 2-stage (or master-slave) configuration, master can be
    # override by slave
    model_config_root = Path('Parameters/root.{}'.format(_ext))
    if opt.p:
      model_config_file = Path(opt.p)
    else:
      model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext))
    if model_config_root.exists():
      opt.update(Config(str(model_config_root)))
    if model_config_file.exists():
      opt.update(Config(str(model_config_file)))

  model_params = opt.get(opt.model, {})
  suppress_opt_by_args(model_params, *args)
  opt.update(model_params)
  model = get_model(flags.model)(**model_params)
  if flags.cuda:
    model.cuda()
  root = f'{flags.save_dir}/{flags.model}'
  if flags.comment:
    root += '_' + flags.comment
  verbosity = logging.DEBUG if flags.verbose else logging.INFO
  trainer = model.trainer

  datasets = load_datasets(data_config_file)
  try:
    test_datas = [datasets[t.upper()] for t in flags.test]
    run_benchmark = True
  except KeyError:
    test_datas = []
    for pattern in flags.test:
      test_data = Dataset(test=_glob_absolute_pattern(pattern),
                          mode='pil-image1', modcrop=False)
      father = Path(flags.test)
      while not father.is_dir():
        if father.parent == father:
          break
        father = father.parent
      test_data.name = father.stem
      test_datas.append(test_data)
    run_benchmark = False

  if opt.verbose:
    dump(opt)
  for test_data in test_datas:
    loader_config = Config(convert_to='rgb',
                           feature_callbacks=[], label_callbacks=[],
                           output_callbacks=[], **opt)
    loader_config.batch = 1
    loader_config.subdir = test_data.name
    loader_config.output_callbacks += [
      save_image(root, flags.output_index, flags.auto_rename)]
    if opt.channel == 1:
      loader_config.convert_to = 'gray'

    with trainer(model, root, verbosity, flags.pth) as t:
      if flags.seed is not None:
        t.set_seed(flags.seed)
      loader = QuickLoader(test_data, 'test', loader_config,
                           n_threads=flags.thread)
      loader_config.epoch = flags.epoch
      if run_benchmark:
        t.benchmark(loader, loader_config)
      else:
        t.infer(loader, loader_config)
コード例 #16
0
def test_image_data():
    d = Dataset('data/set5_x2')
    data = d.compile()
    assert len(data) is 5
    assert data.capacity == 983040
コード例 #17
0
def test_include_exclude():
    d = Dataset('data')
    d.include_('xiuxian*')
    data1 = d.compile()
    d = Dataset('data')
    d.include_reg_('set5')
    data2 = d.compile()
    d = Dataset('data').include_reg('set5').exclude('png')
    data3 = d.compile()

    assert len(data1) is 6
    assert len(data2) is 5
    assert len(data3) is 0
コード例 #18
0
 def test_video_data(self):
     d = Dataset('data/video/custom_pair').use_like_video()
     data = d.compile()
     self.assertEqual(len(data), 2)
コード例 #19
0
def test_video_data():
    d = Dataset('data/video/custom_pair').use_like_video()
    data = d.compile()
    assert len(data) is 2
コード例 #20
0
    def test_include_exclude(self):
        d = Dataset('data')
        d.include_('xiuxian*')
        data1 = d.compile()
        d = Dataset('data')
        d.include_reg_('set5')
        data2 = d.compile()
        d = Dataset('data').include_reg('set5').exclude('png')
        data3 = d.compile()

        self.assertEqual(len(data1), 6)
        self.assertEqual(len(data2), 5)
        self.assertEqual(len(data3), 0)
コード例 #21
0
 def test_multi_url(self):
     d = Dataset('data/set5_x2', 'data/kitti_car')
     data = d.compile()
     self.assertEqual(len(data), 8)
コード例 #22
0
def test_multi_url():
    d = Dataset('data/set5_x2', 'data/kitti_car')
    data = d.compile()
    assert len(data) is 8
コード例 #23
0
 def test_image_data(self):
     d = Dataset('data/set5_x2')
     data = d.compile()
     self.assertEqual(len(data), 5)
     self.assertEqual(data.capacity, 983040)