示例#1
0
def test_glob_absolute_pattern():
    URL = './data/set5_x2'
    node = _glob_absolute_pattern(URL)
    assert len(node) == 5
    assert node[0].match('img_001_SRF_2_LR.png')
    assert node[1].match('img_002_SRF_2_LR.png')
    assert node[2].match('img_003_SRF_2_LR.png')
    assert node[3].match('img_004_SRF_2_LR.png')
    assert node[4].match('img_005_SRF_2_LR.png')

    URL = './data'
    node = _glob_absolute_pattern(URL)
    assert len(node) == 3
    assert node[0].match('flying_chair')
    assert node[1].match('kitti_car')
    assert node[2].match('set5_x2')

    URL = './data/flying_chair/*.flo'
    node = _glob_absolute_pattern(URL)
    assert len(node) == 1
    assert node[0].match('0-gt.flo')

    URL = './data/**/*.png'
    node = _glob_absolute_pattern(URL)
    assert len(node) == 10
示例#2
0
def main():
  flags, args = parser.parse_known_args()
  opt = Config()
  for pair in flags._get_kwargs():
    opt.setdefault(*pair)
  data_config_file = Path(flags.data_config)
  if not data_config_file.exists():
    raise RuntimeError("dataset config file doesn't exist!")
  for _ext in ('json', 'yaml', 'yml'):  # for compat
    # apply a 2-stage (or master-slave) configuration, master can be
    # override by slave
    model_config_root = Path('Parameters/root.{}'.format(_ext))
    if opt.p:
      model_config_file = Path(opt.p)
    else:
      model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext))
    if model_config_root.exists():
      opt.update(Config(str(model_config_root)))
    if model_config_file.exists():
      opt.update(Config(str(model_config_file)))

  model_params = opt.get(opt.model, {})
  suppress_opt_by_args(model_params, *args)
  opt.update(model_params)
  model = get_model(flags.model)(**model_params)
  if flags.cuda:
    model.cuda()
  root = f'{flags.save_dir}/{flags.model}'
  if flags.comment:
    root += '_' + flags.comment
  verbosity = logging.DEBUG if flags.verbose else logging.INFO
  trainer = model.trainer

  datasets = load_datasets(data_config_file)
  try:
    test_datas = [datasets[t.upper()] for t in flags.test]
    run_benchmark = True
  except KeyError:
    test_datas = []
    for pattern in flags.test:
      test_data = Dataset(test=_glob_absolute_pattern(pattern),
                          mode='pil-image1', modcrop=False)
      father = Path(flags.test)
      while not father.is_dir():
        if father.parent == father:
          break
        father = father.parent
      test_data.name = father.stem
      test_datas.append(test_data)
    run_benchmark = False

  if opt.verbose:
    dump(opt)
  for test_data in test_datas:
    loader_config = Config(convert_to='rgb',
                           feature_callbacks=[], label_callbacks=[],
                           output_callbacks=[], **opt)
    loader_config.batch = 1
    loader_config.subdir = test_data.name
    loader_config.output_callbacks += [
      save_image(root, flags.output_index, flags.auto_rename)]
    if opt.channel == 1:
      loader_config.convert_to = 'gray'

    with trainer(model, root, verbosity, flags.pth) as t:
      if flags.seed is not None:
        t.set_seed(flags.seed)
      loader = QuickLoader(test_data, 'test', loader_config,
                           n_threads=flags.thread)
      loader_config.epoch = flags.epoch
      if run_benchmark:
        t.benchmark(loader, loader_config)
      else:
        t.infer(loader, loader_config)