Exemplo n.º 1
0
    def __getitem__(self, idx):

        blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
        if len(self.sharp_list) > 0:
            sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
            imgs = [blur, sharp]
        else:
            imgs = [blur]

        if self.mode == 'train':
            imgs = common.crop(*imgs, ps=self.args.patch_size)
            if self.args.augment:
                imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
                imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
        else:
            pass    # deliver test image as is.

        if self.args.gaussain_pyramid:
            imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)

        imgs = common.np2tensor(*imgs)
        relpath = os.path.relpath(self.blur_list[idx], self.subset_root)

        blur = imgs[0]
        sharp = imgs[1] if len(imgs) > 1 else False

        return blur, sharp, idx, relpath
Exemplo n.º 2
0
def main():
    global args, model

    args = parser.parse_args()
    print(args)

    if args.gpu and not torch.cuda.is_available():
        raise Exception("No GPU found!")

    if not os.path.exists(args.test_result):
        os.makedirs(args.test_result)

    if not is_ready(args):
        prepare_data(args)

    cudnn.benchmark = True
    device = torch.device(('cuda:' + args.gpu_id) if args.gpu else 'cpu')

    model = Grad_none.GRAD(feats=args.feats,
                           basic_conv=args.basic_conv,
                           tail_conv=args.tail_conv)
    checkpoint_file = torch.load(args.test_checkpoint)
    model.load_state_dict(checkpoint_file['model'])
    model.eval()
    model = model.to(device)

    psnrs = AverageMeter()

    with tqdm(total=100) as t:
        t.set_description("test")

        for idx in range(0, 100):
            with h5py.File(
                    "{}/DIV2K_np_test_{}.h5".format(args.h5file_dir,
                                                    args.test_sigma),
                    'r') as h5:
                l_image, h_image = h5['l'][str(idx)][()], h5['h'][str(idx)][()]
                l_image = np2tensor(l_image)
                h_image = np2tensor(h_image)

                l_image = l_image.unsqueeze(0)
                h_image = h_image.unsqueeze(0)

                l_image = l_image.to(device)
                h_image = h_image.to(device)

                with torch.no_grad():
                    output = model(l_image)
                    output = quantize(output, [0, 255])
                    psnr = calc_psnr(output, h_image)
                    psnrs.update(psnr.item(), 1)

                if args.test_save:
                    save_image_path = "{}/{:04d}.png".format(
                        args.test_result, idx)
                    output = output.squeeze(0)
                    output = output.data.permute(1, 2, 0)
                    save_image = pil_image.fromarray(
                        output.byte().cpu().numpy())
                    save_image.save(save_image_path)

            t.update(1)

    print("PSNR: {:.4f}".format(psnrs.avg))