Ejemplo n.º 1
0
    def comput_PSNR_SSIM(self, pred, gt, shave_border=0):

        if isinstance(pred, torch.Tensor):
            pred = tensor2np(pred, self.opt.rgb_range)
            pred = pred.astype(np.float32)

        if isinstance(gt, torch.Tensor):
            gt = tensor2np(gt, self.opt.rgb_range)
            gt = gt.astype(np.float32)

        height, width = pred.shape[:2]
        pred = pred[shave_border:height - shave_border,
                    shave_border:width - shave_border]
        gt = gt[shave_border:height - shave_border,
                shave_border:width - shave_border]

        if pred.shape[2] == 3 and gt.shape[2] == 3:
            pred_y = rgb2ycbcr(pred)[:, :, 0]
            gt_y = rgb2ycbcr(gt)[:, :, 0]
        elif pred.shape[2] == 1 and gt.shape[2] == 1:
            pred_y = pred[:, :, 0]
            gt_y = gt[:, :, 0]
        else:
            raise ValueError('Input or output channel is not 1 or 3!')

        psnr_ = calc_PSNR(pred_y, gt_y)
        ssim_ = calc_ssim(pred_y, gt_y)

        return psnr_, ssim_
Ejemplo n.º 2
0
def validation(img, name, save_imgs=False, save_dir=None):
    kernel_generation_net.eval()
    downsampler_net.eval()
    upscale_net.eval()

    kernels, offsets_h, offsets_v = kernel_generation_net(img)
    downscaled_img = downsampler_net(img, kernels, offsets_h, offsets_v,
                                     OFFSET_UNIT)
    downscaled_img = torch.clamp(downscaled_img, 0, 1)
    downscaled_img = torch.round(downscaled_img * 255)

    reconstructed_img = upscale_net(downscaled_img / 255.0)

    img = img * 255
    img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
    img = np.uint8(img)

    reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
    reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(
        0, 2, 3, 1)
    reconstructed_img = np.uint8(reconstructed_img)

    downscaled_img = downscaled_img.data.cpu().numpy().transpose(0, 2, 3, 1)
    downscaled_img = np.uint8(downscaled_img)

    orig_img = img[0, ...].squeeze()
    downscaled_img = downscaled_img[0, ...].squeeze()
    recon_img = reconstructed_img[0, ...].squeeze()

    if save_imgs and save_dir:
        img = Image.fromarray(orig_img)
        img.save(os.path.join(save_dir, name + '_orig.png'))

        img = Image.fromarray(downscaled_img)
        img.save(os.path.join(save_dir, name + '_down.png'))

        img = Image.fromarray(recon_img)
        img.save(os.path.join(save_dir, name + '_recon.png'))

    psnr = utils.cal_psnr(orig_img[SCALE:-SCALE, SCALE:-SCALE, ...],
                          recon_img[SCALE:-SCALE, SCALE:-SCALE, ...],
                          benchmark=BENCHMARK)

    orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
    recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
    orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
    recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]

    ssim = utils.calc_ssim(recon_img_y, orig_img_y)

    return psnr, ssim
Ejemplo n.º 3
0
    def test_step(self, batch, batch_idx):
        hr_pred, _ = self.student_model(batch[0])
        hr_pred_teacher, _ = self.teacher_model(batch[0])
        hr_gt = batch[1]

        lr_numpy = torch2image(batch[0])[:, :, ::-1]
        hr_teacher_numpy = torch2image(hr_pred_teacher)[:, :, ::-1]
        hr_pred_numpy = torch2image(hr_pred)[:, :, ::-1]
        hr_gt_numpy = torch2image(hr_gt)[:, :, ::-1]

        psnr = calc_psnr(hr_gt_numpy, hr_pred_numpy)
        ssim = calc_ssim(hr_gt_numpy, hr_pred_numpy)

        if not os.path.exists("./results/"):
            os.makedirs("./results/")

        try:
            logger = self.logger.experiment[0]
            name = logger.get_key()
        except:
            name = "tmp"

        if not os.path.exists("./results/{}".format(name)):
            os.makedirs("./results/{}".format(name))

        lr_image = Image.fromarray(lr_numpy)
        hr_teacher = Image.fromarray(hr_teacher_numpy)
        img_pred = Image.fromarray(hr_pred_numpy)
        gt_pred = Image.fromarray(hr_gt_numpy)

        img_pred.save("./results/{}/{}_pred.png".format(name, str(batch_idx)),
                      format="PNG")
        hr_teacher.save("./results/{}/{}_teacher.png".format(
            name, str(batch_idx)),
                        format="PNG")
        gt_pred.save("./results/{}/{}_gt.png".format(name, str(batch_idx)),
                     format="PNG")
        lr_image.save("./results/{}/{}_lr.png".format(name, str(batch_idx)),
                      format="PNG")

        self.log_dict({"psnr": psnr, "ssim": ssim})
