Exemple #1
0
def main(json_path='options/val_tsms.json'):
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default=json_path,
                        help='Path to option JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=True)

    logger_name = 'val_msmd_patch'
    utils_logger.logger_info(
        logger_name, os.path.join(opt['path']['log'], logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    for phase, dataset_opt in opt['datasets'].items():
        test_set = define_Dataset(phase, dataset_opt)
        test_loader = DataLoader(test_set,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 drop_last=False,
                                 pin_memory=True)

    model = define_Model(opt, stage2=True)
    model.load()
    avg_psnr = 0.0
    idx = 0

    for test_data in test_loader:
        idx += 1
        image_name = os.path.basename(test_data['L_path'][0])
        image_name = image_name + '.png'
        save_img_path = os.path.join(opt['path']['images'], image_name)

        model.feed_data(test_data)
        model.test()

        visuals = model.current_visuals()
        E_img = util.tensor2uint(visuals['E'])
        #print(E_img.shape)
        H_img = util.tensor2uint(visuals['H'])

        # -----------------------
        # save estimated image E
        # -----------------------
        util.imsave(E_img, save_img_path)
        # -----------------------
        # calculate PSNR
        # -----------------------
        current_psnr = util.calculate_psnr(E_img, H_img, border=4)
        logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(
            idx, image_name, current_psnr))

        avg_psnr += current_psnr

    avg_psnr = avg_psnr / idx

    # testing log
    message_te = '\tVal_PSNR_avg: {:<.2f}dB'.format(avg_psnr)
    logger.info(message_te)
Exemple #2
0
 def prepare_visuals(self):
     """ prepare visual for first sample in batch """
     self.out_dict = {}
     self.out_dict['y'] = util.tensor2uint(self.y[0].detach().float().cpu())
     self.out_dict['dx'] = util.tensor2uint(
         self.dx[0].detach().float().cpu())
     self.out_dict['d'] = self.d[0].detach().float().cpu()
     self.out_dict['y_gt'] = util.tensor2uint(
         self.y_gt[0].detach().float().cpu())
     self.out_dict['path'] = self.path[0]
Exemple #3
0
def predict(img, n_channels, model, x8):
    img_name, ext, img_L = load_image(img, n_channels)
    img_L = img_L.to(device)

    # Prediction
    if not x8:
        img_E = model(img_L)
    else:
        img_E = utils_model.test_mode(model, img_L, mode=3)

    img_E = util.tensor2uint(img_E)

    # Save Image
    out_path = 'testresults'
    util.imsave(img_E, os.path.join(out_path, img_name + ext))
    print('Brisque Score of enhanced image : %f' %
          (brisq.get_score(os.path.join(out_path, img_name + ext))))
    print('*-----------------------------------------------*')
Exemple #4
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrnet'  # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set5'  # test set,  'set5' | 'srbsd68'
    test_sf = [4] if 'gan' in model_name else [
        2, 3, 4
    ]  # scale factor, from {1,2,3,4}

    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image
    save_LEH = False  # save zoomed LR, E and H images

    # ----------------------------------------
    # load testing kernels
    # ----------------------------------------
    # kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels.mat'))['kernels']
    kernels = loadmat(os.path.join('kernels', 'kernels_12.mat'))['kernels']

    n_channels = 1 if 'gray' in model_name else 3  # 3 for color image, 1 for grayscale image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    noise_level_img = 0  # fixed: 0, noise level for LR image
    noise_level_model = noise_level_img  # fixed, noise level of model, default 0
    result_name = testset_name + '_' + model_name
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path = H_path, E_path, logger
    # ----------------------------------------
    L_path = os.path.join(
        testsets,
        testset_name)  # L_path and H_path, fixed, for Low-quality images
    E_path = os.path.join(results,
                          result_name)  # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    if 'tiny' in model_name:
        model = net(n_iter=6,
                    h_nc=32,
                    in_nc=4,
                    out_nc=3,
                    nc=[16, 32, 64, 64],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")
    else:
        model = net(n_iter=8,
                    h_nc=64,
                    in_nc=4,
                    out_nc=3,
                    nc=[64, 128, 256, 512],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for key, v in model.named_parameters():
        v.requires_grad = False
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    model = model.to(device)

    logger.info('Model path: {:s}'.format(model_path))
    logger.info('Params number: {}'.format(number_parameters))
    logger.info('Model_name:{}, image sigma:{}'.format(model_name,
                                                       noise_level_img))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    # --------------------------------
    # read images
    # --------------------------------
    test_results_ave = OrderedDict()
    test_results_ave['psnr_sf_k'] = []

    for sf in test_sf:

        for k_index in range(kernels.shape[1]):

            test_results = OrderedDict()
            test_results['psnr'] = []
            kernel = kernels[0, k_index].astype(np.float64)

            ## other kernels
            # kernel = utils_deblur.blurkernel_synthesis(h=25)  # motion kernel
            # kernel = utils_deblur.fspecial('gaussian', 25, 1.6) # Gaussian kernel
            # kernel = sr.shift_pixel(kernel, sf)  # pixel shift; optional
            # kernel /= np.sum(kernel)

            util.surf(kernel) if show_img else None
            idx = 0

            for img in L_paths:

                # --------------------------------
                # (1) classical degradation, img_L
                # --------------------------------
                idx += 1
                img_name, ext = os.path.splitext(os.path.basename(img))
                img_H = util.imread_uint(
                    img, n_channels=n_channels)  # HR image, int8
                img_H = util.modcrop(img_H, np.lcm(sf, 8))  # modcrop

                # generate degraded LR image
                img_L = ndimage.filters.convolve(img_H,
                                                 kernel[..., np.newaxis],
                                                 mode='wrap')  # blur
                img_L = sr.downsample_np(
                    img_L, sf,
                    center=False)  # downsample, standard s-fold downsampler
                img_L = util.uint2single(img_L)  # uint2single

                np.random.seed(seed=0)  # for reproducibility
                img_L += np.random.normal(0, noise_level_img,
                                          img_L.shape)  # add AWGN

                util.imshow(util.single2uint(img_L)) if show_img else None

                x = util.single2tensor4(img_L)
                k = util.single2tensor4(kernel[..., np.newaxis])
                sigma = torch.tensor(noise_level_model).float().view(
                    [1, 1, 1, 1])
                [x, k, sigma] = [el.to(device) for el in [x, k, sigma]]

                # --------------------------------
                # (2) inference
                # --------------------------------
                x = model(x, k, sf, sigma)

                # --------------------------------
                # (3) img_E
                # --------------------------------
                img_E = util.tensor2uint(x)

                if save_E:
                    util.imsave(
                        img_E,
                        os.path.join(
                            E_path, img_name + '_x' + str(sf) + '_k' +
                            str(k_index + 1) + '_' + model_name + '.png'))

                # --------------------------------
                # (4) img_LEH
                # --------------------------------
                img_L = util.single2uint(img_L)
                if save_LEH:
                    k_v = kernel / np.max(kernel) * 1.2
                    k_v = util.single2uint(
                        np.tile(k_v[..., np.newaxis], [1, 1, 3]))
                    k_v = cv2.resize(k_v, (3 * k_v.shape[1], 3 * k_v.shape[0]),
                                     interpolation=cv2.INTER_NEAREST)
                    img_I = cv2.resize(
                        img_L, (sf * img_L.shape[1], sf * img_L.shape[0]),
                        interpolation=cv2.INTER_NEAREST)
                    img_I[:k_v.shape[0], -k_v.shape[1]:, :] = k_v
                    img_I[:img_L.shape[0], :img_L.shape[1], :] = img_L
                    util.imshow(np.concatenate([img_I, img_E, img_H], axis=1),
                                title='LR / Recovered / Ground-truth'
                                ) if show_img else None
                    util.imsave(
                        np.concatenate([img_I, img_E, img_H], axis=1),
                        os.path.join(
                            E_path, img_name + '_x' + str(sf) + '_k' +
                            str(k_index + 1) + '_LEH.png'))

                if save_L:
                    util.imsave(
                        img_L,
                        os.path.join(
                            E_path, img_name + '_x' + str(sf) + '_k' +
                            str(k_index + 1) + '_LR.png'))

                psnr = util.calculate_psnr(
                    img_E, img_H, border=sf**2)  # change with your own border
                test_results['psnr'].append(psnr)
                logger.info(
                    '{:->4d}--> {:>10s} -- x{:>2d} --k{:>2d} PSNR: {:.2f}dB'.
                    format(idx, img_name + ext, sf, k_index, psnr))

            ave_psnr_k = sum(test_results['psnr']) / len(test_results['psnr'])
            logger.info(
                '------> Average PSNR(RGB) of ({}) scale factor: ({}), kernel: ({}) sigma: ({}): {:.2f} dB'
                .format(testset_name, sf, k_index + 1, noise_level_model,
                        ave_psnr_k))
            test_results_ave['psnr_sf_k'].append(ave_psnr_k)
    logger.info(test_results_ave['psnr_sf_k'])
Exemple #5
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter17000'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/ZEMAX/uabcnet_{}.pth'.format(
            model_code))
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    img_names = glob.glob(
        '/home/xiu/databag/deblur/ICCV2021/suo_image/*/AC254-075-A-ML-Zemax(ZMX).bmp'
    )
    img_names.sort()
    for img_id, img_name in enumerate(img_names):
        img_L = cv2.imread(img_name)
        img_L = img_L.astype(np.float32)
        W, H = img_L.shape[:2]
        num_patch = [6, 8]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/ZEMAX/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)
        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        #save img_L

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        t1 = time.time()

        print('[{}/{}]: {} s per frame'.format(img_id, len(img_names),
                                               t1 - t0))

        xk = patch_E
        xk = xk.astype(np.uint8)

        cv2.imshow('res', xk)
        cv2.imshow('input', patch_L.astype(np.uint8))

        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
Exemple #6
0
# load test image  
x = util.imread_uint(test_path, n_channels=n_channels)
orig_im = x.squeeze()
x = util.uint2single(x)
np.random.seed(seed=0)  # for reproducibility
y = x + np.random.normal(0, sigma/255., x.shape) # add gaussian noise
y = util.single2tensor4(y)
y = y.to(device)

# denoise the image to compare PSNR before and after adaptation
with torch.no_grad():
  x_ = model(y)

# compute PSNR
denoised_im = util.tensor2uint(x_)
prev_psnr = util.calculate_psnr(denoised_im, orig_im, border=0)

# external adaptation

# open train image
x_train = util.imread_uint(train_path, n_channels=n_channels)
x_train = util.uint2single(x_train)
x_train_comp = util.single2tensor4(x_train)
x_train_comp = x_train_comp.to(device)

model.train()
optimizer = optim.Adam(model.parameters(), lr=lr)

# training loop
start_time = time.time()
Exemple #7
0
def main():

    utils_logger.logger_info('blind_sr_log', log_path='blind_sr_log.log')
    logger = logging.getLogger('blind_sr_log')

#    print(torch.__version__)               # pytorch version
#    print(torch.version.cuda)              # cuda version
#    print(torch.backends.cudnn.version())  # cudnn version

    testsets = 'testsets'       # fixed, set path of testsets
    testset_Ls = ['RealSRSet']  # ['RealSRSet','DPED']

    model_names = ['RRDB','ESRGAN','FSSR_DPED','FSSR_JPEG','RealSR_DPED','RealSR_JPEG']
    model_names = ['BSRGAN']

    save_results = True
    sf = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for model_name in model_names:

        model_path = os.path.join('model_zoo', model_name+'.pth')          # set model path
        logger.info('{:>16s} : {:s}'.format('Model Name', model_name))

        # torch.cuda.set_device(0)      # set GPU ID
        logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
        torch.cuda.empty_cache()

        # --------------------------------
        # define network and load model
        # --------------------------------
        model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32)  # define network

