예제 #1
0
def define_G(opt):
    opt_net = opt['netG']
    net_type = opt_net['net_type']


    # ----------------------------------------
    # denoising task
    # ----------------------------------------

    # ----------------------------------------
    # DnCNN
    # ----------------------------------------
    if net_type == 'dncnn':
        from models.network_dncnn import DnCNN as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],  # total number of conv layers
                   act_mode=opt_net['act_mode'])

    # ----------------------------------------
    # Flexible DnCNN
    # ----------------------------------------
    elif net_type == 'fdncnn':
        from models.network_dncnn import FDnCNN as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],  # total number of conv layers
                   act_mode=opt_net['act_mode'])

    # ----------------------------------------
    # FFDNet
    # ----------------------------------------
    elif net_type == 'ffdnet':
        from models.network_ffdnet import FFDNet as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   act_mode=opt_net['act_mode'])

    # ----------------------------------------
    # others
    # ----------------------------------------

    # ----------------------------------------
    # super-resolution task
    # ----------------------------------------

    # ----------------------------------------
    # SRMD
    # ----------------------------------------
    elif net_type == 'srmd':
        from models.network_srmd import SRMD as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # super-resolver prior of DPSR
    # ----------------------------------------
    elif net_type == 'dpsr':
        from models.network_dpsr import MSRResNet_prior as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # modified SRResNet v0.0
    # ----------------------------------------
    elif net_type == 'msrresnet0':
        from models.network_msrresnet import MSRResNet0 as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # modified SRResNet v0.1
    # ----------------------------------------
    elif net_type == 'msrresnet1':
        from models.network_msrresnet import MSRResNet1 as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # RRDB
    # ----------------------------------------
    elif net_type == 'rrdb':  # RRDB
        from models.network_rrdb import RRDB as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   gc=opt_net['gc'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # IMDB
    # ----------------------------------------
    elif net_type == 'imdn':  # IMDB
        from models.network_imdn import IMDN as net
        netG = net(in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   upscale=opt_net['scale'],
                   act_mode=opt_net['act_mode'],
                   upsample_mode=opt_net['upsample_mode'])

    # ----------------------------------------
    # USRNet
    # ----------------------------------------
    elif net_type == 'usrnet':  # USRNet
        from models.network_usrnet import USRNet as net
        netG = net(n_iter=opt_net['n_iter'],
                   h_nc=opt_net['h_nc'],
                   in_nc=opt_net['in_nc'],
                   out_nc=opt_net['out_nc'],
                   nc=opt_net['nc'],
                   nb=opt_net['nb'],
                   act_mode=opt_net['act_mode'],
                   downsample_mode=opt_net['downsample_mode'],
                   upsample_mode=opt_net['upsample_mode']
                   )

    # ----------------------------------------
    # others
    # ----------------------------------------
    # TODO

    else:
        raise NotImplementedError('netG [{:s}] is not found.'.format(net_type))

    # ----------------------------------------
    # initialize weights
    # ----------------------------------------
    if opt['is_train']:
        init_weights(netG,
                     init_type=opt_net['init_type'],
                     init_bn_type=opt_net['init_bn_type'],
                     gain=opt_net['init_gain'])

    return netG
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 15  # noise level for noisy image
    noise_level_model = noise_level_img  # noise level for model
    model_name = 'fdncnn_gray'  # 'fdncnn_gray' | 'fdncnn_color' | 'fdncnn_color_clip' | 'fdncnn_gray_clip'
    testset_name = 'bsd68'  # test set,  'bsd68' | 'cbsd68' | 'set12'
    need_degradation = True  # default: True
    x8 = False  # default: False, x8 to boost performance
    show_img = False  # default: Falsedefault: False

    task_current = 'dn'  # 'dn' for denoising | 'sr' for super-resolution
    sf = 1  # unused for denoising
    if 'color' in model_name:
        n_channels = 3  # 3 for color image
    else:
        n_channels = 1  # 1 for grayscale image
    if 'clip' in model_name:
        use_clip = True  # clip the intensities into range of [0, 1]
    else:
        use_clip = False
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    H_path = L_path  # H_path, for High-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_dncnn import FDnCNN as net
    model = net(in_nc=n_channels + 1,
                out_nc=n_channels,
                nc=64,
                nb=20,
                act_mode='R')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        if need_degradation:  # degradation process
            np.random.seed(seed=0)  # for reproducibility
            img_L += np.random.normal(0, noise_level_img / 255., img_L.shape)
        if use_clip:
            img_L = util.uint2single(util.single2uint(img_L))

        util.imshow(util.single2uint(img_L),
                    title='Noisy image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        noise_level_map = torch.ones(
            (1, 1, img_L.size(2), img_L.size(3)),
            dtype=torch.float).mul_(noise_level_model / 255.)
        img_L = torch.cat((img_L, noise_level_map), dim=1)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------

            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------

            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + ext))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.
            format(result_name, ave_psnr, ave_ssim))