Ejemplo n.º 4
0
    def test(self, epoch=10):
        self.ckp.write_log('=> Evaluation...')
        timer_test = utils.timer()
        upscale = self.args.upscale
        avg_psnr = {}
        avg_ssim = {}

        for scale in upscale:
            avg_psnr[scale] = 0.0
            avg_ssim[scale] = 0.0

        for iteration, (input, hr) in enumerate(self.loader_test, 1):

            has_target = type(hr) == list  # if test on demo

            if has_target:
                input, hr = self.prepare([input, hr])
            else:
                input = self.prepare([input])[0]

            sr = self.model(input)

            save_list = [*sr, input]

            if has_target:
                save_list.extend(hr)

                psnr = {}
                ssim = {}
                for i, scale in enumerate(upscale):
                    psnr[scale] = utils.calc_psnr(hr[i], sr[i], int(scale))
                    ssim[scale] = utils.calc_ssim(hr[i], sr[i])
                    avg_psnr[scale] += psnr[scale]
                    avg_ssim[scale] += ssim[scale]

            if self.args.save:
                if has_target:
                    for i, scale in enumerate(upscale):
                        self.ckp.write_log(
                            '=> Image{} PSNR_x{}: {:.4f}'.format(
                                iteration, scale, psnr[scale]))
                        self.ckp.write_log(
                            '=> Image{} SSIM_x{}: {:.4f}'.format(
                                iteration, scale, ssim[scale]))
                self.ckp.save_result(iteration, save_list)

        if has_target:
            for scale, value in avg_psnr.items():
                self.ckp.write_log("=> PSNR_x{}: {:.4f}".format(
                    scale, value / len(self.loader_test)))
                self.ckp.write_log("=> SSIM_x{}: {:.4f}".format(
                    scale, avg_ssim[scale] / len(self.loader_test)))

        self.ckp.write_log("=> Total time: {:.1f}s".format(timer_test.toc()))

        if not self.args.test:
            self.ckp.save_model(self.model, 'latest')
            cur_psnr = avg_psnr[upscale[-1]]
            if self.best_psnr < cur_psnr:
                self.best_psnr = cur_psnr
                self.best_epoch = epoch
                self.ckp.save_model(self.model,
                                    '{}_best'.format(self.best_epoch))
Ejemplo n.º 5
0
        hr = hr[0].detach().cpu().numpy()
        sr = sr[0].detach().cpu().numpy()
        sr = np.transpose(sr, (1, 2, 0))
        hr = np.transpose(hr, (1, 2, 0))

    sr = sr.astype(np.uint8)
    Image.fromarray(sr).save(
        os.path.join(data_root, args.model, data[2] + '.png'))
    hr = hr.astype(np.uint8)
    Image.fromarray(hr).save(os.path.join(data_root, "hr", data[2] + '.png'))

    sr = sr / 255.0
    hr = hr / 255.0

    psnr = utils.calc_psnr(sr, hr, scale=int(args.scale[0]), rgb_range=1.)
    ssim = utils.calc_ssim(sr, hr)
    psnrs.append(psnr)
    ssims.append(ssim)
    print(psnr)
    #print(ssim)
    #print(psnr)

    # plt.subplot(121)
    # plt.imshow(lr.astype(np.uint8))
    #
    # plt.subplot(122)
    # plt.imshow(hr.astype(np.uint8))
    #
    # plt.show()