#            model_old = torch.load(model_path)
#            state_dict = model.state_dict()
#            for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
#                state_dict[key2] = param
#            model.load_state_dict(state_dict, strict=True)

        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        torch.cuda.empty_cache()

        for testset_L in testset_Ls:

            L_path = os.path.join(testsets, testset_L)
            #E_path = os.path.join(testsets, testset_L+'_'+model_name)
            E_path = os.path.join(testsets, testset_L+'_results_x'+str(sf))
            util.mkdir(E_path)

            logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
            logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
            idx = 0

            for img in util.get_image_paths(L_path):

                # --------------------------------
                # (1) img_L
                # --------------------------------
                idx += 1
                img_name, ext = os.path.splitext(os.path.basename(img))
                logger.info('{:->4d} --> {:<s} --> x{:<d}--> {:<s}'.format(idx, model_name, sf, img_name+ext))

                img_L = util.imread_uint(img, n_channels=3)
                img_L = util.uint2tensor4(img_L)
                img_L = img_L.to(device)

                # --------------------------------
                # (2) inference
                # --------------------------------
                img_E = model(img_L)

                # --------------------------------
                # (3) img_E
                # --------------------------------
                img_E = util.tensor2uint(img_E)
                if save_results:
                    util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
Exemple #8
0
def main():
    """
    # ----------------------------------------------------------------------------------
    # In real applications, you should set proper 
    # - "noise_level_img": from [3, 25], set 3 for clean image, try 15 for very noisy LR images
    # - "k" (or "kernel_width"): blur kernel is very important!!!  kernel_width from [0.6, 3.0]
    # to get the best performance.
    # ----------------------------------------------------------------------------------
    """
    ##############################################################################

    testset_name = 'Set3C'  # set test set,  'set5' | 'srbsd68'
    noise_level_img = 3  # set noise level of image, from [3, 25], set 3 for clean image
    model_name = 'drunet_color'  # 'ircnn_color'         # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
    sf = 2  # set scale factor, 1, 2, 3, 4
    iter_num = 24  # set number of iterations, default: 24 for SISR

    # --------------------------------
    # set blur kernel
    # --------------------------------
    kernel_width_default_x1234 = [
        0.6, 0.9, 1.7, 2.2
    ]  # Gaussian kernel widths for x1, x2, x3, x4
    noise_level_model = noise_level_img / 255.  # noise level of model
    kernel_width = kernel_width_default_x1234[sf - 1]
    """
    # set your own kernel width !!!!!!!!!!
    """
    # kernel_width = 1.0

    k = utils_deblur.fspecial('gaussian', 25, kernel_width)
    k = sr.shift_pixel(k, sf)  # shift the kernel
    k /= np.sum(k)

    ##############################################################################

    show_img = False
    util.surf(k) if show_img else None
    x8 = True  # default: False, x8 to boost performance
    modelSigma1 = 49  # set sigma_1, default: 49
    modelSigma2 = max(sf, noise_level_model * 255.)
    classical_degradation = True  # set classical degradation or bicubic degradation

    task_current = 'sr'  # 'sr' for super-resolution
    n_channels = 1 if 'gray' in model_name else 3  # fixed
    model_zoo = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_realapplications_' + task_current + '_' + model_name
    model_path = os.path.join(model_zoo, model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------
    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------
    if 'drunet' in model_name:
        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels + 1,
                    out_nc=n_channels,
                    nc=[64, 128, 256, 512],
                    nb=4,
                    act_mode='R',
                    downsample_mode="strideconv",
                    upsample_mode="convtranspose")
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
    elif 'ircnn' in model_name:
        from models.network_dncnn import IRCNN as net
        model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
        model25 = torch.load(model_path)
        former_idx = 0

    logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info('Model path: {:s}'.format(model_path))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    for idx, img in enumerate(L_paths):

        # --------------------------------
        # (1) get img_L
        # --------------------------------
        logger.info('Model path: {:s} Image: {:s}'.format(model_path, img))
        img_name, ext = os.path.splitext(os.path.basename(img))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)
        img_L = util.modcrop(img_L, 8)  # modcrop

        # --------------------------------
        # (2) get rhos and sigmas
        # --------------------------------
        rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255 / 255.,
                                                   noise_level_model),
                                         iter_num=iter_num,
                                         modelSigma1=modelSigma1,
                                         modelSigma2=modelSigma2,
                                         w=1)
        rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(
            device)

        # --------------------------------
        # (3) initialize x, and pre-calculation
        # --------------------------------
        x = cv2.resize(img_L, (img_L.shape[1] * sf, img_L.shape[0] * sf),
                       interpolation=cv2.INTER_CUBIC)

        if np.ndim(x) == 2:
            x = x[..., None]

        if classical_degradation:
            x = sr.shift_pixel(x, sf)
        x = util.single2tensor4(x).to(device)

        img_L_tensor, k_tensor = util.single2tensor4(
            img_L), util.single2tensor4(np.expand_dims(k, 2))
        [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor],
                                                 device)
        FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

        # --------------------------------
        # (4) main iterations
        # --------------------------------
        for i in range(iter_num):

            print('Iter: {} / {}'.format(i, iter_num))

            # --------------------------------
            # step 1, FFT
            # --------------------------------
            tau = rhos[i].float().repeat(1, 1, 1, 1)
            x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)

            if 'ircnn' in model_name:
                current_idx = np.int(
                    np.ceil(sigmas[i].cpu().numpy() * 255. / 2.) - 1)

                if current_idx != former_idx:
                    model.load_state_dict(model25[str(current_idx)],
                                          strict=True)
                    model.eval()
                    for _, v in model.named_parameters():
                        v.requires_grad = False
                    model = model.to(device)
                former_idx = current_idx

            # --------------------------------
            # step 2, denoiser
            # --------------------------------
            if x8:
                x = util.augment_img_tensor4(x, i % 8)

            if 'drunet' in model_name:
                x = torch.cat(
                    (x, sigmas[i].repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
                x = utils_model.test_mode(model,
                                          x,
                                          mode=2,
                                          refield=64,
                                          min_size=256,
                                          modulo=16)
            elif 'ircnn' in model_name:
                x = model(x)

            if x8:
                if i % 8 == 3 or i % 8 == 5:
                    x = util.augment_img_tensor4(x, 8 - i % 8)
                else:
                    x = util.augment_img_tensor4(x, i % 8)

        # --------------------------------
        # (3) img_E
        # --------------------------------
        img_E = util.tensor2uint(x)
        util.imsave(
            img_E,
            os.path.join(E_path, img_name + '_x' + str(sf) + '_' + model_name +
                         '.png'))
Exemple #9
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 0 / 255.0  # set AWGN noise level for LR image, default: 0
    noise_level_model = noise_level_img  # set noise level of model, default: 0
    model_name = 'ircnn_color'  # set denoiser, 'drunet_color' | 'ircnn_color'
    testset_name = 'Set18'  # set testing set,  'set18' | 'set24'
    x8 = True  # set PGSE to boost performance, default: True
    iter_num = 40  # set number of iterations, default: 40 for demosaicing
    modelSigma1 = 49  # set sigma_1, default: 49
    modelSigma2 = max(0.6, noise_level_model * 255.)  # set sigma_2, default
    matlab_init = True

    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image
    save_LEH = False  # save zoomed LR, E and H images
    border = 10  # default 10 for demosaicing

    task_current = 'dm'  # 'dm' for demosaicing
    n_channels = 3  # fixed
    model_zoo = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + task_current + '_' + model_name
    model_path = os.path.join(model_zoo, model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    if 'drunet' in model_name:
        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels + 1,
                    out_nc=n_channels,
                    nc=[64, 128, 256, 512],
                    nb=4,
                    act_mode='R',
                    downsample_mode="strideconv",
                    upsample_mode="convtranspose")
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
    elif 'ircnn' in model_name:
        from models.network_dncnn import IRCNN as net
        model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
        model25 = torch.load(model_path)
        former_idx = 0

    logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info('Model path: {:s}'.format(model_path))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    test_results = OrderedDict()
    test_results['psnr'] = []

    for idx, img in enumerate(L_paths):

        # --------------------------------
        # (1) get img_H and img_L
        # --------------------------------

        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        img_H = util.imread_uint(img, n_channels=n_channels)
        CFA, CFA4, mosaic, mask = utils_mosaic.mosaic_CFA_Bayer(img_H)

        # --------------------------------
        # (2) initialize x
        # --------------------------------

        if matlab_init:  # matlab demosaicing for initialization
            CFA4 = util.uint2tensor4(CFA4).to(device)
            x = utils_mosaic.dm_matlab(CFA4)
        else:
            x = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA)
            x = util.uint2tensor4(x).to(device)

        img_L = util.tensor2uint(x)
        y = util.uint2tensor4(mosaic).to(device)

        util.imshow(img_L) if show_img else None
        mask = util.single2tensor4(mask.astype(np.float32)).to(device)

        # --------------------------------
        # (3) get rhos and sigmas
        # --------------------------------

        rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255 / 255.,
                                                   noise_level_img),
                                         iter_num=iter_num,
                                         modelSigma1=modelSigma1,
                                         modelSigma2=modelSigma2,
                                         w=1.0)
        rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(
            device)

        # --------------------------------
        # (4) main iterations
        # --------------------------------

        for i in range(iter_num):

            # --------------------------------
            # step 1, closed-form solution
            # --------------------------------

            x = (y + rhos[i].float() * x).div(mask + rhos[i])

            # --------------------------------
            # step 2, denoiser
            # --------------------------------

            if 'ircnn' in model_name:
                current_idx = np.int(
                    np.ceil(sigmas[i].cpu().numpy() * 255. / 2.) - 1)
                if current_idx != former_idx:
                    model.load_state_dict(model25[str(current_idx)],
                                          strict=True)
                    model.eval()
                    for _, v in model.named_parameters():
                        v.requires_grad = False
                    model = model.to(device)
                former_idx = current_idx

            x = torch.clamp(x, 0, 1)
            if x8:
                x = util.augment_img_tensor4(x, i % 8)

            if 'drunet' in model_name:
                x = torch.cat((x, sigmas[i].float().repeat(
                    1, 1, x.shape[2], x.shape[3])),
                              dim=1)
                x = utils_model.test_mode(model,
                                          x,
                                          mode=2,
                                          refield=32,
                                          min_size=256,
                                          modulo=16)
                # x = model(x)
            elif 'ircnn' in model_name:
                x = model(x)

            if x8:
                if i % 8 == 3 or i % 8 == 5:
                    x = util.augment_img_tensor4(x, 8 - i % 8)
                else:
                    x = util.augment_img_tensor4(x, i % 8)

        x[mask.to(torch.bool)] = y[mask.to(torch.bool)]

        # --------------------------------
        # (4) img_E
        # --------------------------------

        img_E = util.tensor2uint(x)
        psnr = util.calculate_psnr(img_E, img_H, border=border)
        test_results['psnr'].append(psnr)
        logger.info('{:->4d}--> {:>10s} -- PSNR: {:.2f}dB'.format(
            idx, img_name + ext, psnr))

        if save_E:
            util.imsave(
                img_E,
                os.path.join(E_path, img_name + '_' + model_name + '.png'))

        if save_L:
            util.imsave(img_L, os.path.join(E_path, img_name + '_L.png'))

        if save_LEH:
            util.imsave(
                np.concatenate([img_L, img_E, img_H], axis=1),
                os.path.join(E_path, img_name + model_name + '_LEH.png'))

    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    logger.info('------> Average PSNR(RGB) of ({}) is : {:.2f} dB'.format(
        testset_name, ave_psnr))
