Exemplo n.º 1
0
    def comput_PSNR_SSIM(self, pred, gt, shave_border=0):

        if isinstance(pred, torch.Tensor):
            pred = tensor2np(pred, self.opt.rgb_range)
            pred = pred.astype(np.float32)

        if isinstance(gt, torch.Tensor):
            gt = tensor2np(gt, self.opt.rgb_range)
            gt = gt.astype(np.float32)

        height, width = pred.shape[:2]
        pred = pred[shave_border:height - shave_border,
                    shave_border:width - shave_border]
        gt = gt[shave_border:height - shave_border,
                shave_border:width - shave_border]

        if pred.shape[2] == 3 and gt.shape[2] == 3:
            pred_y = rgb2ycbcr(pred)[:, :, 0]
            gt_y = rgb2ycbcr(gt)[:, :, 0]
        elif pred.shape[2] == 1 and gt.shape[2] == 1:
            pred_y = pred[:, :, 0]
            gt_y = gt[:, :, 0]
        else:
            raise ValueError('Input or output channel is not 1 or 3!')

        psnr_ = calc_PSNR(pred_y, gt_y)
        ssim_ = calc_ssim(pred_y, gt_y)

        return psnr_, ssim_
Exemplo n.º 2
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)), False)
        self.model.eval()

        # We can use custom forward function 
        def _test_forward(x, scale):
            if self.args.self_ensemble:
                return utils.x8_forward(x, self.model, self.args.precision)
            elif self.args.chop_forward:
                return utils.chop_forward(x, self.model, scale)
            else:
                return self.model(x)

        timer_test = utils.timer()
        set_name = type(self.loader_test.dataset).__name__
        for idx_scale, scale in enumerate(self.scale):
            eval_acc = 0
            self._scale_change(idx_scale, self.loader_test)
            for idx_img, (input, target, _) in enumerate(self.loader_test):
                input, target = self._prepare(input, target, volatile=True)
                output = _test_forward(input, scale)
                eval_acc += utils.calc_PSNR(
                    output, target, set_name, self.args.rgb_range, scale)
                self.ckp.save_results(idx_img, input, output, target, scale)

            self.ckp.log_test[-1, idx_scale] = eval_acc / len(self.loader_test)
            best = self.ckp.log_test.max(0)
            performance = 'PSNR: {:.3f}'.format(
                self.ckp.log_test[-1, idx_scale])
            self.ckp.write_log(
                '[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
                    set_name,
                    scale,
                    performance,
                    best[0][idx_scale],
                    best[1][idx_scale] + 1))

        if best[1][0] + 1 == epoch:
            is_best = True
        else:
            is_best = False

        self.ckp.write_log(
            'Time: {:.2f}s\n'.format(timer_test.toc()), refresh=True)
        self.ckp.save(self, epoch, is_best=is_best)
