def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    model_name = 'dncnn3'  # 'dncnn3'- can be used for blind Gaussian denoising, JPEG deblocking (quality factor 5-100) and super-resolution (x234)

    # important!
    testset_name = 'bsd68'  # test set, low-quality grayscale/color JPEG images
    n_channels = 1  # set 1 for grayscale image, set 3 for color image

    x8 = False  # default: False, x8 to boost performance
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name  # fixed
    L_path = os.path.join(
        testsets, testset_name
    )  # L_path, for Low-quality grayscale/Y-channel JPEG images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    model_pool = 'model_zoo'  # fixed
    model_path = os.path.join(model_pool, model_name + '.pth')
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_dncnn import DnCNN as net
    model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx + 1, img_name + ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)
        if n_channels == 3:
            ycbcr = util.rgb2ycbcr(img_L, False)
            img_L = ycbcr[..., 0:1]
        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------
        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = util.tensor2single(img_E)
        if n_channels == 3:
            ycbcr[..., 0] = img_E
            img_E = util.ycbcr2rgb(ycbcr)
        img_E = util.single2uint(img_E)

        # ------------------------------------
        # save results
        # ------------------------------------
        util.imsave(img_E, os.path.join(E_path, img_name + '.png'))
Exemplo n.º 2
0
def main():

    # --------------------------------
    # let's start!
    # --------------------------------
    utils_logger.logger_info('test_dpsr', log_path='test_dpsr.log')
    logger = logging.getLogger('test_dpsr')

    # basic setting
    # ================================================

    sf = 4  # scale factor
    noise_level_img = 0 / 255.0  # noise level of low quality image, default 0
    noise_level_model = noise_level_img  # noise level of model, default 0
    show_img = True

    use_srganplus = True  # 'True' for SRGAN+ (x4) and 'False' for SRResNet+ (x2,x3,x4)
    testsets = 'testsets'
    testset_current = 'BSD68'

    if use_srganplus and sf == 4:
        model_prefix = 'DPSRGAN'
        save_suffix = 'dpsrgan'
    else:
        model_prefix = 'DPSR'
        save_suffix = 'dpsr'

    model_path = os.path.join('DPSR_models', model_prefix + 'x%01d.pth' % (sf))

    iter_num = 15  # number of iterations, fixed
    n_channels = 3  # only color images, fixed
    border = sf**2  # shave boader to calculate PSNR, fixed

    # k_type = ('d', 'm', 'g')
    k_type = ('m')  # motion blur kernel

    # ================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # load model
    # --------------------------------
    model = SRResNet(in_nc=4,
                     out_nc=3,
                     nc=96,
                     nb=16,
                     upscale=sf,
                     act_mode='R',
                     upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path {:s}. Testing...'.format(model_path))

    # --------------------------------
    # read image (img) and kernel (k)
    # --------------------------------
    test_results = OrderedDict()

    for k_type_n in range(len(k_type)):

        # --1--> L_folder, folder of Low-quality images
        testsubset_current = 'x%01d_%01s' % (sf, k_type[k_type_n])
        L_folder = os.path.join(testsets, testset_current, testsubset_current)

        # --2--> E_folder, folder of Estimated images
        E_folder = os.path.join(testsets, testset_current,
                                testsubset_current + '_' + save_suffix)
        util.mkdir(E_folder)

        # --3--> H_folder, folder of High-quality images
        H_folder = os.path.join(testsets, testset_current, 'GT')

        test_results['psnr_' + k_type[k_type_n]] = []

        logger.info(L_folder)
        idx = 0

        for im in os.listdir(os.path.join(L_folder)):
            if im.endswith('.jpg') or im.endswith('.bmp') or im.endswith(
                    '.png'):

                # --------------------------------
                # (1) img_L
                # --------------------------------
                idx += 1
                img_name, ext = os.path.splitext(im)
                img_L = util.imread_uint(os.path.join(L_folder, im),
                                         n_channels=n_channels)
                util.imshow(img_L) if show_img else None

                np.random.seed(seed=0)  # for reproducibility
                img_L = util.unit2single(img_L) + np.random.normal(
                    0, noise_level_img, img_L.shape)

                # --------------------------------
                # (2) kernel
                # --------------------------------
                k = loadmat(os.path.join(L_folder,
                                         img_name + '.mat'))['kernel']
                k = k.astype(np.float32)
                k /= np.sum(k)

                # --------------------------------
                # (3) get upperleft, denominator
                # --------------------------------
                upperleft, denominator = utils_deblur.get_uperleft_denominator(
                    img_L, k)

                # --------------------------------
                # (4) get rhos and sigmas
                # --------------------------------
                rhos, sigmas = utils_deblur.get_rho_sigma(sigma=max(
                    0.255 / 255., noise_level_model),
                                                          iter_num=iter_num)

                # --------------------------------
                # (5) main iteration
                # --------------------------------
                z = img_L
                rhos = np.float32(rhos)
                sigmas = np.float32(sigmas)

                for i in range(iter_num):

                    # --------------------------------
                    # step 1, Eq. (9) // FFT
                    # --------------------------------
                    rho = rhos[i]
                    if i != 0:
                        z = util.imresize_np(z, 1 / sf, True)

                    z = np.real(
                        np.fft.ifft2(
                            (upperleft + rho * np.fft.fft2(z, axes=(0, 1))) /
                            (denominator + rho),
                            axes=(0, 1)))
                    # imsave('LR_deblurred_%02d.png'%i, np.clip(z, 0, 1))

                    # --------------------------------
                    # step 2, Eq. (12) // super-resolver
                    # --------------------------------
                    sigma = torch.from_numpy(np.array(sigmas[i]))
                    img_L = util.single2tensor4(z)

                    noise_level_map = torch.ones(
                        (1, 1, img_L.size(2), img_L.size(3)),
                        dtype=torch.float).mul_(sigma)
                    img_L = torch.cat((img_L, noise_level_map), dim=1)
                    img_L = img_L.to(device)
                    # with torch.no_grad():
                    z = model(img_L)
                    z = util.tensor2single(z)

                # --------------------------------
                # (6) img_E
                # --------------------------------
                img_E = util.single2uint(z)  # np.uint8((z * 255.0).round())

                # --------------------------------
                # (7) img_H
                # --------------------------------
                img_H = util.imread_uint(os.path.join(H_folder,
                                                      img_name[:7] + '.png'),
                                         n_channels=n_channels)

                util.imshow(
                    np.concatenate([img_E, img_H], axis=1),
                    title='Recovered / Ground-truth') if show_img else None

                psnr = util.calculate_psnr(img_E, img_H, border=border)

                logger.info('{:->4d}--> {:>10s}, {:.2f}dB'.format(
                    idx, im, psnr))
                test_results['psnr_' + k_type[k_type_n]].append(psnr)

                util.imsave(img_E, os.path.join(E_folder, img_name + ext))

        ave_psnr = sum(test_results['psnr_' + k_type[k_type_n]]) / len(
            test_results['psnr_' + k_type[k_type_n]])
        logger.info(
            '------> Average PSNR(RGB) of ({} - {}) is : {:.2f} dB'.format(
                testset_current, testsubset_current, ave_psnr))
def main():

    # --------------------------------
    # let's start!
    # --------------------------------
    utils_logger.logger_info('test_srresnetplus_real', log_path='test_srresnetplus_real.log')
    logger = logging.getLogger('test_srresnetplus_real')

    # basic setting
    # ================================================

    sf = 4  # from 2, 3 and 4
    noise_level_img = 14./255.  # noise level of low-quality image
    testsets = 'testsets'
    testset_current = 'real_imgs'
    use_srganplus = True  # 'True' for SRGAN+ (x4) and 'False' for SRResNet+ (x2,x3,x4)

    im = 'frog.png'  # frog.png

    if 'frog' in im:
        noise_level_img = 14./255.

    noise_level_model = noise_level_img  # noise level of model

    if use_srganplus and sf == 4:
        model_prefix = 'DPSRGAN'
        save_suffix = 'srganplus'
    else:
        model_prefix = 'DPSR'
        save_suffix = 'srresnet'

    model_path = os.path.join('DPSR_models', model_prefix+'x%01d.pth' % (sf))
    show_img = True
    n_channels = 3  # only color images, fixed

    # ================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # (1) load trained model
    # --------------------------------

    model = SRResNet(in_nc=4, out_nc=3, nc=96, nb=16, upscale=sf, act_mode='R', upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path {:s}. Testing...'.format(model_path))

    # --------------------------------
    # (2) L_folder, E_folder
    # --------------------------------
    # --1--> L_folder, folder of Low-quality images
    L_folder = os.path.join(testsets, testset_current, 'LR')  # L: Low quality

    # --2--> E_folder, folder of Estimated images
    E_folder = os.path.join(testsets, testset_current, 'x{:01d}_'.format(sf)+save_suffix)
    util.mkdir(E_folder)

    logger.info(L_folder)

    # for im in os.listdir(os.path.join(L_folder)):
    #   if (im.endswith('.jpg') or im.endswith('.bmp') or im.endswith('.png')) and 'kernel' not in im:

    # --------------------------------
    # (3) load low-resolution image
    # --------------------------------
    img_name, ext = os.path.splitext(im)
    img = util.imread_uint(os.path.join(L_folder, im), n_channels=n_channels)
    h, w = img.shape[:2]
    util.imshow(img, title='Low-resolution image') if show_img else None
    img = util.uint2single(img)
    img_L = util.single2tensor4(img)

    # --------------------------------
    # (4) do super-resolution
    # --------------------------------
    noise_level_map = torch.ones((1, 1, img_L.size(2), img_L.size(3)), dtype=torch.float).mul_(noise_level_model)
    img_L = torch.cat((img_L, noise_level_map), dim=1)
    img_L = img_L.to(device)
    # with torch.no_grad():
    img_E = model(img_L)
    img_E = util.tensor2single(img_E)

    # --------------------------------
    # (5) img_E
    # --------------------------------
    img_E = util.single2uint(img_E[:h*sf, :w*sf])  # np.uint8((z[:h*sf, :w*sf] * 255.0).round())

    logger.info('saving: sf = {}, {}.'.format(sf, img_name+'_x{}'.format(sf)+ext))
    util.imsave(img_E, os.path.join(E_folder, img_name+'_x{}'.format(sf)+ext))

    util.imshow(img_E, title='Recovered image') if show_img else None
Exemplo n.º 4
0
def main():

    # --------------------------------
    # let's start!
    # --------------------------------
    utils_logger.logger_info('test_srresnetplus',
                             log_path='test_srresnetplus.log')
    logger = logging.getLogger('test_srresnetplus')

    # basic setting
    # ================================================

    sf = 4  # scale factor
    noise_level_img = 0 / 255.0  # noise level of L image
    noise_level_model = noise_level_img
    show_img = True

    use_srganplus = True  # 'True' for SRGAN+ (x4) and 'False' for SRResNet+ (x2,x3,x4)
    testsets = 'testsets'
    testset_current = 'Set5'
    n_channels = 3  # only color images, fixed
    border = sf  # shave boader to calculate PSNR and SSIM

    if use_srganplus and sf == 4:
        model_prefix = 'DPSRGAN'
        save_suffix = 'dpsrgan'
    else:
        model_prefix = 'DPSR'
        save_suffix = 'dpsr'

    model_path = os.path.join('DPSR_models', model_prefix + 'x%01d.pth' % (sf))

    # --------------------------------
    # L_folder, E_folder, H_folder
    # --------------------------------
    # --1--> L_folder, folder of Low-quality images
    testsubset_current = 'x%01d' % (sf)
    L_folder = os.path.join(testsets, testset_current, testsubset_current)

    # --2--> E_folder, folder of Estimated images
    E_folder = os.path.join(testsets, testset_current,
                            testsubset_current + '_' + save_suffix)
    util.mkdir(E_folder)

    # --3--> H_folder, folder of High-quality images
    H_folder = os.path.join(testsets, testset_current, 'GT')

    need_H = True if os.path.exists(H_folder) else False

    # ================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # load model
    # --------------------------------
    model = SRResNet(in_nc=4,
                     out_nc=3,
                     nc=96,
                     nb=16,
                     upscale=sf,
                     act_mode='R',
                     upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path {:s}. \nTesting...'.format(model_path))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    idx = 0

    logger.info(L_folder)

    for im in os.listdir(os.path.join(L_folder)):
        if im.endswith('.jpg') or im.endswith('.bmp') or im.endswith('.png'):

            logger.info('{:->4d}--> {:>10s}'.format(
                idx, im)) if not need_H else None

            # --------------------------------
            # (1) img_L
            # --------------------------------
            idx += 1
            img_name, ext = os.path.splitext(im)
            img = util.imread_uint(os.path.join(L_folder, im),
                                   n_channels=n_channels)

            np.random.seed(seed=0)  # for reproducibility
            img = util.unit2single(img) + np.random.normal(
                0, noise_level_img, img.shape)

            util.imshow(img,
                        title='Low-resolution image') if show_img else None

            img_L = util.single2tensor4(img)
            noise_level_map = torch.ones(
                (1, 1, img_L.size(2), img_L.size(3)),
                dtype=torch.float).mul_(noise_level_model)
            img_L = torch.cat((img_L, noise_level_map), dim=1)
            img_L = img_L.to(device)

            # --------------------------------
            # (2) img_E
            # --------------------------------
            img_E = model(img_L)
            img_E = util.tensor2single(img_E)
            img_E = util.single2uint(img_E)  # np.uint8((z * 255.0).round())

            if need_H:

                # --------------------------------
                # (3) img_H
                # --------------------------------
                img_H = util.imread_uint(os.path.join(H_folder, im),
                                         n_channels=n_channels)
                img_H = util.modcrop(img_H, scale=sf)

                # --------------------------------
                # PSNR and SSIM
                # --------------------------------
                psnr = util.calculate_psnr(img_E, img_H, border=border)
                ssim = util.calculate_ssim(img_E, img_H, border=border)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if np.ndim(img_H) == 3:  # RGB image

                    img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                    img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                    psnr_y = util.calculate_psnr(img_E_y,
                                                 img_H_y,
                                                 border=border)
                    ssim_y = util.calculate_ssim(img_E_y,
                                                 img_H_y,
                                                 border=border)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)

                    logger.info(
                        '{:->20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}.'
                        .format(im, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                            im, psnr, ssim))

            # --------------------------------
            # save results
            # --------------------------------
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None
            util.imsave(
                img_E,
                os.path.join(E_folder, img_name + '_x{}'.format(sf) + ext))

    if need_H:

        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'PSNR/SSIM(RGB) - {} - x{} -- PSNR: {:.2f} dB; SSIM: {:.4f}'.
            format(testset_current, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'PSNR/SSIM( Y ) - {} - x{} -- PSNR: {:.2f} dB; SSIM: {:.4f}'.
                format(testset_current, sf, ave_psnr_y, ave_ssim_y))
Exemplo n.º 5
0
def main():

    # --------------------------------
    # let's start!
    # --------------------------------
    utils_logger.logger_info('test_dpsr_real', log_path='test_dpsr_real.log')
    logger = logging.getLogger('test_dpsr_real')
    global arg
    arg = parser.parse_args()
    # basic setting
    # ================================================
    sf = arg.sf
    show_img = False
    noise_level_img = 8. / 255.
    #testsets = '/home/share2/wutong/DPSR/testsets/test/'

    #im = '0000115_01031_d_0000082.jpg'  # chip.png colour.png

    # if 'chip' in im:
    #   noise_level_img = 8./255.
    # elif 'colour' in im:
    #noise_level_img = 0.5/255.

    use_srganplus = False
    if use_srganplus and sf == 4:
        model_prefix = 'DPSRGAN'
        save_suffix = 'dpsrgan'
    else:
        model_prefix = 'DPSR'
        save_suffix = 'dpsr'

    model_path = os.path.join('DPSR_models', model_prefix + 'x%01d.pth' % (sf))

    iter_num = 15  # number of iterations
    n_channels = 3  # only color images, fixed

    # ================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # (1) load trained model
    # --------------------------------

    model = SRResNet(in_nc=4,
                     out_nc=3,
                     nc=96,
                     nb=16,
                     upscale=sf,
                     act_mode='R',
                     upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path {:s}. Testing...'.format(model_path))

    # --------------------------------
    # (2) L_folder, E_folder
    # --------------------------------
    # --1--> L_folder, folder of Low-quality images
    L_folder = os.path.join(arg.load)  # L: Low quality

    # --2--> E_folder, folder of Estimated images
    E_folder = os.path.join(arg.save)
    util.mkdir(E_folder)

    logger.info(L_folder)

    # for im in os.listdir(os.path.join(L_folder)):
    #   if (im.endswith('.jpg') or im.endswith('.bmp') or im.endswith('.png')) and 'kernel' not in im:

    # --------------------------------
    # (3) load low-resolution image
    # --------------------------------
    img_list = os.listdir(L_folder)
    for im in img_list:
        img_path, ext = os.path.splitext(im)
        img_name = img_path.split('/')[-1]
        img = util.imread_uint(os.path.join(L_folder, im),
                               n_channels=n_channels)
        h, w = img.shape[:2]
        util.imshow(img, title='Low-resolution image') if show_img else None
        img = util.unit2single(img)

        # --------------------------------
        # (4) load blur kernel
        # --------------------------------
        # if os.path.exists(os.path.join(L_folder, img_name+'_kernel.mat')):
        # k = loadmat(os.path.join(L_folder, img_name+'.mat'))['kernel']
        #  k = k.astype(np.float64)
        #  k /= k.sum()
        # elif os.path.exists(os.path.join(L_folder, img_name+'_kernel.png')):
        #   k = cv2.imread(os.path.join(L_folder, img_name+'_kernel.png'), 0)
        #    k = np.float64(k)  # float64 !
        #    k /= k.sum()
        #else:
        k = utils_deblur.fspecial('gaussian', 5, 0.25)
        iter_num = 5

        # --------------------------------
        # (5) handle boundary
        # --------------------------------
        img = utils_deblur.wrap_boundary_liu(
            img,
            utils_deblur.opt_fft_size(
                [img.shape[0] + k.shape[0] + 1,
                 img.shape[1] + k.shape[1] + 1]))

        # --------------------------------
        # (6) get upperleft, denominator
        # --------------------------------
        upperleft, denominator = utils_deblur.get_uperleft_denominator(img, k)

        # --------------------------------
        # (7) get rhos and sigmas
        # --------------------------------
        rhos, sigmas = utils_deblur.get_rho_sigma(sigma=max(
            0.255 / 255.0, noise_level_img),
                                                  iter_num=iter_num)

        # --------------------------------
        # (8) main iteration
        # --------------------------------
        z = img
        rhos = np.float32(rhos)
        sigmas = np.float32(sigmas)

        for i in range(iter_num):

            logger.info('Iter: {:->4d}--> {}'.format(i + 1, im))
            # --------------------------------
            # step 1, Eq. (9) // FFT
            # --------------------------------
            rho = rhos[i]
            if i != 0:
                z = util.imresize_np(z, 1 / sf, True)

            z = np.real(
                np.fft.ifft2((upperleft + rho * np.fft.fft2(z, axes=(0, 1))) /
                             (denominator + rho),
                             axes=(0, 1)))

            # --------------------------------
            # step 2, Eq. (12) // super-resolver
            # --------------------------------
            sigma = torch.from_numpy(np.array(sigmas[i]))
            img_L = util.single2tensor4(z)

            noise_level_map = torch.ones((1, 1, img_L.size(2), img_L.size(3)),
                                         dtype=torch.float).mul_(sigma)
            img_L = torch.cat((img_L, noise_level_map), dim=1)
            img_L = img_L.to(device)
            # with torch.no_grad():
            z = model(img_L)
            z = util.tensor2single(z)

        # --------------------------------
        # (9) img_E
        # --------------------------------
        img_E = util.single2uint(
            z[:h * sf, :w * sf])  # np.uint8((z[:h*sf, :w*sf] * 255.0).round())

        logger.info('saving: sf = {}, {}.'.format(
            sf, img_name + '_x{}'.format(sf) + ext))
        util.imsave(img_E, os.path.join(E_folder, img_name + ext))

        util.imshow(img_E, title='Recovered image') if show_img else None