예제 #3
0
def denoising(noise_im,
              clean_im,
              LR=1e-2,
              sigma=5,
              rho=1,
              eta=0.5,
              total_step=20,
              prob1_iter=500,
              result_root=None,
              f=None):

    input_depth = 3
    latent_dim = 3

    en_net = Encoder(input_depth,
                     latent_dim,
                     down_sample_norm='batchnorm',
                     up_sample_norm='batchnorm').cuda()
    de_net = Decoder(latent_dim,
                     input_depth,
                     down_sample_norm='batchnorm',
                     up_sample_norm='batchnorm').cuda()

    model = net(3, 3, nc=64, nb=20, act_mode='R')
    model_path = '/home/dihan/KAIR/model_zoo/dncnn_color_blind.pth'
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False

    model = model.cuda()

    noise_im_torch = np_to_torch(noise_im)
    noise_im_torch = noise_im_torch.cuda()

    with torch.no_grad():
        r_dncnn_np = torch_to_np(model(noise_im_torch))

    psnr_dncnn = compare_psnr(clean_im.transpose(1, 2, 0),
                              r_dncnn_np.transpose(1, 2, 0), 1)
    ssim_dncnn = compare_ssim(r_dncnn_np.transpose(1, 2, 0),
                              clean_im.transpose(1, 2, 0),
                              multichannel=True,
                              data_range=1)

    print('PSNR_DNCNN: {}, SSIM_DNCNN: {}'.format(psnr_dncnn, ssim_dncnn),
          file=f,
          flush=True)

    parameters = [p for p in en_net.parameters()
                  ] + [p for p in de_net.parameters()]
    optimizer = torch.optim.Adam(parameters, lr=LR)
    l2_loss = torch.nn.MSELoss(reduction='sum').cuda()

    i0 = np_to_torch(noise_im).cuda()
    Y = torch.zeros_like(noise_im_torch).cuda()

    i0_til_torch = np_to_torch(noise_im).cuda()

    diff_original_np = noise_im.astype(np.float32) - clean_im.astype(
        np.float32)
    diff_original_name = 'Original_dis.png'
    save_hist(diff_original_np, result_root + diff_original_name)

    best_psnr = 0
    best_ssim = 0

    for i in range(total_step):

        ############################### sub-problem 1 #################################

        for i_1 in range(prob1_iter):

            optimizer.zero_grad()

            mean, log_var = en_net(noise_im_torch)

            z = sample_z(mean, log_var)
            out = de_net(z)

            total_loss = 0.5 * l2_loss(out, noise_im_torch)
            total_loss += kl_loss(mean, log_var, i0, sigma)
            total_loss += (rho / 2) * l2_loss(i0 + Y, i0_til_torch)

            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                i0 = ((1 / sigma**2) * mean + rho *
                      (i0_til_torch - Y)) / ((1 / sigma**2) + rho)

        with torch.no_grad():

            ############################### sub-problem 2 #################################

            i0_til_torch = model(i0 + Y)

            ############################### sub-problem 3 #################################

            Y = Y + eta * (i0 - i0_til_torch)

            ###############################################################################

            i0_np = torch_to_np(i0)
            Y_np = torch_to_np(Y)

            denoise_obj_pil = np_to_pil((i0_np + Y_np).clip(0, 1))

            Y_norm_np = np.sqrt((Y_np * Y_np).sum(0))

            i0_pil = np_to_pil(i0_np)

            mean_np = torch_to_np(mean)

            mean_pil = np_to_pil(mean_np)

            out_np = torch_to_np(out)
            out_pil = np_to_pil(out_np)

            diff_np = mean_np - clean_im

            denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png'
            Y_name = 'Y_{:04d}'.format(i) + '.png'
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            mean_i_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png'

            denoise_obj_pil.save(result_root + denoise_obj_name)
            save_heatmap(Y_norm_np, result_root + Y_name)
            i0_pil.save(result_root + i0_name)
            mean_pil.save(result_root + mean_i_name)
            out_pil.save(result_root + out_name)
            save_hist(diff_np, result_root + diff_name)

            i0_til_np = torch_to_np(i0_til_torch).clip(0, 1)

            psnr = compare_psnr(clean_im.transpose(1, 2, 0),
                                i0_til_np.transpose(1, 2, 0), 1)
            ssim = compare_ssim(clean_im.transpose(1, 2, 0),
                                i0_til_np.transpose(1, 2, 0),
                                multichannel=True,
                                data_range=1)
            i0_til_pil = np_to_pil(i0_til_np)
            i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png'))

            print('Iteration: %02d, VAE Loss: %f, PSNR: %f, SSIM: %f' %
                  (i, total_loss.item(), psnr, ssim),
                  file=f,
                  flush=True)

            if best_psnr < psnr:
                best_psnr = psnr
                best_ssim = ssim
            else:
                break

    return i0_til_np, best_psnr, best_ssim
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    model_name = 'dncnn3'  # 'dncnn3'- can be used for blind Gaussian denoising, JPEG deblocking (quality factor 5-100) and super-resolution (x234)

    # important!
    testset_name = 'bsd68'  # test set, low-quality grayscale/color JPEG images
    n_channels = 1  # set 1 for grayscale image, set 3 for color image

    x8 = False  # default: False, x8 to boost performance
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + model_name  # fixed
    L_path = os.path.join(
        testsets, testset_name
    )  # L_path, for Low-quality grayscale/Y-channel JPEG images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    model_pool = 'model_zoo'  # fixed
    model_path = os.path.join(model_pool, model_name + '.pth')
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_dncnn import DnCNN as net
    model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx + 1, img_name + ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)
        if n_channels == 3:
            ycbcr = util.rgb2ycbcr(img_L, False)
            img_L = ycbcr[..., 0:1]
        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------
        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = util.tensor2single(img_E)
        if n_channels == 3:
            ycbcr[..., 0] = img_E
            img_E = util.ycbcr2rgb(ycbcr)
        img_E = util.single2uint(img_E)

        # ------------------------------------
        # save results
        # ------------------------------------
        util.imsave(img_E, os.path.join(E_path, img_name + '.png'))
