コード例 #1
0
def _get_test_or_val_set(conf, data_dir, fold_dir):
    """Gets validation or training dataset

  Each entry consists of input, the low resolution center crop of an image, and
  target, the center crop in high resolution. If `full_image` is set to true
  in conf, then the full image is used instead of a crop.

  Parameters
  ----------
  conf : Configuration
    Configuration expecting the following keys:
      - `test_crop_size`: Size of the crop to extract
      - `upscale_factor`: Downscale the crop by this factor
      - `scale_to_orig`: If true, scale the crop up to crop size again after
        downscaling to preserve same size between inputs and targets
      - `grayscale`: If true, converts images to grayscale
  data_dir : string
    Path to top level data folder
  fold_dir : string
    Subdirectory of the fold to use

  Returns : torch.utils.data.Dataset
  -------
    Instance of validation or test dataset
  """
    fold_dir = os.path.join(data_dir, DATASET_DIR, fold_dir)
    transform = get_sr_transform(conf, mode='test')
    return make_sr_dataset_from_folder(conf,
                                       fold_dir,
                                       transform,
                                       name=DATASET_DIR)
コード例 #2
0
def get_train_set(conf, data_dir):
    """Get training dataset

  Each entry consists of input, a random low resolution random crop, and
  target, the random crop in high resolution.

  Parameters
  ----------
  conf : Configuration
    Configuration expecting the following keys:
      - `train_crop_size`: Size of the random crop to extract
      - `upscale_factor`: Downscale the random crop by this factor
      - `scale_to_orig`: If true, scale random crop up to crop size again after
        downscaling to preserve same size between inputs and targets
      - `grayscale`: If true, converts images to grayscale
  data_dir : string
    Path to top level data folder

  Returns : torch.utils.data.Dataset
  -------
    Instance of training dataset
  """
    train_dir = os.path.join(data_dir, DATASET_DIR, TRAIN_DIR)
    transform = get_sr_transform(conf, mode='train')
    return make_sr_dataset_from_folder(conf,
                                       train_dir,
                                       transform,
                                       name=DATASET_DIR)
コード例 #3
0
ファイル: __init__.py プロジェクト: liminghao0914/srgan
def _get_test_or_val_set(conf, data_dir):
    """Gets validation or training dataset

  Each entry consists of input, the low resolution center crop of an image, and
  target, the center crop in high resolution. If `full_image` is set to true
  in conf, then the full image is used instead of a crop.

  Parameters
  ----------
  conf : Configuration
    Configuration expecting the following keys:
      - `test_crop_size`: Size of the crop to extract
      - `upscale_factor`: Downscale the crop by this factor
      - `scale_to_orig`: If true, scale the crop up to crop size again after
        downscaling to preserve same size between inputs and targets
      - `grayscale`: If true, converts images to grayscale
  data_dir : string
    Path to top level data folder

  Returns : torch.utils.data.Dataset
  -------
    Instance of validation or test dataset
  """
    dataset_path = os.path.join(data_dir, DATASET_DIR)

    hr_images = _get_images(dataset_path, scale=2, mode='HR')

    if 2 <= conf.upscale_factor <= 4:
        # Use downscaled images from dataset, if available
        sr_images = _get_images(dataset_path,
                                scale=conf.upscale_factor,
                                mode='LR')
        transform = get_sr_transform(conf, mode='test', downscale=False)
    else:
        # Need to downscale images ourselves
        sr_images = None
        transform = get_sr_transform(conf, mode='test', downscale=True)

    dataset = SRDatasetFromImagePaths(hr_images,
                                      sr_images,
                                      grayscale=conf.get_attr('grayscale',
                                                              default=False),
                                      luma=conf.get_attr('luma',
                                                         default=False),
                                      transform=transform,
                                      name=DATASET_DIR)
    return dataset
コード例 #4
0
ファイル: eval.py プロジェクト: liminghao0914/srgan
def main(argv):
  args = parser.parse_args(argv)

  if args.cuda != '':
    try:
      args.cuda = utils.set_cuda_env(args.cuda)
    except Exception:
      print('No free GPU on this machine. Aborting run.')
      return
    print('Running on GPU {}'.format(args.cuda))

  # Load configuration
  conf = Configuration.from_json(args.config)
  conf.args = args
  if args.conf:
    new_conf_entries = {}
    for arg in args.conf:
      key, value = arg.split('=')
      new_conf_entries[key] = value
    conf.update(new_conf_entries)
  if args.verbose:
    print(conf)

  utils.set_random_seeds(conf.seed)

  # Setup model
  runner = build_runner(conf, conf.runner_type, args.cuda, mode='test')

  # Handle resuming from checkpoint
  if args.checkpoint != 'NONE':
    if os.path.exists(args.checkpoint):
      _ = restore_checkpoint(args.checkpoint, runner, cuda=args.cuda)
      print('Restored checkpoint from {}'.format(args.checkpoint))
    else:
      print('Checkpoint {} to restore from not found'.format(args.checkpoint))
      return

  # Evaluate on full image, not crops
  conf.full_image = True

  # Load datasets
  mode = 'dataset'
  if len(args.files_or_dirs) == 0:
    datasets = [load_dataset(conf, args.data_dir, conf.validation_dataset, args.fold)]
  else:
    datasets = []
    for f in args.files_or_dirs:
      if is_dataset(f):
        dataset = load_dataset(conf, args.data_dir, f, args.fold)
        datasets.append(dataset)
      else:
        mode = 'image'
        transform = get_sr_transform(conf, 'test', downscale=False)
        datasets = [make_sr_dataset_from_folder(conf, f, transform,
                                                inference=True)
                    for f in args.files_or_dirs]

  num_workers = conf.get_attr('num_data_workers', default=DEFAULT_NUM_WORKERS)

  # Evaluate all datasets
  for dataset in datasets:
    loader = DataLoader(dataset=dataset,
                        num_workers=num_workers,
                        batch_size=1,
                        shuffle=False)

    if mode == 'dataset':
      data, _, val_metrics = runner.validate(loader, len(loader))

      print('Average metrics for {}'.format(dataset.name))
      for metric_name, metric in val_metrics.items():
        print('     {}: {}'.format(metric_name, metric))
    else:
      data = runner.infer(loader)

    if args.infer or args.dump:
      if mode == 'dataset':
        output_dir = get_run_dir(args.out_dir, dataset.name)
        if not os.path.isdir(output_dir):
          os.mkdir(output_dir)

      file_idx = 0
      for batch in data:
        if mode == 'image':
          output_dir = os.path.dirname(dataset.images[file_idx])

        named_batch = runner.get_named_outputs(batch)
        inputs = named_batch['input']
        predictions = named_batch['prediction']
        targets = named_batch['target']
        for (inp, target, prediction) in zip(inputs, targets, predictions):
          image_file = os.path.basename(dataset.images[file_idx])
          name, _ = os.path.splitext(image_file)
          file_idx += 1

          if args.dump:
            input_file = os.path.join(output_dir,
                                      '{}_input.png'.format(name))
            save_image(inp.data, input_file)
            target_file = os.path.join(output_dir,
                                       '{}_target.png'.format(name))
            save_image(target.data, target_file)
          pred_file = os.path.join(output_dir,
                                   '{}_pred.png'.format(name))
          save_image(prediction.data, pred_file)