Exemple #10
0
def main():
    #0. global config
    sf = 4
    stage = 8
    patch_size = [32, 32]
    patch_num = [2, 2]

    #1. local PSF
    all_PSFs = load_kernels('./data')

    #2. load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(torch.load('./data/uabcnet_final.pth'), strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    #3. set up discriminator
    model_D = gan.PatchDiscriminator(5)
    model_D = model_D.to(device)

    gan_loss = gan.GANLoss(mode='lsgan')
    gan_loss = gan_loss.to(device)
    fake_images = ImagePool(16)

    #positional lambda, mu for HQS.
    ab_buffer = np.zeros((patch_num[0], patch_num[1], 2 * stage, 3))
    ab_buffer[:, :, ::2, :] = 0.01
    ab_buffer[:, :, 1::2, :] = 0.1
    ab = torch.tensor(ab_buffer,
                      dtype=torch.float32,
                      device=device,
                      requires_grad=True)
    params = []
    params += [{"params": [ab], "lr": 5e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-5}]

    #
    params_D = []
    params_D += list(model_D.parameters())

    optimizer = torch.optim.Adam(params, lr=1e-4, betas=(0.9, 0.999))
    optimizer_D = torch.optim.Adam(params_D, lr=1e-4, betas=(0.9, 0.999))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    #3.load training data
    imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png',
                       recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 200000

    PSF_grid = draw_random_kernel(all_PSFs)

    for i in range(N_maxiter):

        t0 = time.time()
        img_idx = np.random.randint(len(imgs_H))
        img_H = cv2.imread(imgs_H[img_idx])

        #draw random kernel

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)
        t_data = time.time() - t0

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        loss_l1 = F.l1_loss(x_E, x_gt)
        loss_gan = gan_loss(model_D(x_E), True)
        loss = loss_l1 + loss_gan
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_real = model_D(x_gt)
        loss_D_real = gan_loss(pred_real, True)
        fake = fake_images.query(x_E)
        pred_fake = model_D(fake.detach())
        loss_D_fake = gan_loss(pred_fake, False)
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        scheduler.step()

        t_iter = time.time() - t0 - t_data

        print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.
              format(global_iter + 1, loss.item(), t_data, t_iter))

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        patch_E = util.tensor2uint((x_E))
        show = np.hstack((patch_H, patch_L, patch_E))
        cv2.imshow('H,L,E', show)
        key = cv2.waitKey(1)
        global_iter += 1

        if key == ord('q'):
            break

    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './data/uabcnet_finetune.pth')
    np.savetxt('./data/ab_finetune.txt', ab_numpy)
Exemple #11
0
def main():
    #0. global config
    #scale factor
    sf = 4
    stage = 8
    patch_size = [32, 32]
    patch_num = [3, 3]

    #1. local PSF
    #shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    #2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    #model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    #positional lambda, mu for HQS, set as free trainable parameters here.
    ab_buffer = np.ones(
        (patch_num[0], patch_num[1], 2 * stage, 3), dtype=np.float32) * 0.1
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    params = []
    params += [{"params": [ab], "lr": 0.0005}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 0.0001}]
    optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    #3.load training data
    imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png',
                       recursive=True)
    imgs_L = glob.glob('/home/xiu/databag/deblur/images/DIV2K_lr/*.png',
                       recursive=True)
    imgs_H.sort()
    imgs_L.sort()

    global_iter = 0
    N_maxiter = 200000

    #def get_train_pairs()

    for i in range(N_maxiter):

        t0 = time.time()
        #draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])

        #img2 = imgs_L[img_idx]
        #img_L = cv2.imread(img2)
        #draw random patch from image
        #a. without img_L

        #draw random kernel
        PSF_grid = draw_random_kernel(all_PSFs, patch_num)

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)
        #b.	with img_L
        #patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L)
        t_data = time.time() - t0

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        loss = F.l1_loss(x_E, x_gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        t_iter = time.time() - t0 - t_data

        print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.
              format(global_iter + 1, loss.item(), t_data, t_iter))

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        #patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf]
        patch_E = util.tensor2uint((x_E))
        show = np.hstack((patch_H, patch_L, patch_E))
        cv2.imshow('H,L,E', show)
        key = cv2.waitKey(1)
        global_iter += 1

        # for logging model weight.
        # if global_iter % 100 ==0:
        # 	torch.save(model.state_dict(),'./logs/uabcnet_{}.pth'.format(global_iter))

        if key == ord('q'):
            break
        if key == ord('s'):
            torch.save(model.state_dict(), './logs/uabcnet.pth')

    torch.save(model.state_dict(), './logs/uabcnet.pth')
Exemple #12
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 0/255.0            # set AWGN noise level for LR image, default: 0, 
    noise_level_model = noise_level_img  # setnoise level of model, default 0
    model_name = 'drunet_color'          # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
    testset_name = 'srbsd68'             # set test set,  'set5' | 'srbsd68'
    x8 = True                            # default: False, x8 to boost performance
    test_sf = [2]                        # set scale factor, default: [2, 3, 4], [2], [3], [4]
    iter_num = 24                        # set number of iterations, default: 24 for SISR
    modelSigma1 = 49                     # set sigma_1, default: 49
    classical_degradation = True         # set classical degradation or bicubic degradation

    show_img = False                     # default: False
    save_L = True                        # save LR image
    save_E = True                        # save estimated image
    save_LEH = False                     # save zoomed LR, E and H images

    task_current = 'sr'                  # 'sr' for super-resolution
    n_channels = 1 if 'gray' in model_name else 3  # fixed
    model_zoo = 'model_zoo'              # fixed
    testsets = 'testsets'                # fixed
    results = 'results'                  # fixed
    result_name = testset_name + '_' + task_current + '_' + model_name
    model_path = os.path.join(model_zoo, model_name+'.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)   # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    if 'drunet' in model_name:
        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
    elif 'ircnn' in model_name:
        from models.network_dncnn import IRCNN as net
        model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
        model25 = torch.load(model_path)
        former_idx = 0

    logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
    logger.info('Model path: {:s}'.format(model_path))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    # --------------------------------
    # load kernel
    # --------------------------------

    # kernels = hdf5storage.loadmat(os.path.join('kernels', 'Levin09.mat'))['kernels']
    if classical_degradation:
        kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels_12.mat'))['kernels']
    else:
        kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']

    test_results_ave = OrderedDict()
    test_results_ave['psnr_sf_k'] = []
    test_results_ave['psnr_y_sf_k'] = []

    for sf in test_sf:
        border = sf
        modelSigma2 = max(sf, noise_level_model*255.)
        k_num = 8 if classical_degradation else 1

        for k_index in range(k_num):
            logger.info('--------- sf:{:>1d} --k:{:>2d} ---------'.format(sf, k_index))
            test_results = OrderedDict()
            test_results['psnr'] = []
            test_results['psnr_y'] = []

            if not classical_degradation:  # for bicubic degradation
                k_index = sf-2
            k = kernels[0, k_index].astype(np.float64)

            util.surf(k) if show_img else None

            for idx, img in enumerate(L_paths):

                # --------------------------------
                # (1) get img_L
                # --------------------------------

                img_name, ext = os.path.splitext(os.path.basename(img))
                img_H = util.imread_uint(img, n_channels=n_channels)
                img_H = util.modcrop(img_H, sf)  # modcrop

                if classical_degradation:
                    img_L = sr.classical_degradation(img_H, k, sf)
                    util.imshow(img_L) if show_img else None
                    img_L = util.uint2single(img_L)
                else:
                    img_L = util.imresize_np(util.uint2single(img_H), 1/sf)

                np.random.seed(seed=0)  # for reproducibility
                img_L += np.random.normal(0, noise_level_img, img_L.shape) # add AWGN

                # --------------------------------
                # (2) get rhos and sigmas
                # --------------------------------

                rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
                rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)

                # --------------------------------
                # (3) initialize x, and pre-calculation
                # --------------------------------

                x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC)
                if np.ndim(x)==2:
                    x = x[..., None]

                if classical_degradation:
                    x = sr.shift_pixel(x, sf)
                x = util.single2tensor4(x).to(device)

                img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2))
                [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device)
                FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

                # --------------------------------
                # (4) main iterations
                # --------------------------------

                for i in range(iter_num):

                    # --------------------------------
                    # step 1, FFT
                    # --------------------------------

                    tau = rhos[i].float().repeat(1, 1, 1, 1)
                    x = sr.data_solution(x.float(), FB, FBC, F2B, FBFy, tau, sf)

                    if 'ircnn' in model_name:
                        current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)
            
                        if current_idx != former_idx:
                            model.load_state_dict(model25[str(current_idx)], strict=True)
                            model.eval()
                            for _, v in model.named_parameters():
                                v.requires_grad = False
                            model = model.to(device)
                        former_idx = current_idx

                    # --------------------------------
                    # step 2, denoiser
                    # --------------------------------

                    if x8:
                        x = util.augment_img_tensor4(x, i % 8)
                        
                    if 'drunet' in model_name:
                        x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
                        x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16)
                    elif 'ircnn' in model_name:
                        x = model(x)

                    if x8:
                        if i % 8 == 3 or i % 8 == 5:
                            x = util.augment_img_tensor4(x, 8 - i % 8)
                        else:
                            x = util.augment_img_tensor4(x, i % 8)

                # --------------------------------
                # (3) img_E
                # --------------------------------

                img_E = util.tensor2uint(x)

                if save_E:
                    util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_'+model_name+'.png'))

                if n_channels == 1:
                    img_H = img_H.squeeze()

                # --------------------------------
                # (4) img_LEH
                # --------------------------------

                img_L = util.single2uint(img_L).squeeze()

                if save_LEH:
                    k_v = k/np.max(k)*1.0
                    if n_channels==1:
                        k_v = util.single2uint(k_v)
                    else:
                        k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, n_channels]))
                    k_v = cv2.resize(k_v, (3*k_v.shape[1], 3*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
                    img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
                    img_I[:k_v.shape[0], -k_v.shape[1]:, ...] = k_v
                    img_I[:img_L.shape[0], :img_L.shape[1], ...] = img_L
                    util.imshow(np.concatenate([img_I, img_E, img_H], axis=1), title='LR / Recovered / Ground-truth') if show_img else None
                    util.imsave(np.concatenate([img_I, img_E, img_H], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LEH.png'))

                if save_L:
                    util.imsave(img_L, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LR.png'))

                psnr = util.calculate_psnr(img_E, img_H, border=border)
                test_results['psnr'].append(psnr)
                logger.info('{:->4d}--> {:>10s} -- sf:{:>1d} --k:{:>2d} PSNR: {:.2f}dB'.format(idx+1, img_name+ext, sf, k_index, psnr))

                if n_channels == 3:
                    img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                    img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                    psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                    test_results['psnr_y'].append(psnr_y)

            # --------------------------------
            # Average PSNR for all kernels
            # --------------------------------

            ave_psnr_k = sum(test_results['psnr']) / len(test_results['psnr'])
            logger.info('------> Average PSNR(RGB) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_k))
            test_results_ave['psnr_sf_k'].append(ave_psnr_k)

            if n_channels == 3:  # RGB image
                ave_psnr_y_k = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                logger.info('------> Average PSNR(Y) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_y_k))
                test_results_ave['psnr_y_sf_k'].append(ave_psnr_y_k)

    # ---------------------------------------
    # Average PSNR for all sf and kernels
    # ---------------------------------------

    ave_psnr_sf_k = sum(test_results_ave['psnr_sf_k']) / len(test_results_ave['psnr_sf_k'])
    logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_sf_k))
    if n_channels == 3:
        ave_psnr_y_sf_k = sum(test_results_ave['psnr_y_sf_k']) / len(test_results_ave['psnr_y_sf_k'])
        logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_y_sf_k))
Exemple #13
0
def main():
	#0. global config
	#scale factor
	sf = 4	
	stage = 5
	patch_size = [32,32]
	patch_num = [2,2]

	#1. local PSF
	#shape: gx,gy,kw,kw,3
	all_PSFs = load_kernels('./data')


	#2. local model
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = net(n_iter=5, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=2,sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
	model.load_state_dict(torch.load('./logs/uabcnet_final.pth'),strict=True)
	model.train()
	for _, v in model.named_parameters():
		v.requires_grad = True
	model = model.to(device)

	#positional lambda, mu for HQS, set as free trainable parameters here.

	#ab_buffer = np.loadtxt('./data/ab.txt').reshape((patch_num[0],patch_num[1],2*stage,3)).astype(np.float32)
	ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape((1,1,2*stage,3)).astype(np.float32)

	ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32)
	for xx in range(patch_num[0]):
		for yy in range(patch_num[1]):
			ab_buffer[xx,yy] = ab_pretrain[0,0]

	ab = torch.tensor(ab_buffer,device=device,requires_grad=True)
	params = []
	params += [{"params":[ab],"lr":0.0001}]
	for key,value in model.named_parameters():
		params += [{"params":[value],"lr":1e-6}]

	optimizer = torch.optim.Adam(params,lr=0.0001,betas=(0.9,0.999))
	scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1000,gamma=0.9)

	#3.load training data
	imgs_H = glob.glob('./DIV2K_train/*.png',recursive=True)
	imgs_H.sort()

	global_iter = 0

	all_PSNR = []
	N_maxiter = 4000

	PSF_grid = draw_random_kernel(all_PSFs,patch_num)
	#def get_train_pairs()

	for i in range(N_maxiter):

		t0 = time.time()
		#draw random image.
		img_idx = np.random.randint(len(imgs_H))

		img_H = cv2.imread(imgs_H[img_idx])

		patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size)
		#b.	with img_L
		#patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L)
		t_data = time.time()-t0

		x = util.uint2single(patch_L)
		x = util.single2tensor4(x)
		x_gt = util.uint2single(patch_H)
		x_gt = util.single2tensor4(x_gt)

		k_local = []
		for h_ in range(patch_num[1]):
			for w_ in range(patch_num[0]):
				k_local.append(util.single2tensor4(patch_psf[w_,h_]))
		k = torch.cat(k_local,dim=0)
		[x,x_gt,k] = [el.to(device) for el in [x,x_gt,k]]
		
		ab_patch = F.softplus(ab)
		ab_patch_v = []
		for h_ in range(patch_num[1]):
			for w_ in range(patch_num[0]):
				ab_patch_v.append(ab_patch[w_:w_+1,h_])
		ab_patch_v = torch.cat(ab_patch_v,dim=0)

		x_E = model.forward_patchwise_SR(x,k,ab_patch_v,patch_num,[patch_size[0],patch_size[1]],sf)

		loss = F.l1_loss(x_E,x_gt)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		scheduler.step()

		t_iter = time.time() - t0 - t_data

		print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.format(global_iter+1,loss.item(),t_data,t_iter))

		patch_L = cv2.resize(patch_L,dsize=None,fx=sf,fy=sf,interpolation=cv2.INTER_NEAREST)
		#patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf]
		patch_E = util.tensor2uint((x_E))
		show = np.hstack((patch_H,patch_L,patch_E))
		cv2.imshow('H,L,E',show)
		key = cv2.waitKey(1)
		global_iter+= 1

		if i % 1000 ==0:
			cv2.imwrite(os.path.join('./result', 'test' , 'resultE-{:04d}.png'.format(i + 1)), patch_E)
			cv2.imwrite(os.path.join('./result', 'test', 'resultL-{:04d}.png'.format(i + 1)), patch_L)
			cv2.imwrite(os.path.join('./result', 'test', 'resultH-{:04d}.png'.format(i + 1)), patch_H)

		# if key==ord('q'):
		# 	break
		# if key==ord('s'):
		# 	ab_numpy = ab.detach().cpu().numpy().flatten()
		# 	np.savetxt('./data/ab.txt',ab_numpy)


	ab_numpy = ab.detach().cpu().numpy().flatten()
	torch.save(model.state_dict(),'./data/uabcnet_finetune.pth')
	np.savetxt('./data/ab_finetune.txt',ab_numpy)