print(np.mean(np.array(ssims)))
print(np.mean(np.array(psnrs)))
Ejemplo n.º 6
0
    sr_img = []
    hr_img = []

    for img in utils.get_image_paths(sr_path):
        img = utils.imread_uint(img, n_channels=1)
        sr_img.append(img)
    
    for img in utils.get_image_paths(hr_path):
        img = utils.imread_uint(img, n_channels=1)
        hr_img.append(img)

    if len(sr_img) != len(hr_img):
        print('ERROR: The number is not equal!')

    mean_rmse = 0
    mean_psnr = 0
    mean_ssim = 0
    for i in range(0, len(sr_img)):
        rmse, _ = utils.calc_rmse(sr_img[i], hr_img[i])
        psnr = utils.calc_psnr(sr_img[i], hr_img[i])
        ssim = utils.calc_ssim(sr_img[i], hr_img[i])

        logger.info('Image:{:03d} || RMSE:{} || PSNR:{} || SSIM:{}'.format(i+1, rmse, psnr, ssim))
        mean_rmse += rmse
        mean_psnr += psnr
        mean_ssim += ssim
    mean_rmse =  mean_rmse / len(sr_img)
    mean_psnr =  mean_psnr / len(sr_img)
    mean_ssim =  mean_ssim / len(sr_img)
    logger.info('AVG RMSE: {} || PSNR:{} || SSIM:{}'.format(mean_rmse, mean_psnr, mean_ssim))
                                               scale_factor=opt.scale,
                                               mode='bilinear')
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    with torch.no_grad():
                        if network == "MDSR":
                            preds = model(inputs, opt.scale)
                        elif network == "MDSR_blanced_attention":
                            preds = model(inputs, opt.scale)
                        else:
                            preds = model(inputs)

                    preds = convert_rgb_to_y(denormalize(preds.squeeze(0)),
                                             dim_order='chw')
                    labels = convert_rgb_to_y(denormalize(labels.squeeze(0)),
                                              dim_order='chw')

                    preds = preds[opt.scale:-opt.scale, opt.scale:-opt.scale]
                    labels = labels[opt.scale:-opt.scale, opt.scale:-opt.scale]

                    epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
                    epoch_ssim.update(calc_ssim(preds, labels), len(inputs))

                print(
                    'scale:{}    dataset:{}   model:{}   eval psnr: {:.6f}   ssim: {:.4f}'
                    .format(str(scale), datasetfortest, network,
                            epoch_psnr.avg, epoch_ssim.avg))

# python train.py --choose_net="IMDN_BLANCED_CBAM"
Ejemplo n.º 8
0
    def test(self):
        self.ckp.write_log('\nEvaluation:')
        self.model.eval()
        self.ckp.start_log(train=False)
        self.ckp.start_log(train=False, key='ssim')
        with torch.no_grad():
            tqdm_test = tqdm(self.loader_test, ncols=80)
            for idx_img, data_pack in enumerate(tqdm_test):
                if self.args.real:
                    lr, filename = data_pack
                else:
                    lr, hr, kernels, filename = data_pack
                ycbcr_flag = False
                filename = filename[len(filename) // 2]
                # lr: [batch_size, n_seq, 3, patch_size, patch_size]
                if self.args.n_colors == 1 and lr.size()[2] == 3:
                    lr = lr[:, :, 0:1, :, :]
                    if not self.args.real:
                        hr = hr[:, :, 0:1, :, :]

                # Divide LR frame sequence [N, n_sequence, n_colors, H, W] -> N * [1, n_sequence, n_colors, H, W]
                # We need seperate on first dimension because we want to keep sequence order when re-concact
                lr = list(torch.split(lr, 1, dim=0))
                lr = [x.to(self.device) for x in lr]
                lr = [torch.squeeze(x, dim=0) for x in lr]
                lr = torch.cat(lr, dim=0)
                if not self.args.real:
                    hr = list(torch.split(hr, 1, dim=0))
                    center = self.args.n_sequence // 2
                    center_hr = [x[:, center, :, :, :] for x in hr]
                    center_hr = [x.to(self.device) for x in center_hr]
                    center_hr = torch.cat(center_hr, dim=0)

                    hr = [x.to(self.device) for x in hr]
                    hr = [torch.squeeze(x, dim=0) for x in hr]
                    hr = torch.cat(hr, dim=0)
                cur_kernel_pca = None

                sr, _, _, = self.model(lr, cur_kernel_pca)
                sr = torch.clamp(sr, min=0.0, max=1.0)
                if not self.args.real:
                    PSNR = utils.calc_psnr(self.args, sr, center_hr)
                    SSIM = utils.calc_ssim(self.args, sr, center_hr)
                    self.ckp.report_log(PSNR, train=False)
                    self.ckp.report_log(SSIM, train=False, key='ssim')

                if self.args.save_images and idx_img % 30 == 0 or self.args.test_only:

                    if self.args.real:
                        save_list = [sr]
                    else:
                        save_list = [sr]

                    filename = filename[0]
                    self.ckp.save_images(filename, save_list, self.args.scale)

            self.ckp.end_log(len(self.loader_test), train=False)
            self.ckp.end_log(len(self.loader_test), train=False, key='ssim')
            best = self.ckp.psnr_log.max(0)
            self.ckp.write_log(
                '[{}]\taverage PSNR: {:.3f} , average SSIM: {:.3f} (Best: {:.3f} @epoch {})'
                .format(self.args.data_test, self.ckp.psnr_log[-1],
                        self.ckp.ssim_log[-1], best[0], best[1] + 1))
            if not self.args.test_only:
                self.ckp.save(self,
                              self.epoch,
                              is_best=(best[1] + 1 == self.epoch))