Esempio n. 1
0
def test(img_path, img_path2, crop_border, test_y_channel=False):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED)

    # --------------------- Numpy ---------------------
    psnr = calculate_psnr(img,
                          img2,
                          crop_border=crop_border,
                          input_order='HWC',
                          test_y_channel=test_y_channel)
    ssim = calculate_ssim(img,
                          img2,
                          crop_border=crop_border,
                          input_order='HWC',
                          test_y_channel=test_y_channel)
    print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')

    # --------------------- PyTorch (CPU) ---------------------
    img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
    img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)

    psnr_pth = calculate_psnr_pt(img,
                                 img2,
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    ssim_pth = calculate_ssim_pt(img,
                                 img2,
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    print(
        f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}'
    )

    # --------------------- PyTorch (GPU) ---------------------
    img = img.cuda()
    img2 = img2.cuda()
    psnr_pth = calculate_psnr_pt(img,
                                 img2,
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    ssim_pth = calculate_ssim_pt(img,
                                 img2,
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    print(
        f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}'
    )

    psnr_pth = calculate_psnr_pt(torch.repeat_interleave(img, 2, dim=0),
                                 torch.repeat_interleave(img2, 2, dim=0),
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    ssim_pth = calculate_ssim_pt(torch.repeat_interleave(img, 2, dim=0),
                                 torch.repeat_interleave(img2, 2, dim=0),
                                 crop_border=crop_border,
                                 test_y_channel=test_y_channel)
    print(
        f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
        f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
def main():
    """Calculate PSNR and SSIM for images.

    Configurations:
        folder_gt (str): Path to gt (Ground-Truth).
        folder_restored (str): Path to restored images.
        crop_border (int): Crop border for each side.
        suffix (str): Suffix for restored images.
        test_y_channel (bool): If True, test Y channel (In MatLab YCbCr format)
            If False, test RGB channels.
    """
    # Configurations
    # -------------------------------------------------------------------------
    folder_gt = 'datasets/val_set14/Set14'
    folder_restored = 'results/exp/visualization/val_set14'
    crop_border = 4
    suffix = '_expname'
    test_y_channel = False
    # -------------------------------------------------------------------------

    psnr_all = []
    ssim_all = []
    img_list = sorted(mmcv.scandir(folder_gt, recursive=True))

    if test_y_channel:
        print('Testing Y channel.')
    else:
        print('Testing RGB channels.')

    for i, img_path in enumerate(img_list):
        basename, ext = osp.splitext(osp.basename(img_path))
        img_gt = mmcv.imread(
            osp.join(folder_gt, img_path), flag='unchanged').astype(
                np.float32) / 255.
        img_restored = mmcv.imread(
            osp.join(folder_restored, basename + suffix + ext),
            flag='unchanged').astype(np.float32) / 255.

        if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
            img_gt = mmcv.bgr2ycbcr(img_gt, y_only=True)
            img_restored = mmcv.bgr2ycbcr(img_restored, y_only=True)

        # calculate PSNR and SSIM
        psnr = calculate_psnr(
            img_gt * 255,
            img_restored * 255,
            crop_border=crop_border,
            input_order='HWC')
        ssim = calculate_ssim(
            img_gt * 255,
            img_restored * 255,
            crop_border=crop_border,
            input_order='HWC')
        print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, '
              f'\tSSIM: {ssim:.6f}')
        psnr_all.append(psnr)
        ssim_all.append(ssim)
    print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, '
          f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
Esempio n. 3
0
def main(args):
    """Calculate PSNR and SSIM for images.
    """
    psnr_all = []
    ssim_all = []
    img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
    img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True)))

    if args.test_y_channel:
        print('Testing Y channel.')
    else:
        print('Testing RGB channels.')

    for i, img_path in enumerate(img_list_gt):
        basename, ext = osp.splitext(osp.basename(img_path))
        img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        if args.suffix == '':
            img_path_restored = img_list_restored[i]
        else:
            img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
        img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

        if args.correct_mean_var:
            mean_l = []
            std_l = []
            for j in range(3):
                mean_l.append(np.mean(img_gt[:, :, j]))
                std_l.append(np.std(img_gt[:, :, j]))
            for j in range(3):
                # correct twice
                mean = np.mean(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
                std = np.std(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]

                mean = np.mean(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
                std = np.std(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]

        if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
            img_gt = bgr2ycbcr(img_gt, y_only=True)
            img_restored = bgr2ycbcr(img_restored, y_only=True)

        # calculate PSNR and SSIM
        psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
        ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
        print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
        psnr_all.append(psnr)
        ssim_all.append(ssim)
    print(args.gt)
    print(args.restored)
    print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
def main():
    """Calculate PSNR and SSIM for images.

    Configurations:
        folder_gt (str): Path to gt (Ground-Truth).
        folder_restored (str): Path to restored images.
        crop_border (int): Crop border for each side.
        suffix (str): Suffix for restored images.
        test_y_channel (bool): If True, test Y channel (In MatLab YCbCr format)
            If False, test RGB channels.
    """
    # Configurations
    # -------------------------------------------------------------------------
    folder_gt = 'datasets/val_set14/Set14'
    folder_restored = 'results/exp/visualization/val_set14'
    crop_border = 4
    suffix = '_expname'
    test_y_channel = False
    correct_mean_var = False
    # -------------------------------------------------------------------------

    psnr_all = []
    ssim_all = []
    img_list = sorted(scandir(folder_gt, recursive=True, full_path=True))

    if test_y_channel:
        print('Testing Y channel.')
    else:
        print('Testing RGB channels.')

    for i, img_path in enumerate(img_list):
        basename, ext = osp.splitext(osp.basename(img_path))
        img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(
            np.float32) / 255.
        img_restored = cv2.imread(
            osp.join(folder_restored, basename + suffix + ext),
            cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

        if correct_mean_var:
            mean_l = []
            std_l = []
            for j in range(3):
                mean_l.append(np.mean(img_gt[:, :, j]))
                std_l.append(np.std(img_gt[:, :, j]))
            for j in range(3):
                # correct twice
                mean = np.mean(img_restored[:, :, j])
                img_restored[:, :,
                             j] = img_restored[:, :, j] - mean + mean_l[j]
                std = np.std(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]

                mean = np.mean(img_restored[:, :, j])
                img_restored[:, :,
                             j] = img_restored[:, :, j] - mean + mean_l[j]
                std = np.std(img_restored[:, :, j])
                img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]

        if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
            img_gt = bgr2ycbcr(img_gt, y_only=True)
            img_restored = bgr2ycbcr(img_restored, y_only=True)

        # calculate PSNR and SSIM
        psnr = calculate_psnr(img_gt * 255,
                              img_restored * 255,
                              crop_border=crop_border,
                              input_order='HWC')
        ssim = calculate_ssim(img_gt * 255,
                              img_restored * 255,
                              crop_border=crop_border,
                              input_order='HWC')
        print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, '
              f'\tSSIM: {ssim:.6f}')
        psnr_all.append(psnr)
        ssim_all.append(ssim)
    print(folder_gt)
    print(folder_restored)
    print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, '
          f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}')