Exemplo n.º 3
0
    def test_generator(self, test_num_patch = 200, test_num_image = 5, load = False):
        
        self.generator.eval()
        
        # test for patches
        start = time.time()
        test_list_phone = sorted(glob(self.config.test_path_phone_patch))
        PSNR_phone_enhanced_list = np.zeros([test_num_patch])
        
        indexes = []
        for i in range(test_num_patch):
            index = np.random.randint(len(test_list_phone))
            indexes.append(index)
            test_img = scipy.misc.imread(test_list_phone[index], mode = "RGB").astype("float32")
            test_patch_phone = get_patch(test_img, self.config.patch_size)
            test_patch_phone = preprocess(test_patch_phone)
            
            with torch.no_grad():
                test_patch_phone = torch.from_numpy(np.transpose(test_patch_phone, (2,1,0))).float().unsqueeze(0)
                if torch.cuda.is_available():
                    test_patch_phone = test_patch_phone.cuda()

                test_patch_enhanced = self.generator(test_patch_phone)
            
            test_patch_enhanced = test_patch_enhanced.cpu().data.numpy()
            test_patch_enhanced = np.transpose(test_patch_enhanced.cpu().data.numpy(), (0,2,3,1))
            test_patch_phone = np.transpose(test_patch_phone.cpu().data.numpy(), (0,2,3,1))

            if i % 50 == 0:
                imageio.imwrite(("%s/phone_%d.png" %(self.result_img_dir, i)), postprocess(test_patch_phone[0]))
                imageio.imwrite(("%s/enhanced_%d.png" %(self.result_img_dir,i)), postprocess(test_patch_enhanced[0]))

            PSNR = calc_PSNR(postprocess(test_patch_enhanced[0]), postprocess(test_patch_phone))
            PSNR_phone_enhanced_list[i] = PSNR

        print("(runtime: %.3f s) Average test PSNR for %d random test image patches: phone-enhanced %.3f" %(time.time()-start, test_num_patch, np.mean(PSNR_phone_enhanced_list)))
        
        # test for images
        start = time.time()
        test_list_phone = sorted(glob(self.config.test_path_phone_image))
        PSNR_phone_enhanced_list = np.zeros([test_num_image])

        indexes = []
        for i in range(test_num_image):
            index = i
            indexes.append(index)
            
            test_image_phone = preprocess(scipy.misc.imread(test_list_phone[index], mode = "RGB").astype("float32"))
            
            with torch.no_grad():
                test_image_phone = torch.from_numpy(np.transpose(test_image_phone, (2,1,0))).float().unsqueeze(0)
                if torch.cuda.is_available():
                    test_image_phone = test_image_phone.cuda()

                test_image_enhanced = self.generator(test_image_phone)
            
            test_image_enhanced = test_image_enhanced.cpu().data.numpy()
            test_image_enhanced = np.transpose(test_image_enhanced.cpu().data.numpy(), (0,2,3,1))
            test_image_phone = np.transpose(test_image_phone.cpu().data.numpy(), (0,2,3,1))
                        
            imageio.imwrite(("%s/phone_%d.png" %(self.sample_dir, i)), postprocess(test_image_phone[0]))
            imageio.imwrite(("%s/enhanced_%d.png" %(self.sample_dir, i)), postprocess(test_image_enhanced[0]))
            
            PSNR = calc_PSNR(postprocess(test_image_enhanced[0]), postprocess(test_image_phone[0]))
            PSNR_phone_enhanced_list[i] = PSNR
            
        if test_num_image > 0:
            print("(runtime: %.3f s) Average test PSNR for %d random full test images: original-enhanced %.3f" %(time.time()-start, test_num_image, np.mean(PSNR_phone_enhanced_list)))
Exemplo n.º 4
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)), False)
        self.model.eval()

        # We can use custom forward function
        def _test_forward(x, scale):
            if self.args.self_ensemble:
                return utils.x8_forward(x, self.model, self.args.precision)
            elif self.args.chop_forward:
                return utils.chop_forward(x, self.model, scale)
            else:
                return self.model(x)

        timer_test = utils.timer()
        set_name = type(self.loader_test.dataset).__name__

        ## creation of log file
        old_stdout = sys.stdout
        TooLargeImageslog_file = open("TooLargeImages.log", "w")

        for idx_scale, scale in enumerate(self.scale):
            eval_acc = 0
            validation_loss = 0
            self._scale_change(idx_scale, self.loader_test)
            for idx_img, (input, target, _) in enumerate(self.loader_test):
                input, target = self._prepare(input, target, volatile=True)
                try:
                    output = _test_forward(input, scale)
                    eval_acc += utils.calc_PSNR(output, target, set_name,
                                                self.args.rgb_range, scale)
                    self.ckp.save_results(idx_img, input, output, target,
                                          scale)
                    validation_loss += self._calc_loss_for_validation(
                        output, target)
                except:

                    sys.stdout = TooLargeImageslog_file
                    print(
                        "image: ", idx_img + 1 + self.args.n_train
                    )  # note: idx_img + 1 + self.args.n_train (used to be just idx_img)
                    sys.stdout = old_stdout

            self.ckp.log_test_loss[-1, idx_scale] = validation_loss / len(
                self.loader_test)
            self.ckp.log_test[-1, idx_scale] = eval_acc / len(self.loader_test)
            best = self.ckp.log_test.max(0)
            performance = 'PSNR: {:.3f}'.format(self.ckp.log_test[-1,
                                                                  idx_scale])
            self.ckp.write_log(
                '[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
                    set_name, scale, performance, best[0][idx_scale],
                    best[1][idx_scale] + 1))

        ## CLOSING LOG FILE ##
        TooLargeImageslog_file.close()
        ######################

        if best[1][0] + 1 == epoch:
            is_best = True
        else:
            is_best = False

        self.ckp.write_log('Time: {:.2f}s\n'.format(timer_test.toc()),
                           refresh=True)
        self.ckp.save(
            self, epoch, is_best=is_best
        )  #this is where we save (in utils) , in save function we also plot.
