示例#1
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 0 / 255.0  # set AWGN noise level for LR image, default: 0
    noise_level_model = noise_level_img  # set noise level of model, default: 0
    model_name = 'ircnn_color'  # set denoiser, 'drunet_color' | 'ircnn_color'
    testset_name = 'Set18'  # set testing set,  'set18' | 'set24'
    x8 = True  # set PGSE to boost performance, default: True
    iter_num = 40  # set number of iterations, default: 40 for demosaicing
    modelSigma1 = 49  # set sigma_1, default: 49
    modelSigma2 = max(0.6, noise_level_model * 255.)  # set sigma_2, default
    matlab_init = True

    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image
    save_LEH = False  # save zoomed LR, E and H images
    border = 10  # default 10 for demosaicing

    task_current = 'dm'  # 'dm' for demosaicing
    n_channels = 3  # fixed
    model_zoo = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    result_name = testset_name + '_' + task_current + '_' + model_name
    model_path = os.path.join(model_zoo, model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    if 'drunet' in model_name:
        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels + 1,
                    out_nc=n_channels,
                    nc=[64, 128, 256, 512],
                    nb=4,
                    act_mode='R',
                    downsample_mode="strideconv",
                    upsample_mode="convtranspose")
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
    elif 'ircnn' in model_name:
        from models.network_dncnn import IRCNN as net
        model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
        model25 = torch.load(model_path)
        former_idx = 0

    logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(
        model_name, noise_level_img, noise_level_model))
    logger.info('Model path: {:s}'.format(model_path))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)

    test_results = OrderedDict()
    test_results['psnr'] = []

    for idx, img in enumerate(L_paths):

        # --------------------------------
        # (1) get img_H and img_L
        # --------------------------------

        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        img_H = util.imread_uint(img, n_channels=n_channels)
        CFA, CFA4, mosaic, mask = utils_mosaic.mosaic_CFA_Bayer(img_H)

        # --------------------------------
        # (2) initialize x
        # --------------------------------

        if matlab_init:  # matlab demosaicing for initialization
            CFA4 = util.uint2tensor4(CFA4).to(device)
            x = utils_mosaic.dm_matlab(CFA4)
        else:
            x = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA)
            x = util.uint2tensor4(x).to(device)

        img_L = util.tensor2uint(x)
        y = util.uint2tensor4(mosaic).to(device)

        util.imshow(img_L) if show_img else None
        mask = util.single2tensor4(mask.astype(np.float32)).to(device)

        # --------------------------------
        # (3) get rhos and sigmas
        # --------------------------------

        rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255 / 255.,
                                                   noise_level_img),
                                         iter_num=iter_num,
                                         modelSigma1=modelSigma1,
                                         modelSigma2=modelSigma2,
                                         w=1.0)
        rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(
            device)

        # --------------------------------
        # (4) main iterations
        # --------------------------------

        for i in range(iter_num):

            # --------------------------------
            # step 1, closed-form solution
            # --------------------------------

            x = (y + rhos[i].float() * x).div(mask + rhos[i])

            # --------------------------------
            # step 2, denoiser
            # --------------------------------

            if 'ircnn' in model_name:
                current_idx = np.int(
                    np.ceil(sigmas[i].cpu().numpy() * 255. / 2.) - 1)
                if current_idx != former_idx:
                    model.load_state_dict(model25[str(current_idx)],
                                          strict=True)
                    model.eval()
                    for _, v in model.named_parameters():
                        v.requires_grad = False
                    model = model.to(device)
                former_idx = current_idx

            x = torch.clamp(x, 0, 1)
            if x8:
                x = util.augment_img_tensor4(x, i % 8)

            if 'drunet' in model_name:
                x = torch.cat((x, sigmas[i].float().repeat(
                    1, 1, x.shape[2], x.shape[3])),
                              dim=1)
                x = utils_model.test_mode(model,
                                          x,
                                          mode=2,
                                          refield=32,
                                          min_size=256,
                                          modulo=16)
                # x = model(x)
            elif 'ircnn' in model_name:
                x = model(x)

            if x8:
                if i % 8 == 3 or i % 8 == 5:
                    x = util.augment_img_tensor4(x, 8 - i % 8)
                else:
                    x = util.augment_img_tensor4(x, i % 8)

        x[mask.to(torch.bool)] = y[mask.to(torch.bool)]

        # --------------------------------
        # (4) img_E
        # --------------------------------

        img_E = util.tensor2uint(x)
        psnr = util.calculate_psnr(img_E, img_H, border=border)
        test_results['psnr'].append(psnr)
        logger.info('{:->4d}--> {:>10s} -- PSNR: {:.2f}dB'.format(
            idx, img_name + ext, psnr))

        if save_E:
            util.imsave(
                img_E,
                os.path.join(E_path, img_name + '_' + model_name + '.png'))

        if save_L:
            util.imsave(img_L, os.path.join(E_path, img_name + '_L.png'))

        if save_LEH:
            util.imsave(
                np.concatenate([img_L, img_E, img_H], axis=1),
                os.path.join(E_path, img_name + model_name + '_LEH.png'))

    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    logger.info('------> Average PSNR(RGB) of ({}) is : {:.2f} dB'.format(
        testset_name, ave_psnr))