예제 #5
0
def main(args):

    # Model name
    if args['model'] is None:
        model_name = 'dncnn_50'  # 'dncnn_25' | 'dncnn_50' | 'dncnn_gray_blind' | 'dncnn_color_blind' | 'dncnn3'
    else:
        model_name = args['model']

    # RGB or Gray Scale mode
    if args['color'] == 'rgb':
        model_name = 'dncnn_color_blind'

    # Error Handling for unavailable model
    try:
        model_path = os.path.join(model_pool, model_name + '.pth')
        print("Using model %s" % (model_name))
        print("---------------------------------")

    except:
        print('Model not found')
        exit()

    # Disabled now - Found to reduce quality of output
    x8 = False  # default: False, x8 to boost performance

    if args['type'] == None:
        task_current = 'dn'  # 'dn' for denoising | 'sr' for super-resolution
    else:
        task_current = args['type']

    sf = 1  # unused for denoising

    # Identify if model used is color
    if 'color' in model_name:
        n_channels = 3  # fixed, 1 for grayscale image, 3 for color image
    else:
        n_channels = 1  # fixed for grayscale image

    if model_name in ['dncnn_gray_blind', 'dncnn_color_blind', 'dncnn3']:
        nb = 20  # fixed
    else:
        nb = 17  # fixed

    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM

    need_H = False

    # Load Model
    model = net(in_nc=n_channels,
                out_nc=n_channels,
                nc=64,
                nb=nb,
                act_mode='R')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # Load Image
    if (args['input'] != None) and (args['batch'] == None):  # Single image
        img = args['input']
        predict(img, n_channels, model, x8)
    elif (args['input'] == None) and (args['batch'] != None):  # Batch Mode
        # Load each image from directory and predict
        mypath = args['batch']
        # Check if path exist and load list of files
        if os.path.exists(mypath):
            files = [
                f for f in listdir(mypath)
                if os.path.isfile(os.path.join(mypath, f))
            ]
            print("Found %d files" % (len(files)))
            # Predict for each file and save results
            for item in files:
                try:
                    filepath = os.path.join(mypath, item)
                    predict(filepath, n_channels, model, x8)
                except:
                    print("Error with %s" % (item))
        else:
            print("Path does not exist")
            exit()