Exemplo n.º 5
0
def validation(valid_path, result_path, model, scale):
    model.eval()
    count = 0
    PSNR = 0
    SSIM = 0
    # RGB mean for ImageNet
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                       std=[4.367, 4.464, 4.444])

    filepath = valid_path.split('_LR')[0]
    file = os.listdir(filepath)
    lr_file = os.listdir(valid_path + '/X' + str(scale))
    if lr_file[0].find('jpeg') != -1:
        lr_type = '.jpeg'
    elif lr_file[0].find('jpg') != -1:
        lr_type = '.jpg'
    elif lr_file[0].find('bmp') != -1:
        lr_type = '.bmp'
    else:
        lr_type = '.png'

    if file[0].find('jpeg') != -1:
        hr_type = '.jpeg'
    elif file[0].find('jpg') != -1:
        hr_type = '.jpg'
    elif file[0].find('bmp') != -1:
        hr_type = '.bmp'
    else:
        hr_type = '.png'
    file.sort()
    length = file.__len__()
    with torch.no_grad():
        with tqdm(total=length) as pbar:
            for idx_img in range(length):
                time.sleep(0.01)
                pbar.update(1)
                img_name = file[idx_img].split(hr_type)[0]
                img_hr_rgb = imageio.imread(filepath + '/' + img_name +
                                            hr_type)
                img_lr_rgb = imageio.imread(valid_path + '/X' + str(scale) +
                                            '/' + img_name + 'x' + str(scale) +
                                            lr_type)
                img_lr_rgb, img_hr_rgb = common.set_channel(
                    img_lr_rgb, img_hr_rgb, 3)
                img_lr_rgb, img_hr_rgb = common.np2Tensor(
                    img_lr_rgb, img_hr_rgb, 255)

                img_lr_rgb = normalize(img_lr_rgb)  # Normalize
                # img_hr_rgb = normalize(img_hr_rgb)

                img_lr_rgb = Variable(img_lr_rgb).view(1, img_lr_rgb.shape[0],
                                                       img_lr_rgb.shape[1],
                                                       img_lr_rgb.shape[2])
                img_hr_rgb = Variable(img_hr_rgb).view(1, img_hr_rgb.shape[0],
                                                       img_hr_rgb.shape[1],
                                                       img_hr_rgb.shape[2])

                img_lr_rgb = img_lr_rgb.cuda()
                # SR = F.interpolate(img_lr_rgb, scale_factor=scale)
                SR = model(img_lr_rgb)
                # SR = model(img_lr_rgb, scale)
                SR = unnormalize(SR.data[0].cpu())
                # plt.figure()
                # plt.subplot(1,3, 1)
                # plt.imshow(img_lr_rgb.data[0].cpu().numpy().transpose(1,2,0))
                # plt.subplot(1,3, 2)
                # plt.imshow(SR.numpy().transpose(1,2,0))
                # plt.subplot(1,3, 3)
                # plt.imshow(img_hr_rgb.data[0].cpu().numpy().transpose(1,2,0))
                # plt.show()

                PSNR += utils.calc_PSNR(SR,
                                        img_hr_rgb.data[0],
                                        rgb_range=255,
                                        shave=scale)
                SSIM += utils.calc_SSIM(SR,
                                        img_hr_rgb.data[0],
                                        rgb_range=255,
                                        shave=scale)
                count = count + 1
                result = SR.mul(255).clamp(0, 255).round()
                result = result.numpy().astype(np.uint8)
                result = result.transpose((1, 2, 0))
                result = Image.fromarray(result)
                result.save(result_path + '/' + img_name + '_DeFiAN_x' +
                            str(scale) + '.png')

    Avg_PSNR = PSNR / count
    Avg_SSIM = SSIM / count

    return Avg_PSNR, Avg_SSIM