Exemple #14
0
def main():
    #0. global config
    #scale factor
    sf = 4
    stage = 5
    patch_size = [32, 32]
    patch_num = [2, 2]

    #1. local PSF
    #shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    #2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=5,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    #loaded_state_dict=  torch.load('./data/uabcnet_final.pth')
    loaded_state_dict = torch.load('./data/uabcnet_finetune.pth')
    model.load_state_dict(loaded_state_dict, strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    #positional lambda, mu for HQS, set as free trainable parameters here.
    ab_buffer = np.loadtxt('./data/ab_finetune.txt').reshape(
        (patch_num[0], patch_num[1], 2 * stage, 3)).astype(np.float32)
    #ab[2x2,2*stage,3]

    #ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32)*0.1
    ab = torch.tensor(ab_buffer, device=device, requires_grad=False)
    ab = F.softplus(ab)

    #3.load training data
    imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 1000

    PSF_grid = using_AC254_lens(all_PSFs, patch_num)

    all_PSNR = []
    out_folder = 'finetune'

    for i in range(N_maxiter):

        #draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])
        img_H = np.pad(img_H, [(12, 12), (12, 12), (0, 0)])
        croppatch = imgpatch(img_H, 280, 280, 24)
        patches = croppatch.crop(img_H, 1)

        patch_E_list = []
        patch_L_list = []
        # patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size)

        for piece in range(len(patches)):
            patch_L, patch_H, patch_psf = draw_testing_pair(
                patches[piece], PSF_grid, sf, patch_num, patch_size)

            x = util.uint2single(patch_L)
            x = util.single2tensor4(x)
            x_gt = util.uint2single(patch_H)
            x_gt = util.single2tensor4(x_gt)

            k_local = []
            for h_ in range(patch_num[1]):
                for w_ in range(patch_num[0]):
                    k_local.append(util.single2tensor4(patch_psf[w_, h_]))
            k = torch.cat(k_local, dim=0)
            [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

            ab_patch = F.softplus(ab)
            ab_patch_v = []
            for h_ in range(patch_num[1]):
                for w_ in range(patch_num[0]):
                    ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
            ab_patch_v = torch.cat(ab_patch_v, dim=0)

            x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                             [patch_size[0], patch_size[1]],
                                             sf)

            patch_L = cv2.resize(patch_L,
                                 dsize=None,
                                 fx=sf,
                                 fy=sf,
                                 interpolation=cv2.INTER_NEAREST)
            patch_E = util.tensor2uint((x_E))

            patch_E_list.append(patch_E[np.newaxis, :])
            patch_L_list.append((patch_L[None, ...]))
            print(piece)
        img_E = croppatch.merge(patch_E_list)
        # croppatch_E = imgpatch(np.zeros_like(img_H), 256, 256, 0)
        # croppatch_E.crop(np.zeros_like(img_H),1)
        # img_E = croppatch_E.merge(patch_E_list)

        psnr, ssim = cal_psnrssim(img_E, img_H, 255)
        print(psnr)
        print(ssim)

        cv2.imwrite(
            os.path.join('./result', out_folder,
                         'resultE-{:04d}.png'.format(i + 1)), img_E)
        cv2.imwrite(
            os.path.join('./result', out_folder,
                         'resultH-{:04d}.png'.format(i + 1)), img_H)

        all_PSNR.append(psnr)

        #show = np.hstack((patch_H,patch_L,patch_E))

    np.savetxt(os.path.join('./result', out_folder, 'psnr.txt'), all_PSNR)
Exemple #15
0
def main():

    utils_logger.logger_info('efficientsr_challenge',
                             log_path='efficientsr_challenge.log')
    logger = logging.getLogger('efficientsr_challenge')

    #    print(torch.__version__)               # pytorch version
    #    print(torch.version.cuda)              # cuda version
    #    print(torch.backends.cudnn.version())  # cudnn version

    # --------------------------------
    # basic settings
    # --------------------------------
    model_names = ['msrresnet', 'imdn']
    model_id = 1  # set the model name
    model_name = model_names[model_id]
    logger.info('{:>16s} : {:s}'.format('Model Name', model_name))

    testsets = 'testsets'  # set path of testsets
    testset_L = 'DIV2K_valid_LR'  # set current testing dataset; 'DIV2K_test_LR'
    testset_L = 'set12'

    save_results = True
    print_modelsummary = True  # set False when calculating `Max Memery` and `Runtime`

    torch.cuda.set_device(0)  # set GPU ID
    logger.info('{:>16s} : {:<d}'.format('GPU ID',
                                         torch.cuda.current_device()))
    torch.cuda.empty_cache()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # define network and load model
    # --------------------------------
    if model_name == 'msrresnet':
        from models.network_msrresnet import MSRResNet1 as net
        model = net(in_nc=3, out_nc=3, nc=64, nb=16,
                    upscale=4)  # define network
        model_path = os.path.join('model_zoo',
                                  'msrresnet_x4_psnr.pth')  # set model path
    elif model_name == 'imdn':
        from models.network_imdn import IMDN as net
        model = net(in_nc=3,
                    out_nc=3,
                    nc=64,
                    nb=8,
                    upscale=4,
                    act_mode='L',
                    upsample_mode='pixelshuffle')  # define network
        model_path = os.path.join('model_zoo', 'imdn_x4.pth')  # set model path

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # --------------------------------
    # print model summary
    # --------------------------------
    if print_modelsummary:
        from utils.utils_modelsummary import get_model_activation, get_model_flops
        input_dim = (3, 256, 256)  # set the input dimension

        activations, num_conv2d = get_model_activation(model, input_dim)
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Activations',
                                                   activations / 10**6))
        logger.info('{:>16s} : {:<d}'.format('#Conv2d', num_conv2d))

        flops = get_model_flops(model, input_dim, False)
        logger.info('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops / 10**9))

        num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Params',
                                                   num_parameters / 10**6))

    # --------------------------------
    # read image
    # --------------------------------
    L_path = os.path.join(testsets, testset_L)
    E_path = os.path.join(testsets, testset_L + '_' + model_name)
    util.mkdir(E_path)

    # record runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
    logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_path):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name + ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        torch.cuda.empty_cache()
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        # logger.info('{:>16s} : {:<.3f} [M]'.format('Max Memery', torch.cuda.max_memory_allocated(torch.cuda.current_device())/1024**2))  # Memery
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        #        torch.cuda.synchronize()
        #        start = time.time()
        #        img_E = model(img_L)
        #        torch.cuda.synchronize()
        #        end = time.time()
        #        test_results['runtime'].append(end-start)  # seconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save_results:
            util.imsave(img_E, os.path.join(E_path, img_name + ext))
    ave_runtime = sum(test_results['runtime']) / len(
        test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(
        L_path, ave_runtime))
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    noise_level_img = 50             # noise level for noisy image
    model_name = 'ircnn_gray'        # 'ircnn_gray' | 'ircnn_color'
    testset_name = 'set12'          # test set, 'bsd68' | 'set12'
    need_degradation = True          # default: True
    x8 = False                       # default: False, x8 to boost performance
    show_img = False                 # default: False
    current_idx = min(24, np.int(np.ceil(noise_level_img/2)-1)) # current_idx+1 th denoiser


    task_current = 'dn'       # fixed, 'dn' for denoising | 'sr' for super-resolution
    sf = 1                    # unused for denoising
    if 'color' in model_name:
        n_channels = 3        # fixed, 1 for grayscale image, 3 for color image 
    else:
        n_channels = 1        # fixed for grayscale image 

    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'     # fixed
    results = 'results'       # fixed
    result_name = testset_name + '_' + model_name     # fixed
    border = sf if task_current == 'sr' else 0        # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name+'.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------
    L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
    H_path = L_path                               # H_path, for High-quality images
    E_path = os.path.join(results, result_name)   # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    model25 = torch.load(model_path)
    from models.network_dncnn import IRCNN as net
    model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
    model.load_state_dict(model25[str(current_idx)], strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []

    logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------
        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        if need_degradation:  # degradation process
            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img/255., img_L.shape)

        util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------
        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------
            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------
            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None

        # ------------------------------------
        # save results
        # ------------------------------------
        util.imsave(img_E, os.path.join(E_path, img_name+ext))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))