示例#2
0
文件: test.py 项目: zjucmx/RFDN
def main():

    utils_logger.logger_info('AIM-track', log_path='AIM-track.log')
    logger = logging.getLogger('AIM-track')

    # --------------------------------
    # basic settings
    # --------------------------------
    testsets = 'DIV2K'
    testset_L = 'DIV2K_valid_LR_bicubic'
    #testset_L = 'DIV2K_test_LR_bicubic'

    torch.cuda.current_device()
    torch.cuda.empty_cache()
    #torch.backends.cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # load model
    # --------------------------------
    model_path = os.path.join('trained_model', 'RFDN_AIM.pth')
    model = RFDN()
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # number of parameters
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    # --------------------------------
    # read image
    # --------------------------------
    L_folder = os.path.join(testsets, testset_L, 'X4')
    E_folder = os.path.join(testsets, testset_L+'_results')
    util.mkdir(E_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(E_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    img_SR = []
    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)
        img_SR.append(img_E)

        # --------------------------------
        # (3) save results
        # --------------------------------
        #util.imsave(img_E, os.path.join(E_folder, img_name+ext))

    ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))

    # --------------------------------
    # (4) calculate psnr
    # --------------------------------
    '''
示例#3
0
def main():

    utils_logger.logger_info('efficientsr_challenge',
                             log_path='efficientsr_challenge.log')
    logger = logging.getLogger('efficientsr_challenge')

    #    print(torch.__version__)               # pytorch version
    #    print(torch.version.cuda)              # cuda version
    #    print(torch.backends.cudnn.version())  # cudnn version

    # --------------------------------
    # basic settings
    # --------------------------------
    model_names = ['msrresnet', 'imdn']
    model_id = 1  # set the model name
    model_name = model_names[model_id]
    logger.info('{:>16s} : {:s}'.format('Model Name', model_name))

    testsets = 'testsets'  # set path of testsets
    testset_L = 'DIV2K_valid_LR'  # set current testing dataset; 'DIV2K_test_LR'
    testset_L = 'set12'

    save_results = True
    print_modelsummary = True  # set False when calculating `Max Memery` and `Runtime`

    torch.cuda.set_device(0)  # set GPU ID
    logger.info('{:>16s} : {:<d}'.format('GPU ID',
                                         torch.cuda.current_device()))
    torch.cuda.empty_cache()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # define network and load model
    # --------------------------------
    if model_name == 'msrresnet':
        from models.network_msrresnet import MSRResNet1 as net
        model = net(in_nc=3, out_nc=3, nc=64, nb=16,
                    upscale=4)  # define network
        model_path = os.path.join('model_zoo',
                                  'msrresnet_x4_psnr.pth')  # set model path
    elif model_name == 'imdn':
        from models.network_imdn import IMDN as net
        model = net(in_nc=3,
                    out_nc=3,
                    nc=64,
                    nb=8,
                    upscale=4,
                    act_mode='L',
                    upsample_mode='pixelshuffle')  # define network
        model_path = os.path.join('model_zoo', 'imdn_x4.pth')  # set model path

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # --------------------------------
    # print model summary
    # --------------------------------
    if print_modelsummary:
        from utils.utils_modelsummary import get_model_activation, get_model_flops
        input_dim = (3, 256, 256)  # set the input dimension

        activations, num_conv2d = get_model_activation(model, input_dim)
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Activations',
                                                   activations / 10**6))
        logger.info('{:>16s} : {:<d}'.format('#Conv2d', num_conv2d))

        flops = get_model_flops(model, input_dim, False)
        logger.info('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops / 10**9))

        num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Params',
                                                   num_parameters / 10**6))

    # --------------------------------
    # read image
    # --------------------------------
    L_path = os.path.join(testsets, testset_L)
    E_path = os.path.join(testsets, testset_L + '_' + model_name)
    util.mkdir(E_path)

    # record runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
    logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_path):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name + ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        torch.cuda.empty_cache()
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        # logger.info('{:>16s} : {:<.3f} [M]'.format('Max Memery', torch.cuda.max_memory_allocated(torch.cuda.current_device())/1024**2))  # Memery
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        #        torch.cuda.synchronize()
        #        start = time.time()
        #        img_E = model(img_L)
        #        torch.cuda.synchronize()
        #        end = time.time()
        #        test_results['runtime'].append(end-start)  # seconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save_results:
            util.imsave(img_E, os.path.join(E_path, img_name + ext))
    ave_runtime = sum(test_results['runtime']) / len(
        test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(
        L_path, ave_runtime))
示例#4
0
def main():

    utils_logger.logger_info('blind_sr_log', log_path='blind_sr_log.log')
    logger = logging.getLogger('blind_sr_log')

#    print(torch.__version__)               # pytorch version
#    print(torch.version.cuda)              # cuda version
#    print(torch.backends.cudnn.version())  # cudnn version

    testsets = 'testsets'       # fixed, set path of testsets
    testset_Ls = ['RealSRSet']  # ['RealSRSet','DPED']

    model_names = ['RRDB','ESRGAN','FSSR_DPED','FSSR_JPEG','RealSR_DPED','RealSR_JPEG']
    model_names = ['BSRGAN']

    save_results = True
    sf = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for model_name in model_names:

        model_path = os.path.join('model_zoo', model_name+'.pth')          # set model path
        logger.info('{:>16s} : {:s}'.format('Model Name', model_name))

        # torch.cuda.set_device(0)      # set GPU ID
        logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
        torch.cuda.empty_cache()

        # --------------------------------
        # define network and load model
        # --------------------------------
        model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32)  # define network

#            model_old = torch.load(model_path)
#            state_dict = model.state_dict()
#            for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
#                state_dict[key2] = param
#            model.load_state_dict(state_dict, strict=True)

        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        torch.cuda.empty_cache()

        for testset_L in testset_Ls:

            L_path = os.path.join(testsets, testset_L)
            #E_path = os.path.join(testsets, testset_L+'_'+model_name)
            E_path = os.path.join(testsets, testset_L+'_results_x'+str(sf))
            util.mkdir(E_path)

            logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
            logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
            idx = 0

            for img in util.get_image_paths(L_path):

                # --------------------------------
                # (1) img_L
                # --------------------------------
                idx += 1
                img_name, ext = os.path.splitext(os.path.basename(img))
                logger.info('{:->4d} --> {:<s} --> x{:<d}--> {:<s}'.format(idx, model_name, sf, img_name+ext))

                img_L = util.imread_uint(img, n_channels=3)
                img_L = util.uint2tensor4(img_L)
                img_L = img_L.to(device)

                # --------------------------------
                # (2) inference
                # --------------------------------
                img_E = model(img_L)

                # --------------------------------
                # (3) img_E
                # --------------------------------
                img_E = util.tensor2uint(img_E)
                if save_results:
                    util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
示例#5
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    if noise_level_model == -1:
        model_name = 'srmdnf_x' + str(sf)
    else:
        model_name = 'srmd_x' + str(sf)
    model_path = os.path.join(model_pool, model_name+'.pth')
    in_nc = 18 if 'nf' in model_name else 19

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = sources  # L_path, for Low-quality images
    E_path = results   # E_path, for Estimated images
    if not os.path.splitext(E_path)[1]:
        util.mkdir(E_path)

    device = torch.device(using_device)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from utils.network_srmd import SRMD as net
    model = net(in_nc=in_nc, out_nc=n_channels, nc=nc, nb=nb,
                upscale=sf, act_mode='R', upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    if os.path.isfile(L_path):
        L_paths = [L_path]
    else:
        L_paths = util.get_image_paths(L_path)

    # ----------------------------------------
    # kernel and PCA reduced feature
    # ----------------------------------------

    # Gaussian kernel, delta kernel 0.01
    kernel = utils_deblur.fspecial('gaussian', 15, 0.01)

    P = loadmat(srmd_pca_path)['P']
    degradation_vector = np.dot(P, np.reshape(kernel, (-1), order="F"))
    if 'nf' not in model_name:  # noise-free SR
        degradation_vector = np.append(
            degradation_vector, noise_level_model/255.)
    degradation_vector = torch.from_numpy(
        degradation_vector).view(1, -1, 1, 1).float()

    for _, img in enumerate(L_paths):
        img_name, _ = os.path.splitext(os.path.basename(img))
        try:
            # ------------------------------------
            # (1) img_L
            # ------------------------------------
            img_L, alpha = util.imread_uint_alpha(img, n_channels=n_channels)
            # Bicubic to handle alpha channel if the intended picture is supposed to have.
            if not alpha is None and picture_format == "png":
                alpha = util.uint2tensor4(alpha)
                alpha = torch.nn.functional.interpolate(
                    alpha, scale_factor=sf, mode='bicubic', align_corners=False)
                alpha = alpha.to(device)
                alpha = torch.clamp(alpha, 0, 255)
                alpha = util.tensor2uint(alpha) 
            img_L = util.uint2tensor4(img_L)
            degradation_map = degradation_vector.repeat(
                1, 1, img_L.size(-2), img_L.size(-1))
            img_L = torch.cat((img_L, degradation_map), dim=1)
            img_L = img_L.to(device)

            # ------------------------------------
            # (2) img_E
            # ------------------------------------

            if not x8:
                img_E = model(img_L)
            else:
                img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf)

            img_E = util.tensor2uint(img_E)
            if not alpha is None and picture_format == "png":
                alpha = alpha.reshape((alpha.shape[0], alpha.shape[1], 1))
                img_E = np.concatenate((img_E, alpha), axis=2)
            elif not alpha is None:
                print("Warning! You lost your alpha channel for this picture!")

            # ------------------------------------
            # save results
            # ------------------------------------
            if os.path.splitext(E_path)[1]:
                util.imsave(img_E, E_path)
            else:
                util.imsave(img_E, os.path.join(
                    E_path, img_name+'.' + picture_format))
            print(os.path.basename(img) + " successfully saved to disk!")
        except Exception:
            traceback.print_exc()
            print(os.path.basename(img) + " failed!")
示例#6
0
def main(model=None, model_path=None):

    utils_logger.logger_info('AIM-track',
                             log_path=os.path.join(model_path,
                                                   'AIM-track.log'))
    logger = logging.getLogger('AIM-track')

    # --------------------------------
    # basic settings
    # --------------------------------
    testsets = 'DIV2K'  # DIV2K root path
    testset_L = 'DIV2K_test_LR_bicubic'  # test image folder name

    torch.cuda.current_device()
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # load model
    # --------------------------------
    # model_path = os.path.join('MSRResNetx4_model', 'MSRResNetx4.pth')
    if model is None:
        model_path = 'MSRResNetx4_model'
        model = MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=4)
        model.load_state_dict(torch.load(os.path.join(model_path,
                                                      'model.pth')),
                              strict=True)
    model.eval()
    """ for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device) """

    # number of parameters
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    print('Params number: {}'.format(number_parameters))

    # --------------------------------
    # read image
    # --------------------------------
    L_folder = os.path.join(testsets, testset_L)
    assert os.path.isdir(L_folder)  # check the test images path
    E_folder = os.path.join(model_path, 'results')
    util.mkdir(E_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(E_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name + ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        #        torch.cuda.synchronize()
        #        start = time.time()
        #        img_E = model(img_L)
        #        torch.cuda.synchronize()
        #        end = time.time()
        #        test_results['runtime'].append(end-start)  # seconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        util.imsave(img_E, os.path.join(E_folder, img_name + ext))
    ave_runtime = sum(test_results['runtime']) / len(
        test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(
        L_folder, ave_runtime))
    print('------> Average runtime of ({}) is : {:.6f} seconds'.format(
        L_folder, ave_runtime))