示例#1
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    model_name = '1000_G'  # 'msrresnet_x4_gan' | 'msrresnet_x4_psnr'
    testset_name = 'set_real'  # test set,  'set5' | 'srbsd68'
    need_degradation = True  # default: True
    x8 = False  # default: False, x8 to boost performance, default: False
    sf = [int(s) for s in re.findall(r'\d+', model_name)][0]  # scale factor
    show_img = False  # default: False

    task_current = 'sr'  # 'dn' for denoising | 'sr' for super-resolution
    n_channels = 3  # fixed
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    noise_level_img = 0  # fixed: 0, noise level for LR image
    result_name = testset_name + '_' + model_name
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets,
                          testset_name)  # L_path, for Low-quality images
    H_path = L_path  # H_path, for High-quality images
    E_path = os.path.join(results, result_name)  # E_path, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_msrresnet import MSRResNet1 as net
    model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=16, upscale=4)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    logger.info('model_name:{}, image sigma:{}'.format(model_name,
                                                       noise_level_img))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        # degradation process, bicubic downsampling
        if need_degradation:
            img_L = util.modcrop(img_L, sf)
            img_L = util.imresize_np(img_L, 1 / sf)
            # img_L = util.uint2single(util.single2uint(img_L))
            # np.random.seed(seed=0)  # for reproducibility
            # img_L += np.random.normal(0, noise_level_img/255., img_L.shape)

        util.imshow(util.single2uint(img_L),
                    title='LR image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8:
            img_E = model(img_L)
        else:
            img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf)

        img_E = util.tensor2uint(img_E)

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------

            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()
            img_H = util.modcrop(img_H, sf)

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------

            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

            if np.ndim(img_H) == 3:  # RGB image
                img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)

        # ------------------------------------
        # save results
        # ------------------------------------

        util.imsave(img_E, os.path.join(E_path, img_name + '.png'))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'
            .format(result_name, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'
                .format(result_name, sf, ave_psnr_y, ave_ssim_y))
示例#2
0
def main():

    utils_logger.logger_info('efficientsr_challenge',
                             log_path='efficientsr_challenge.log')
    logger = logging.getLogger('efficientsr_challenge')

    #    print(torch.__version__)               # pytorch version
    #    print(torch.version.cuda)              # cuda version
    #    print(torch.backends.cudnn.version())  # cudnn version

    # --------------------------------
    # basic settings
    # --------------------------------
    model_names = ['msrresnet', 'imdn']
    model_id = 1  # set the model name
    model_name = model_names[model_id]
    logger.info('{:>16s} : {:s}'.format('Model Name', model_name))

    testsets = 'testsets'  # set path of testsets
    testset_L = 'DIV2K_valid_LR'  # set current testing dataset; 'DIV2K_test_LR'
    testset_L = 'set12'

    save_results = True
    print_modelsummary = True  # set False when calculating `Max Memery` and `Runtime`

    torch.cuda.set_device(0)  # set GPU ID
    logger.info('{:>16s} : {:<d}'.format('GPU ID',
                                         torch.cuda.current_device()))
    torch.cuda.empty_cache()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # define network and load model
    # --------------------------------
    if model_name == 'msrresnet':
        from models.network_msrresnet import MSRResNet1 as net
        model = net(in_nc=3, out_nc=3, nc=64, nb=16,
                    upscale=4)  # define network
        model_path = os.path.join('model_zoo',
                                  'msrresnet_x4_psnr.pth')  # set model path
    elif model_name == 'imdn':
        from models.network_imdn import IMDN as net
        model = net(in_nc=3,
                    out_nc=3,
                    nc=64,
                    nb=8,
                    upscale=4,
                    act_mode='L',
                    upsample_mode='pixelshuffle')  # define network
        model_path = os.path.join('model_zoo', 'imdn_x4.pth')  # set model path

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # --------------------------------
    # print model summary
    # --------------------------------
    if print_modelsummary:
        from utils.utils_modelsummary import get_model_activation, get_model_flops
        input_dim = (3, 256, 256)  # set the input dimension

        activations, num_conv2d = get_model_activation(model, input_dim)
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Activations',
                                                   activations / 10**6))
        logger.info('{:>16s} : {:<d}'.format('#Conv2d', num_conv2d))

        flops = get_model_flops(model, input_dim, False)
        logger.info('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops / 10**9))

        num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
        logger.info('{:>16s} : {:<.4f} [M]'.format('#Params',
                                                   num_parameters / 10**6))

    # --------------------------------
    # read image
    # --------------------------------
    L_path = os.path.join(testsets, testset_L)
    E_path = os.path.join(testsets, testset_L + '_' + model_name)
    util.mkdir(E_path)

    # record runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
    logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_path):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name + ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        torch.cuda.empty_cache()
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        # logger.info('{:>16s} : {:<.3f} [M]'.format('Max Memery', torch.cuda.max_memory_allocated(torch.cuda.current_device())/1024**2))  # Memery
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        #        torch.cuda.synchronize()
        #        start = time.time()
        #        img_E = model(img_L)
        #        torch.cuda.synchronize()
        #        end = time.time()
        #        test_results['runtime'].append(end-start)  # seconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save_results:
            util.imsave(img_E, os.path.join(E_path, img_name + ext))
    ave_runtime = sum(test_results['runtime']) / len(
        test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(
        L_path, ave_runtime))