def _get_test_or_val_set(conf, data_dir, fold_dir): """Gets validation or training dataset Each entry consists of input, the low resolution center crop of an image, and target, the center crop in high resolution. If `full_image` is set to true in conf, then the full image is used instead of a crop. Parameters ---------- conf : Configuration Configuration expecting the following keys: - `test_crop_size`: Size of the crop to extract - `upscale_factor`: Downscale the crop by this factor - `scale_to_orig`: If true, scale the crop up to crop size again after downscaling to preserve same size between inputs and targets - `grayscale`: If true, converts images to grayscale data_dir : string Path to top level data folder fold_dir : string Subdirectory of the fold to use Returns : torch.utils.data.Dataset ------- Instance of validation or test dataset """ fold_dir = os.path.join(data_dir, DATASET_DIR, fold_dir) transform = get_sr_transform(conf, mode='test') return make_sr_dataset_from_folder(conf, fold_dir, transform, name=DATASET_DIR)
def get_train_set(conf, data_dir): """Get training dataset Each entry consists of input, a random low resolution random crop, and target, the random crop in high resolution. Parameters ---------- conf : Configuration Configuration expecting the following keys: - `train_crop_size`: Size of the random crop to extract - `upscale_factor`: Downscale the random crop by this factor - `scale_to_orig`: If true, scale random crop up to crop size again after downscaling to preserve same size between inputs and targets - `grayscale`: If true, converts images to grayscale data_dir : string Path to top level data folder Returns : torch.utils.data.Dataset ------- Instance of training dataset """ train_dir = os.path.join(data_dir, DATASET_DIR, TRAIN_DIR) transform = get_sr_transform(conf, mode='train') return make_sr_dataset_from_folder(conf, train_dir, transform, name=DATASET_DIR)
def _get_test_or_val_set(conf, data_dir): """Gets validation or training dataset Each entry consists of input, the low resolution center crop of an image, and target, the center crop in high resolution. If `full_image` is set to true in conf, then the full image is used instead of a crop. Parameters ---------- conf : Configuration Configuration expecting the following keys: - `test_crop_size`: Size of the crop to extract - `upscale_factor`: Downscale the crop by this factor - `scale_to_orig`: If true, scale the crop up to crop size again after downscaling to preserve same size between inputs and targets - `grayscale`: If true, converts images to grayscale data_dir : string Path to top level data folder Returns : torch.utils.data.Dataset ------- Instance of validation or test dataset """ dataset_path = os.path.join(data_dir, DATASET_DIR) hr_images = _get_images(dataset_path, scale=2, mode='HR') if 2 <= conf.upscale_factor <= 4: # Use downscaled images from dataset, if available sr_images = _get_images(dataset_path, scale=conf.upscale_factor, mode='LR') transform = get_sr_transform(conf, mode='test', downscale=False) else: # Need to downscale images ourselves sr_images = None transform = get_sr_transform(conf, mode='test', downscale=True) dataset = SRDatasetFromImagePaths(hr_images, sr_images, grayscale=conf.get_attr('grayscale', default=False), luma=conf.get_attr('luma', default=False), transform=transform, name=DATASET_DIR) return dataset
def main(argv): args = parser.parse_args(argv) if args.cuda != '': try: args.cuda = utils.set_cuda_env(args.cuda) except Exception: print('No free GPU on this machine. Aborting run.') return print('Running on GPU {}'.format(args.cuda)) # Load configuration conf = Configuration.from_json(args.config) conf.args = args if args.conf: new_conf_entries = {} for arg in args.conf: key, value = arg.split('=') new_conf_entries[key] = value conf.update(new_conf_entries) if args.verbose: print(conf) utils.set_random_seeds(conf.seed) # Setup model runner = build_runner(conf, conf.runner_type, args.cuda, mode='test') # Handle resuming from checkpoint if args.checkpoint != 'NONE': if os.path.exists(args.checkpoint): _ = restore_checkpoint(args.checkpoint, runner, cuda=args.cuda) print('Restored checkpoint from {}'.format(args.checkpoint)) else: print('Checkpoint {} to restore from not found'.format(args.checkpoint)) return # Evaluate on full image, not crops conf.full_image = True # Load datasets mode = 'dataset' if len(args.files_or_dirs) == 0: datasets = [load_dataset(conf, args.data_dir, conf.validation_dataset, args.fold)] else: datasets = [] for f in args.files_or_dirs: if is_dataset(f): dataset = load_dataset(conf, args.data_dir, f, args.fold) datasets.append(dataset) else: mode = 'image' transform = get_sr_transform(conf, 'test', downscale=False) datasets = [make_sr_dataset_from_folder(conf, f, transform, inference=True) for f in args.files_or_dirs] num_workers = conf.get_attr('num_data_workers', default=DEFAULT_NUM_WORKERS) # Evaluate all datasets for dataset in datasets: loader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1, shuffle=False) if mode == 'dataset': data, _, val_metrics = runner.validate(loader, len(loader)) print('Average metrics for {}'.format(dataset.name)) for metric_name, metric in val_metrics.items(): print(' {}: {}'.format(metric_name, metric)) else: data = runner.infer(loader) if args.infer or args.dump: if mode == 'dataset': output_dir = get_run_dir(args.out_dir, dataset.name) if not os.path.isdir(output_dir): os.mkdir(output_dir) file_idx = 0 for batch in data: if mode == 'image': output_dir = os.path.dirname(dataset.images[file_idx]) named_batch = runner.get_named_outputs(batch) inputs = named_batch['input'] predictions = named_batch['prediction'] targets = named_batch['target'] for (inp, target, prediction) in zip(inputs, targets, predictions): image_file = os.path.basename(dataset.images[file_idx]) name, _ = os.path.splitext(image_file) file_idx += 1 if args.dump: input_file = os.path.join(output_dir, '{}_input.png'.format(name)) save_image(inp.data, input_file) target_file = os.path.join(output_dir, '{}_target.png'.format(name)) save_image(target.data, target_file) pred_file = os.path.join(output_dir, '{}_pred.png'.format(name)) save_image(prediction.data, pred_file)