Exemple #17
0
def main():

    utils_logger.logger_info('AIM-track', log_path='AIM-track.log')
    logger = logging.getLogger('AIM-track')

    # --------------------------------
    # basic settings
    # --------------------------------
    testsets = 'DIV2K'
    testset_L = 'DIV2K_valid_LR_bicubic'
    #testset_L = 'DIV2K_test_LR_bicubic'

    torch.cuda.current_device()
    torch.cuda.empty_cache()
    #torch.backends.cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # load model
    # --------------------------------
    model_path = os.path.join('trained_model', 'RFDN_AIM.pth')
    model = RFDN()
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # number of parameters
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    # --------------------------------
    # read image
    # --------------------------------
    L_folder = os.path.join(testsets, testset_L, 'X4')
    E_folder = os.path.join(testsets, testset_L+'_results')
    util.mkdir(E_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(E_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    img_SR = []
    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)
        img_SR.append(img_E)

        # --------------------------------
        # (3) save results
        # --------------------------------
        #util.imsave(img_E, os.path.join(E_folder, img_name+ext))

    ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))

    # --------------------------------
    # (4) calculate psnr
    # --------------------------------
    '''
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrnet'      # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set_real'  # test set,  'set_real'
    test_image = 'chip.png'    # 'chip.png', 'comic.png'
    #test_image = 'comic.png'

    sf = 4                     # scale factor, only from {1, 2, 3, 4}
    show_img = False           # default: False
    save_E = True              # save estimated image
    save_LE = True             # save zoomed LR, Estimated images

    # ----------------------------------------
    # set noise level and kernel
    # ----------------------------------------
    if 'chip' in test_image:
        noise_level_img = 15       # noise level for LR image, 15 for chip
        kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
    else:
        noise_level_img = 2       # noise level for LR image, 0.5~3 for clean images
        kernel_width_default_x1234 = [0.4, 0.7, 1.5, 2.0] # default Gaussian kernel widths of clean/sharp images for x1, x2, x3, x4

    noise_level_model = noise_level_img/255.  # noise level of model
    kernel_width = kernel_width_default_x1234[sf-1]

    # set your own kernel width
    # kernel_width = 2.2

    k = utils_deblur.fspecial('gaussian', 25, kernel_width)
    k = sr.shift_pixel(k, sf)  # shift the kernel
    k /= np.sum(k)
    util.surf(k) if show_img else None
    # scio.savemat('kernel_realapplication.mat', {'kernel':k})

    # load approximated bicubic kernels
    #kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']
#    kernels = loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']
#    kernel = kernels[0, sf-2].astype(np.float64)

    kernel = util.single2tensor4(k[..., np.newaxis])


    n_channels = 1 if 'gray' in  model_name else 3  # 3 for color image, 1 for grayscale image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'     # fixed
    results = 'results'       # fixed
    result_name = testset_name + '_' + model_name
    model_path = os.path.join(model_pool, model_name+'.pth')

    # ----------------------------------------
    # L_path, E_path
    # ----------------------------------------
    L_path = os.path.join(testsets, testset_name) # L_path, fixed, for Low-quality images
    E_path = os.path.join(results, result_name)   # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    if 'tiny' in model_name:
        model = net(n_iter=6, h_nc=32, in_nc=4, out_nc=3, nc=[16, 32, 64, 64],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
    else:
        model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for key, v in model.named_parameters():
        v.requires_grad = False

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
    logger.info(L_path)

    img = os.path.join(L_path, test_image)
    # ------------------------------------
    # (1) img_L
    # ------------------------------------
    img_name, ext = os.path.splitext(os.path.basename(img))
    img_L = util.imread_uint(img, n_channels=n_channels)
    img_L = util.uint2single(img_L)

    util.imshow(img_L) if show_img else None
    w, h = img_L.shape[:2]
    logger.info('{:>10s}--> ({:>4d}x{:<4d})'.format(img_name+ext, w, h))

    # boundary handling
    boarder = 8     # default setting for kernel size 25x25
    img = cv2.resize(img_L, (sf*h, sf*w), interpolation=cv2.INTER_NEAREST)
    img = utils_deblur.wrap_boundary_liu(img, [int(np.ceil(sf*w/boarder+2)*boarder), int(np.ceil(sf*h/boarder+2)*boarder)])
    img_wrap = sr.downsample_np(img, sf, center=False)
    img_wrap[:w, :h, :] = img_L
    img_L = img_wrap

    util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None

    img_L = util.single2tensor4(img_L)
    img_L = img_L.to(device)

    # ------------------------------------
    # (2) img_E
    # ------------------------------------
    sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
    [img_L, kernel, sigma] = [el.to(device) for el in [img_L, kernel, sigma]]

    img_E = model(img_L, kernel, sf, sigma)

    img_E = util.tensor2uint(img_E)[:sf*w, :sf*h, ...]

    if save_E:
        util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))

    # --------------------------------
    # (3) save img_LE
    # --------------------------------
    if save_LE:
        k_v = k/np.max(k)*1.2
        k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3]))
        k_factor = 3
        k_v = cv2.resize(k_v, (k_factor*k_v.shape[1], k_factor*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_L = util.tensor2uint(img_L)[:w, :h, ...]
        img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_I[:k_v.shape[0], :k_v.shape[1], :] = k_v
        util.imshow(np.concatenate([img_I, img_E], axis=1), title='LR / Recovered') if show_img else None
        util.imsave(np.concatenate([img_I, img_E], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'_LE.png'))
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 0  # default: 0, noise level for LR image
    noise_level_model = noise_level_img  # noise level for model
    model_name = 'srmdnf_x4'  # 'srmd_x2' | 'srmd_x3' | 'srmd_x4' | 'srmdnf_x2' | 'srmdnf_x3' | 'srmdnf_x4'
    testset_name = 'set5'  # test set,  'set5' | 'srbsd68'
    sf = [int(s) for s in re.findall(r'\d+', model_name)][0]  # scale factor
    x8 = False  # default: False, x8 to boost performance
    need_degradation = True  # default: True, use degradation model to generate LR image
    show_img = False  # default: False

    srmd_pca_path = os.path.join('kernels', 'srmd_pca_matlab.mat')
    task_current = 'sr'  # 'dn' for denoising | 'sr' for super-resolution
    n_channels = 3  # fixed
    in_nc = 18 if 'nf' in model_name else 19
    nc = 128  # fixed, number of channels
    nb = 12  # fixed, number of conv layers
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    H_path = L_path  # H_path, for High-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_srmd import SRMD as net
    model = net(in_nc=in_nc,
                out_nc=n_channels,
                nc=nc,
                nb=nb,
                upscale=sf,
                act_mode='R',
                upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    # ----------------------------------------
    # kernel and PCA reduced feature
    # ----------------------------------------

    # kernel = sr.anisotropic_Gaussian(ksize=15, theta=np.pi, l1=4, l2=4)
    kernel = utils_deblur.fspecial('gaussian', 15,
                                   0.01)  # Gaussian kernel, delta kernel 0.01

    P = loadmat(srmd_pca_path)['P']
    degradation_vector = np.dot(P, np.reshape(kernel, (-1), order="F"))
    if 'nf' not in model_name:  # noise-free SR
        degradation_vector = np.append(degradation_vector,
                                       noise_level_model / 255.)
    degradation_vector = torch.from_numpy(degradation_vector).view(
        1, -1, 1, 1).float()

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        # degradation process, blur + bicubic downsampling + Gaussian noise
        if need_degradation:
            img_L = util.modcrop(img_L, sf)
            img_L = sr.srmd_degradation(
                img_L, kernel, sf
            )  # equivalent to bicubic degradation if kernel is a delta kernel
            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img / 255., img_L.shape)

        util.imshow(util.single2uint(img_L),
                    title='LR image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        degradation_map = degradation_vector.repeat(1, 1, img_L.size(-2),
                                                    img_L.size(-1))
        img_L = torch.cat((img_L, degradation_map), dim=1)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf)

        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------

            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()
            img_H = util.modcrop(img_H, sf)

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------

            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

            if np.ndim(img_H) == 3:  # RGB image
                img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + '.png'))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'
            .format(result_name, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'
                .format(result_name, sf, ave_psnr_y, ave_ssim_y))
Exemple #20
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrnet'  # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set5'  # test set,  'set5' | 'srbsd68'
    need_degradation = True  # default: True
    sf = 4  # scale factor, only from {2, 3, 4}
    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image

    # load approximated bicubic kernels
    #kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels_bicubicx234.mat'))['kernels']
    kernels = loadmat(os.path.join('kernels',
                                   'kernels_bicubicx234.mat'))['kernels']
    kernel = kernels[0, sf - 2].astype(np.float64)
    kernel = util.single2tensor4(kernel[..., np.newaxis])

    task_current = 'sr'  # fixed, 'sr' for super-resolution
    n_channels = 3  # fixed, 3 for color image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    noise_level_img = 0  # fixed: 0, noise level for LR image
    noise_level_model = noise_level_img  # fixed, noise level of model, default 0
    result_name = testset_name + '_' + model_name + '_bicubic'
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------
    L_path = os.path.join(
        testsets, testset_name)  # L_path, fixed, for Low-quality images
    H_path = L_path  # H_path, 'None' | L_path, for High-quality images
    E_path = os.path.join(results,
                          result_name)  # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    from models.network_usrnet import USRNet as net  # for pytorch version <= 1.7.1
    # from models.network_usrnet_v1 import USRNet as net  # for pytorch version >=1.8.1

    if 'tiny' in model_name:
        model = net(n_iter=6,
                    h_nc=32,
                    in_nc=4,
                    out_nc=3,
                    nc=[16, 32, 64, 64],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")
    else:
        model = net(n_iter=8,
                    h_nc=64,
                    in_nc=4,
                    out_nc=3,
                    nc=[64, 128, 256, 512],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for key, v in model.named_parameters():
        v.requires_grad = False

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    logger.info('model_name:{}, image sigma:{}'.format(model_name,
                                                       noise_level_img))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx + 1, img_name + ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        # degradation process, bicubic downsampling
        if need_degradation:
            img_L = util.modcrop(img_L, sf)
            img_L = util.imresize_np(img_L, 1 / sf)

            # img_L = util.uint2single(util.single2uint(img_L))
            # np.random.seed(seed=0)  # for reproducibility
            # img_L += np.random.normal(0, noise_level_img/255., img_L.shape)

        w, h = img_L.shape[:2]

        if save_L:
            util.imsave(
                util.single2uint(img_L),
                os.path.join(E_path, img_name + '_LR_x' + str(sf) + '.png'))

        img = cv2.resize(img_L, (sf * h, sf * w),
                         interpolation=cv2.INTER_NEAREST)
        img = utils_deblur.wrap_boundary_liu(img, [
            int(np.ceil(sf * w / 8 + 2) * 8),
            int(np.ceil(sf * h / 8 + 2) * 8)
        ])
        img_wrap = sr.downsample_np(img, sf, center=False)
        img_wrap[:w, :h, :] = img_L
        img_L = img_wrap

        util.imshow(util.single2uint(img_L),
                    title='LR image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------
        sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
        [img_L, kernel,
         sigma] = [el.to(device) for el in [img_L, kernel, sigma]]

        img_E = model(img_L, kernel, sf, sigma)

        img_E = util.tensor2uint(img_E)
        img_E = img_E[:sf * w, :sf * h, :]

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------
            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()
            img_H = util.modcrop(img_H, sf)

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------
            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

            if np.ndim(img_H) == 3:  # RGB image
                img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)

        # ------------------------------------
        # save results
        # ------------------------------------
        if save_E:
            util.imsave(
                img_E,
                os.path.join(
                    E_path,
                    img_name + '_x' + str(sf) + '_' + model_name + '.png'))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'
            .format(result_name, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'
                .format(result_name, sf, ave_psnr_y, ave_ssim_y))
Exemple #21
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = PSF_grid.astype(np.float32)
    gx, gy = PSF_grid.shape[:2]

    k_tensor = []
    for yy in range(gy):
        for xx in range(gx):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))
            k_tensor.append(util.single2tensor4(PSF_grid[xx, yy]))

    k_tensor = torch.cat(k_tensor, dim=0)
    inv_weight = util_deblur.get_inv_spatial_weight(k_tensor)

    # ----------------------------------------
    # load model
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),
                               strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    # ----------------------------------------
    # load training data
    # ----------------------------------------
    imgs = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs.sort()

    # ----------------------------------------
    # positional lambda\mu for HQS
    # ----------------------------------------
    stage = 8
    ab_buffer = np.ones((gx, gy, 2 * stage, 3), dtype=np.float32) * 0.1
    #ab_buffer[:,:,0,:] = 0.01
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    # ----------------------------------------
    # build optimizer
    # ----------------------------------------
    params = []
    params += [{"params": [ab], "lr": 0.0005}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 0.0001}]
    optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    patch_size = [128, 128]
    expand = PSF_grid.shape[2] // 2
    patch_num = [2, 2]

    global_iter = 0

    running = True

    while running:
        #alpha.beta
        img_idx = np.random.randint(len(imgs))
        img = imgs[img_idx]
        img_H = cv2.imread(img)
        w, h = img_H.shape[:2]

        #focus on the edges

        mode = np.random.randint(5)
        px_start = np.random.randint(0, gx - patch_num[0] + 1)
        py_start = np.random.randint(0, gy - patch_num[1] + 1)
        if mode == 0:
            px_start = 0
        if mode == 1:
            px_start = gx - patch_num[0]
        if mode == 2:
            py_start = 0
        if mode == 3:
            py_start = gy - patch_num[1]

        x_start = np.random.randint(
            0, w - patch_size[0] * patch_num[0] - expand * 2 + 1)
        y_start = np.random.randint(
            0, h - patch_size[1] * patch_num[1] - expand * 2 + 1)
        PSF_patch = PSF_grid[px_start:px_start + patch_num[0],
                             py_start:py_start + patch_num[1]]

        patch_H = img_H[x_start:x_start+patch_size[0]*patch_num[0]+expand*2,\
         y_start:y_start+patch_size[1]*patch_num[1]+expand*2]
        patch_L = util_deblur.blockConv2d(patch_H, PSF_patch, expand)

        block_expand = max(patch_size[0] // 8, expand)

        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (patch_size[0] * patch_num[0] + block_expand * 2,
                      patch_size[1] * patch_num[1] + block_expand * 2))
        patch_L_wrap = np.hstack(
            (patch_L_wrap[:, -block_expand:, :],
             patch_L_wrap[:, :patch_size[1] * patch_num[1] + block_expand, :]))
        patch_L_wrap = np.vstack(
            (patch_L_wrap[-block_expand:, :, :],
             patch_L_wrap[:patch_size[0] * patch_num[0] + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand])
        x_gt = util.single2tensor4(x_gt)
        inv_weight_patch = torch.ones_like(x_gt)

        k_local = []

        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                inv_weight_patch[0, 0,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             0]
                inv_weight_patch[0, 1,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             1]
                inv_weight_patch[0, 2,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             2]
                k_local.append(k_tensor[w_ + h_ * patch_num[0]:w_ +
                                        h_ * patch_num[0] + 1])

        k = torch.cat(k_local, dim=0)
        [x, x_gt, k, inv_weight_patch
         ] = [el.to(device) for el in [x, x_gt, k, inv_weight_patch]]
        ab_patch = F.softplus(ab[px_start:px_start + patch_num[0],
                                 py_start:py_start + patch_num[1]])
        cd = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)
        x_E = model.forward_patchwise(x, k, cd, patch_num, patch_size)

        predict = x_E[...,block_expand:block_expand+patch_size[0]*patch_num[0],\
         block_expand:block_expand+patch_size[1]*patch_num[1]]
        loss = F.l1_loss(predict.div(inv_weight_patch),
                         x_gt.div(inv_weight_patch))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print('iter:{},loss {}'.format(global_iter + 1, loss.item()))

        patch_L = patch_L_wrap.astype(np.uint8)
        patch_E = util.tensor2uint(x_E)[block_expand:-block_expand,
                                        block_expand:-block_expand]

        show = np.hstack((patch_H[expand:-expand, expand:-expand],
                          patch_L[block_expand:-block_expand,
                                  block_expand:-block_expand], patch_E))

        cv2.imshow('HL', show)
        key = cv2.waitKey(1)

        global_iter += 1

        #change the save period
        if global_iter % 100 == 0:
            ab_numpy = ab.detach().cpu().numpy().flatten()
            torch.save(
                model.state_dict(),
                './ZEMAX_model/usrnet_ZEMAX_iter{}.pth'.format(global_iter))
            np.savetxt('./ZEMAX_model/ab_ZEMAX_iter{}.txt'.format(global_iter),
                       ab_numpy)
        if key == ord('q'):
            running = False
            break
    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './ZEMAX_model/usrnet_ZEMAX.pth')
    np.savetxt('./ZEMAX_model/ab_ZEMAX.txt', ab_numpy)
Exemple #22
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    #PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = np.load('./data/Heide_PSF_plano_small.npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter800'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/plano/uabcnet_{}.pth'.format(
            model_code))
    #strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    for img_id in range(1, 237):
        #for img_id in range(1,12):
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video1-3/res/2_{:03d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/databag/deblur/ICCV2021/suo_image/{}/AC254-075-A-ML-Zemax(ZMX).bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/ResolutionChart/Reso.bmp')
        img_L = cv2.imread(
            '/home/xiu/databag/deblur/ICCV2021/MPI_data/drain/blurry.jpg')
        img_L = img_L.astype(np.float32)
        img_L = img_L[38:-39, 74:-74]
        img_L = cv2.resize(img_L, dsize=None, fx=0.5, fy=0.5)
        #img_L = np.pad(img_L,((1,1),(61,62),(0,0)),mode='edge')

        W, H = img_L.shape[:2]

        print(gx, gy)
        num_patch = [gx, gy]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/plano/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)

        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        #patch_E_z = np.hstack((patch_E_all[::2]))
        #patch_E_x = np.hstack((patch_E_all[1::2]))

        #patch_E_show = np.vstack((patch_E_z,patch_E_x))
        #if block_expand>0:
        #	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
        #else:
        #	show = np.hstack((patch_L,patch_E))

        #cv2.imshow('stage',patch_E_show)
        #cv2.imshow('HL',show)
        #cv2.imshow('RGB',rgb)
        #key = cv2.waitKey(-1)
        #if key==ord('n'):
        #	break

        t1 = time.time()

        print(t1 - t0)

        # print(i)
        xk = patch_E
        # #zk = zk.astype(np.uint8)
        xk = xk.astype(np.uint8)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/new_image/image/ours-{}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/cap_result/1_{:03d}.png'.format(img_id),xk)
        cv2.imshow('xx', xk)
        cv2.imshow('img_L', patch_L.astype(np.uint8))
        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
Exemple #23
0
def main(json_path='options/train_msrresnet_psnr.json'):
    '''
    # ----------------------------------------
    # Step--1 (prepare opt)
    # ----------------------------------------
    '''

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default=json_path,
                        help='Path to option JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=True)
    util.mkdirs(
        (path for key, path in opt['path'].items() if 'pretrained' not in key))

    # ----------------------------------------
    # update opt
    # ----------------------------------------
    # -->-->-->-->-->-->-->-->-->-->-->-->-->-
    init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'],
                                                         net_type='G')
    opt['path']['pretrained_netG'] = init_path_G
    current_step = init_iter

    border = opt['scale']
    # --<--<--<--<--<--<--<--<--<--<--<--<--<-

    # ----------------------------------------
    # save opt to  a '../option.json' file
    # ----------------------------------------
    option.save(opt)

    # ----------------------------------------
    # return None for missing key
    # ----------------------------------------
    opt = option.dict_to_nonedict(opt)

    # ----------------------------------------
    # configure logger
    # ----------------------------------------
    logger_name = 'train'
    utils_logger.logger_info(
        logger_name, os.path.join(opt['path']['log'], logger_name + '.log'))
    logger = logging.getLogger(logger_name)
    logger.info(option.dict2str(opt))

    # ----------------------------------------
    # seed
    # ----------------------------------------
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    '''
    # ----------------------------------------
    # Step--2 (creat dataloader)
    # ----------------------------------------
    '''

    # ----------------------------------------
    # 1) create_dataset
    # 2) creat_dataloader for train and test
    # ----------------------------------------
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = define_Dataset(dataset_opt)
            train_size = int(
                math.ceil(
                    len(train_set) / dataset_opt['dataloader_batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            train_loader = DataLoader(
                train_set,
                batch_size=dataset_opt['dataloader_batch_size'],
                shuffle=dataset_opt['dataloader_shuffle'],
                num_workers=dataset_opt['dataloader_num_workers'],
                drop_last=True,
                pin_memory=True)
        elif phase == 'test':
            test_set = define_Dataset(dataset_opt)
            test_loader = DataLoader(test_set,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=1,
                                     drop_last=False,
                                     pin_memory=True)
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)
    '''
    # ----------------------------------------
    # Step--3 (initialize model)
    # ----------------------------------------
    '''

    model = define_Model(opt)
    model.init_train()
    logger.info(model.info_network())
    logger.info(model.info_params())
    '''
    # ----------------------------------------
    # Step--4 (main training)
    # ----------------------------------------
    '''

    for epoch in range(100):  # keep running
        for i, train_data in enumerate(train_loader):

            current_step += 1

            # -------------------------------
            # 1) update learning rate
            # -------------------------------
            model.update_learning_rate(current_step)

            # -------------------------------
            # 2) feed patch pairs
            # -------------------------------
            model.feed_data(train_data)

            # -------------------------------
            # 3) optimize parameters
            # -------------------------------
            model.optimize_parameters(current_step)

            # -------------------------------
            # 4) training information
            # -------------------------------
            if current_step % opt['train']['checkpoint_print'] == 0:
                logs = model.current_log()  # such as loss
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.current_learning_rate())
                for k, v in logs.items():  # merge log information into message
                    message += '{:s}: {:.3e} '.format(k, v)
                logger.info(message)

            # -------------------------------
            # 5) save model
            # -------------------------------
            if current_step % opt['train']['checkpoint_save'] == 0:
                logger.info('Saving the model.')
                model.save(current_step)

            # -------------------------------
            # 6) testing
            # -------------------------------
            if current_step % opt['train']['checkpoint_test'] == 0:

                avg_psnr = 0.0
                idx = 0

                for test_data in test_loader:
                    idx += 1
                    image_name_ext = os.path.basename(test_data['L_path'][0])
                    img_name, ext = os.path.splitext(image_name_ext)

                    img_dir = os.path.join(opt['path']['images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(test_data)
                    model.test()

                    visuals = model.current_visuals()
                    E_img = util.tensor2uint(visuals['E'])
                    H_img = util.tensor2uint(visuals['H'])

                    # -----------------------
                    # save estimated image E
                    # -----------------------
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.imsave(E_img, save_img_path)

                    # -----------------------
                    # calculate PSNR
                    # -----------------------
                    current_psnr = util.calculate_psnr(E_img,
                                                       H_img,
                                                       border=border)

                    logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(
                        idx, image_name_ext, current_psnr))

                    avg_psnr += current_psnr

                avg_psnr = avg_psnr / idx

                # testing log
                logger.info(
                    '<epoch:{:3d}, iter:{:8,d}, Average PSNR : {:<.2f}dB\n'.
                    format(epoch, current_step, avg_psnr))

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Exemple #24
0
def main():
	# ----------------------------------------
	# load kernels
	# ----------------------------------------
	#PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF']
	PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF']
	#PSF_grid = np.load('./data/PSF.npz')['PSF']
	#print(PSF_grid.shape)
	
	PSF_grid = PSF_grid.astype(np.float32)

	gx,gy = PSF_grid.shape[:2]
	for xx in range(gx):
		for yy in range(gy):
			PSF_grid[xx,yy] = PSF_grid[xx,yy]/np.sum(PSF_grid[xx,yy],axis=(0,1))

	#PSF_grid = PSF_grid[:,1:-1,...]
	# ----------------------------------------
	# load model
	# ----------------------------------------
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")

	loaded_state = torch.load('./usrnet_ZEMAX.pth')
	#strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
	model.load_state_dict(loaded_state, strict=True)

	model.eval()
	#model.train()
	for _, v in model.named_parameters():
		v.requires_grad = False
	#	v.requires_grad = False
	model = model.to(device)


	for img_id in range(100):
		img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
		img_L = img_L.astype(np.float32)

		img_E = np.zeros_like(img_L)

		weight_E = np.zeros_like(img_L)

		patch_size = 2*128
		num_patch = 2
		p_size = patch_size//num_patch
		expand = PSF_grid.shape[2]

		ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,16,3)
		ab = torch.tensor(ab_numpy,device=device,requires_grad=False)

		#save img_L

		t0 = time.time()
		#while running:
		for px_start in range(0,6-2+1,2):
			for py_start in range(0,8-2+1,2):

				PSF_patch = PSF_grid[px_start:px_start+num_patch,py_start:py_start+num_patch]

				patch_L = img_L[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:]
				block_expand = expand
				#block_expand = 1
				if block_expand > 0:
					patch_L_wrap = util_deblur.wrap_boundary_liu(patch_L,(patch_size+block_expand*2,patch_size+block_expand*2))
					#centralize
					patch_L_wrap = np.hstack((patch_L_wrap[:,-block_expand:,:],patch_L_wrap[:,:patch_size+block_expand,:]))
					patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:,:,:],patch_L_wrap[:patch_size+block_expand,:,:]))
				else:
					patch_L_wrap = patch_L
				if block_expand>0:
					x = util.uint2single(patch_L_wrap)
				else:
					x = util.uint2single(patch_L)

				x = util.single2tensor4(x)

				# x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0)
				# x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0)

				k_all = []
				for w_ in range(num_patch):
					for h_ in range(num_patch):
						k_all.append(util.single2tensor4(PSF_patch[h_,w_]))
				k = torch.cat(k_all,dim=0)

				[x,k] = [el.to(device) for el in [x,k]]

				cd = F.softplus(ab[px_start:px_start+num_patch,py_start:py_start+num_patch])
				cd = cd.view(num_patch**2,2*8,3)

				x_E = model.forward_patchwise(x,k,cd,[num_patch,num_patch],[patch_size//num_patch,patch_size//num_patch])

				patch_L = patch_L_wrap.astype(np.uint8)

				patch_E = util.tensor2uint(x_E)
				patch_E_all = [util.tensor2uint(pp) for pp in x_E]

				#patch_E_z = np.hstack((patch_E_all[::2]))
				#patch_E_x = np.hstack((patch_E_all[1::2]))

				#patch_E_show = np.vstack((patch_E_z,patch_E_x))
				#if block_expand>0:
				#	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
				#else:
				#	show = np.hstack((patch_L,patch_E))

		
				#get kernel
				for i in range(8):
					img_E_deconv[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-2][expand:-expand,expand:-expand]
					img_E_denoise[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-1][expand:-expand,expand:-expand]
				weight_E[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += 1.0

				#cv2.imshow('stage',patch_E_show)
				#cv2.imshow('HL',show)
				#cv2.imshow('RGB',rgb)
				#key = cv2.waitKey(-1)
				#if key==ord('n'):
				#	break

		t1 = time.time()

		print(t1-t0)
		img_E = img_E/weight_E
		img_E_deconv = [pp/weight_E for pp in img_E_deconv]
		img_E_denoise = [pp/weight_E for pp in img_E_denoise]

		# print(i)
		xk = img_E_denoise[-1]
		# #zk = zk.astype(np.uint8)
		xk = xk.astype(np.uint8)
		#cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
		cv2.imshow('xx',xk)
		cv2.imshow('img_L',img_L.astype(np.uint8))
		cv2.waitKey(-1)
Exemple #25
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 30  # noise level for noisy image
    noise_level_model = noise_level_img  # noise level for model
    model_name = 'ffdnet_color'  # 'ffdnet_gray' | 'ffdnet_color' | 'ffdnet_color_clip' | 'ffdnet_gray_clip'
    testset_name = 'CBSD68'  # test set,  'bsd68' | 'cbsd68' | 'set12'
    need_degradation = True  # default: True
    show_img = False  # default: False

    task_current = 'dn'  # 'dn' for denoising | 'sr' for super-resolution
    sf = 1  # unused for denoising
    if 'color' in model_name:
        n_channels = 3  # setting for color image
        nc = 96  # setting for color image
        nb = 12  # setting for color image
    else:
        n_channels = 1  # setting for grayscale image
        nc = 64  # setting for grayscale image
        nb = 15  # setting for grayscale image
    if 'clip' in model_name:
        use_clip = True  # clip the intensities into range of [0, 1]
    else:
        use_clip = False
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    H_path = L_path  # H_path, for High-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_ffdnet import FFDNet as net
    model = net(in_nc=n_channels,
                out_nc=n_channels,
                nc=nc,
                nb=nb,
                act_mode='R')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        if need_degradation:  # degradation process
            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img / 255., img_L.shape)
            if use_clip:
                img_L = util.uint2single(util.single2uint(img_L))

        util.imshow(util.single2uint(img_L),
                    title='Noisy image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        sigma = torch.full((1, 1, 1, 1),
                           noise_level_model / 255.).type_as(img_L)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        img_E = model(img_L, sigma)
        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------
            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------

            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + ext))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.
            format(result_name, ave_psnr, ave_ssim))
Exemple #26
0
def main():
    # 0. global config
    # scale factor
    sf = 4
    stage = 5
    patch_size = [32, 32]
    patch_num = [2, 2]

    # 1. local PSF
    # shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    # 2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=5,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    loaded_state_dict = torch.load('./logs/uabcnet_final.pth')
    # loaded_state_dict = torch.load('./data/uabcnet_finetune.pth')
    model.load_state_dict(loaded_state_dict, strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # positional lambda, mu for HQS, set as free trainable parameters here.

    ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape(
        (1, 1, 2 * stage, 3)).astype(np.float32)

    ab_buffer = np.ones((patch_num[0], patch_num[1], 2 * stage, 3),
                        dtype=np.float32)
    for xx in range(patch_num[0]):
        for yy in range(patch_num[1]):
            ab_buffer[xx, yy] = ab_pretrain[0, 0]

    ab = torch.tensor(ab_buffer, device=device, requires_grad=False)
    # ab = F.softplus(ab)

    # 3.load training data
    imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 1000

    PSF_grid = using_AC254_lens(all_PSFs, patch_num)

    all_PSNR = []

    for i in range(N_maxiter):

        # draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        patch_E = util.tensor2uint((x_E))

        psnr = cv2.PSNR(patch_E, patch_H)
        all_PSNR.append(psnr)

        show = np.hstack((patch_H, patch_L, patch_E))
        if i % 250 == 0:
            cv2.imwrite(
                os.path.join('./result', 'finetune',
                             'result-{:04d}.png'.format(i + 1)), show)

    cv2.imwrite(
        os.path.join('./result', 'finetune',
                     'result-{:04d}.png'.format(i + 1)), show)
    np.savetxt(os.path.join('./result', 'finetune', 'psnr.txt'), all_PSNR)
Exemple #27
0
def main(json_path='options/train_sr.json'):
    '''
    # ----------------------------------------
    # Step--1 (prepare opt)
    # ----------------------------------------
    '''

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default=json_path,
                        help='Path to option JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=True)
    util.mkdirs(
        (path for key, path in opt['path'].items() if 'pretrained' not in key))

    # ----------------------------------------
    # update opt
    # ----------------------------------------

    init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'],
                                                         net_type='G1')
    opt['path']['pretrained_netG1'] = init_path_G
    current_step = init_iter

    border = opt['scale']

    # ----------------------------------------
    # save opt to  a '../option.json' file
    # ----------------------------------------
    option.save(opt)

    # ----------------------------------------
    # return None for missing key
    # ----------------------------------------
    opt = option.dict_to_nonedict(opt)

    # ----------------------------------------
    # configure logger
    # ----------------------------------------
    logger_name = 'train'
    utils_logger.logger_info(
        logger_name, os.path.join(opt['path']['log'], logger_name + '.log'))
    logger = logging.getLogger(logger_name)
    logger.info(option.dict2str(opt))

    # ----------------------------------------
    # seed
    # ----------------------------------------
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    '''
    # ----------------------------------------
    # Step--2 (creat dataloader)
    # ----------------------------------------
    '''

    # ----------------------------------------
    # 1) create_dataset
    # 2) creat_dataloader for train and test
    # ----------------------------------------
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = define_Dataset(phase, dataset_opt)
            train_size = int(
                math.ceil(
                    len(train_set) / dataset_opt['dataloader_batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            train_loader = DataLoader(
                train_set,
                batch_size=dataset_opt['dataloader_batch_size'],
                shuffle=dataset_opt['dataloader_shuffle'],
                num_workers=dataset_opt['dataloader_num_workers'],
                drop_last=True,
                pin_memory=True)
        elif phase == 'val':
            val_set = define_Dataset(phase, dataset_opt)
            val_loader = DataLoader(val_set,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=1,
                                    drop_last=False,
                                    pin_memory=True)
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)
    '''
    # ----------------------------------------
    # Step--3 (model_1)
    # ----------------------------------------
    '''

    model_1 = define_Model(opt, stage1=True)
    #logger.info(model_1.info_network())
    model_1.init_train()
    #logger.info(model_1.info_params())

    for epoch in range(100000):
        for i, train_data in enumerate(train_loader):

            current_step += 1

            model_1.update_learning_rate(current_step)

            model_1.feed_data(train_data)

            model_1.optimize_parameters(current_step)

            if current_step % opt['train']['checkpoint_save'] == 0:
                # logger.info('Saving the model.')
                model_1.save(current_step)

            # -------------------------------
            # model_1 testing
            # -------------------------------
            if current_step % opt['train']['checkpoint_test'] == 0:
                # training info
                logs = model_1.current_log()  # such as loss
                message_tr = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_1.current_learning_rate())
                for k, v in logs.items():  # merge log information into message
                    message_tr += '\t{:s}: {:.3e}'.format(k, v)

                avg_psnr = 0.0
                idx = 0

                for val_data in val_loader:
                    idx += 1

                    model_1.feed_data(val_data)
                    model_1.test()

                    visuals = model_1.current_visuals()
                    E_img = util.tensor2uint(visuals['E'])
                    H_img = util.tensor2uint(visuals['H'])
                    # -----------------------
                    # calculate PSNR
                    # -----------------------
                    current_psnr = util.calculate_psnr(E_img,
                                                       H_img,
                                                       border=border)

                    avg_psnr += current_psnr

                avg_psnr = avg_psnr / idx

                # testing log
                message_val = '\tStage SR Val_PSNR_avg: {:<.2f}dB'.format(
                    avg_psnr)
                message = message_tr + message_val
                logger.info(message)

    logger.info('End of Stage SR training.')
Exemple #28
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 15  # set AWGN noise level for noisy image
    noise_level_model = noise_level_img  # set noise level for model
    model_name = 'drunet_gray'  # set denoiser model, 'drunet_gray' | 'drunet_color'
    testset_name = 'bsd68'  # set test set,  'bsd68' | 'cbsd68' | 'set12'
    x8 = False  # default: False, x8 to boost performance
    show_img = False  # default: False
    border = 0  # shave boader to calculate PSNR and SSIM

    if 'color' in model_name:
        n_channels = 3  # 3 for color image
    else:
        n_channels = 1  # 1 for grayscale image

    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    task_current = 'dn'  # 'dn' for denoising
    result_name = testset_name + '_' + task_current + '_' + model_name

    model_path = os.path.join(model_pool, model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_unet import UNetRes as net
    model = net(in_nc=n_channels + 1,
                out_nc=n_channels,
                nc=[64, 128, 256, 512],
                nb=4,
                act_mode='R',
                downsample_mode="strideconv",
                upsample_mode="convtranspose")
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_H = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_H)

        # Add noise without clipping
        np.random.seed(seed=0)  # for reproducibility
        img_L += np.random.normal(0, noise_level_img / 255., img_L.shape)

        util.imshow(util.single2uint(img_L),
                    title='Noisy image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = torch.cat(
            (img_L, torch.FloatTensor([noise_level_model / 255.]).repeat(
                1, 1, img_L.shape[2], img_L.shape[3])),
            dim=1)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8 and img_L.size(2) // 8 == 0 and img_L.size(3) // 8 == 0:
            img_E = model(img_L)
        elif not x8 and (img_L.size(2) // 8 != 0 or img_L.size(3) // 8 != 0):
            img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)
        elif x8:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = util.tensor2uint(img_E)

        # --------------------------------
        # PSNR and SSIM
        # --------------------------------

        if n_channels == 1:
            img_H = img_H.squeeze()
        psnr = util.calculate_psnr(img_E, img_H, border=border)
        ssim = util.calculate_ssim(img_E, img_H, border=border)
        test_results['psnr'].append(psnr)
        test_results['ssim'].append(ssim)
        logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
            img_name + ext, psnr, ssim))

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + ext))

    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
    logger.info(
        'Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(
            result_name, ave_psnr, ave_ssim))
Exemple #29
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 0  # default: 0, noise level for LR image
    noise_level_model = noise_level_img  # noise level for model
    model_name = 'dpsr_x4_gan'  # 'dpsr_x2' | 'dpsr_x3' | 'dpsr_x4' | 'dpsr_x4_gan'
    testset_name = 'set5'  # test set,  'set5' | 'srbsd68'
    need_degradation = True  # default: True
    x8 = False  # default: False, x8 to boost performance
    sf = [int(s) for s in re.findall(r'\d+', model_name)][0]  # scale factor
    show_img = False  # default: False

    task_current = 'sr'  # 'dn' for denoising | 'sr' for super-resolution
    n_channels = 3  # fixed
    nc = 96  # fixed, number of channels
    nb = 16  # fixed, number of conv layers
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    H_path = L_path  # H_path, for High-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_dpsr import MSRResNet_prior as net
    model = net(in_nc=n_channels + 1,
                out_nc=n_channels,
                nc=nc,
                nb=nb,
                upscale=sf,
                act_mode='R',
                upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        # degradation process, bicubic downsampling + Gaussian noise
        if need_degradation:
            img_L = util.modcrop(img_L, sf)
            img_L = util.imresize_np(img_L, 1 / sf)
            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img / 255., img_L.shape)

        util.imshow(util.single2uint(img_L),
                    title='LR image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        noise_level_map = torch.full((1, 1, img_L.size(2), img_L.size(3)),
                                     noise_level_model / 255.).type_as(img_L)
        img_L = torch.cat((img_L, noise_level_map), dim=1)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf)

        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------

            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()
            img_H = util.modcrop(img_H, sf)

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------

            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

            if np.ndim(img_H) == 3:  # RGB image
                img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + '.png'))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'
            .format(result_name, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'
                .format(result_name, sf, ave_psnr_y, ave_ssim_y))
Exemple #30
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 7.65 / 255.0  # default: 0, noise level for LR image
    noise_level_model = noise_level_img  # noise level of model, default 0
    model_name = 'drunet_gray'  # 'drunet_gray' | 'drunet_color' | 'ircnn_gray' | 'ircnn_color'
    testset_name = 'Set3C'  # test set,  'set5' | 'srbsd68'
    x8 = True  # default: False, x8 to boost performance
    iter_num = 8  # number of iterations
    modelSigma1 = 49
    modelSigma2 = noise_level_model * 255.

    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image
    save_LEH = False  # save zoomed LR, E and H images
    border = 0

    # --------------------------------
    # load kernel
    # --------------------------------

    kernels = hdf5storage.loadmat(os.path.join('kernels',
                                               'Levin09.mat'))['kernels']

    sf = 1
    task_current = 'deblur'  # 'deblur' for deblurring
    n_channels = 3 if 'color' in model_name else 1  # fixed
    model_zoo = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + task_current + '_' + model_name
    model_path = os.path.join(model_zoo, model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    if 'drunet' in model_name:
        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels + 1,
                    out_nc=n_channels,
                    nc=[64, 128, 256, 512],
                    nb=4,
                    act_mode='R',
                    downsample_mode="strideconv",
                    upsample_mode="convtranspose")
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
    elif 'ircnn' in model_name:
        from models.network_dncnn import IRCNN as net
        model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
        model25 = torch.load(model_path)
        former_idx = 0

    logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info('Model path: {:s}'.format(model_path))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    test_results_ave = OrderedDict()
    test_results_ave['psnr'] = []  # record average PSNR for each kernel

    for k_index in range(kernels.shape[1]):

        logger.info('-------k:{:>2d} ---------'.format(k_index))
        test_results = OrderedDict()
        test_results['psnr'] = []
        k = kernels[0, k_index].astype(np.float64)
        util.imshow(k) if show_img else None

        for idx, img in enumerate(L_paths):

            # --------------------------------
            # (1) get img_L
            # --------------------------------

            img_name, ext = os.path.splitext(os.path.basename(img))
            img_H = util.imread_uint(img, n_channels=n_channels)
            img_H = util.modcrop(img_H, 8)  # modcrop

            img_L = ndimage.filters.convolve(img_H,
                                             np.expand_dims(k, axis=2),
                                             mode='wrap')
            util.imshow(img_L) if show_img else None
            img_L = util.uint2single(img_L)

            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img,
                                      img_L.shape)  # add AWGN

            # --------------------------------
            # (2) get rhos and sigmas
            # --------------------------------

            rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255 / 255.,
                                                       noise_level_model),
                                             iter_num=iter_num,
                                             modelSigma1=modelSigma1,
                                             modelSigma2=modelSigma2,
                                             w=1.0)
            rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(
                sigmas).to(device)

            # --------------------------------
            # (3) initialize x, and pre-calculation
            # --------------------------------

            x = util.single2tensor4(img_L).to(device)

            img_L_tensor, k_tensor = util.single2tensor4(
                img_L), util.single2tensor4(np.expand_dims(k, 2))
            [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor],
                                                     device)
            FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

            # --------------------------------
            # (4) main iterations
            # --------------------------------

            for i in range(iter_num):

                # --------------------------------
                # step 1, FFT
                # --------------------------------

                tau = rhos[i].float().repeat(1, 1, 1, 1)
                x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)

                if 'ircnn' in model_name:
                    current_idx = np.int(
                        np.ceil(sigmas[i].cpu().numpy() * 255. / 2.) - 1)

                    if current_idx != former_idx:
                        model.load_state_dict(model25[str(current_idx)],
                                              strict=True)
                        model.eval()
                        for _, v in model.named_parameters():
                            v.requires_grad = False
                        model = model.to(device)
                    former_idx = current_idx

                # --------------------------------
                # step 2, denoiser
                # --------------------------------

                if x8:
                    x = util.augment_img_tensor4(x, i % 8)

                if 'drunet' in model_name:
                    x = torch.cat((x, sigmas[i].float().repeat(
                        1, 1, x.shape[2], x.shape[3])),
                                  dim=1)
                    x = utils_model.test_mode(model,
                                              x,
                                              mode=2,
                                              refield=32,
                                              min_size=256,
                                              modulo=16)
                elif 'ircnn' in model_name:
                    x = model(x)

                if x8:
                    if i % 8 == 3 or i % 8 == 5:
                        x = util.augment_img_tensor4(x, 8 - i % 8)
                    else:
                        x = util.augment_img_tensor4(x, i % 8)

            # --------------------------------
            # (3) img_E
            # --------------------------------

            img_E = util.tensor2uint(x)
            if n_channels == 1:
                img_H = img_H.squeeze()

            if save_E:
                util.imsave(
                    img_E,
                    os.path.join(
                        E_path, img_name + '_k' + str(k_index) + '_' +
                        model_name + '.png'))

            # --------------------------------
            # (4) img_LEH
            # --------------------------------

            if save_LEH:
                img_L = util.single2uint(img_L)
                k_v = k / np.max(k) * 1.0
                k_v = util.single2uint(np.tile(k_v[..., np.newaxis],
                                               [1, 1, 3]))
                k_v = cv2.resize(k_v, (3 * k_v.shape[1], 3 * k_v.shape[0]),
                                 interpolation=cv2.INTER_NEAREST)
                img_I = cv2.resize(img_L,
                                   (sf * img_L.shape[1], sf * img_L.shape[0]),
                                   interpolation=cv2.INTER_NEAREST)
                img_I[:k_v.shape[0], -k_v.shape[1]:, :] = k_v
                img_I[:img_L.shape[0], :img_L.shape[1], :] = img_L
                util.imshow(np.concatenate([img_I, img_E, img_H], axis=1),
                            title='LR / Recovered / Ground-truth'
                            ) if show_img else None
                util.imsave(
                    np.concatenate([img_I, img_E, img_H], axis=1),
                    os.path.join(E_path,
                                 img_name + '_k' + str(k_index) + '_LEH.png'))

            if save_L:
                util.imsave(
                    util.single2uint(img_L),
                    os.path.join(E_path,
                                 img_name + '_k' + str(k_index) + '_LR.png'))

            psnr = util.calculate_psnr(
                img_E, img_H, border=border)  # change with your own border
            test_results['psnr'].append(psnr)
            logger.info('{:->4d}--> {:>10s} --k:{:>2d} PSNR: {:.2f}dB'.format(
                idx + 1, img_name + ext, k_index, psnr))

        # --------------------------------
        # Average PSNR
        # --------------------------------

        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        logger.info(
            '------> Average PSNR of ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'
            .format(testset_name, k_index, noise_level_model, ave_psnr))
        test_results_ave['psnr'].append(ave_psnr)