Ejemplo n.º 1
0
def define_C(opt):
    opt_net = opt['network_C']
    which_model = opt_net['which_model_C']

    if which_model == 'dfn':
        netC = DynamicF.DFN_Color_correction()
    elif 'ResNet' in which_model:
        netC = SRResNet_arch.ResNet_alpha_beta_multi_in(which_model)
    else:
        raise NotImplementedError(
            'Discriminator model [{:s}] not recognized'.format(which_model))
    return netC
Ejemplo n.º 2
0
def main():
    #################
    # configurations
    #################
    #torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.enabled = True

    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'

    test_set = 'AI4K_test'  # Vid4 | YouKu10 | REDS4 | AI4K_test
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    test_name = 'Contest2_Test18_A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000'  #'AI4K_TEST_Denoise_A02_265000'    |  AI4K_test_A01b_145000
    N_in = 5

    # load test set
    if test_set == 'AI4K_test':
        #test_dataset_folder =  '/data1/yhliu/AI4K/Corrected_TestA_Contest2_001_ResNet_alpha_beta_gaussian_65000/'     #'/data1/yhliu/AI4K/testA_LR_png/'
        test_dataset_folder = '/home/yhliu/AI4K/contest2/testA_LR_png/'

    flip_test = False  #False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #model_path = '../experiments/002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new/models/latest_G.pth'
    #model_path = '../experiments/A02_predenoise/models/415000_G.pth'

    model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth'

    color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR_FusionDenoise(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, deconv=False)

    color_model = SRResNet_arch.ResNet_alpha_beta_multi_in(
        structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(test_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    #### set up the models
    load_net = torch.load(color_model_path)
    load_net_clean = OrderedDict()  # add prefix 'color_net.'
    for k, v in load_net.items():
        k = 'color_net.' + k
        load_net_clean[k] = v

    color_model.load_state_dict(load_net_clean, strict=True)
    color_model.eval()
    color_model = color_model.to(device)
    color_model = nn.DataParallel(color_model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).cpu()
            print(imgs_in.size())

            if flip_test:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.flipx4_forward(model, imgs_in)
            else:
                start_time = time.time()
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.single_forward(model, imgs_in)
                end_time = time.time()
                print('Forward One image:', end_time - start_time)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            logger.info('{:3d} - {:25}'.format(img_idx + 1, img_name))

    logger.info('################ Tidy Outputs ################')

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Ejemplo n.º 3
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4'
    test_set = 'AI4K_val'  # Vid4 | YouKu10 | REDS4 | AI4K_val
    test_name = 'A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000'
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_bic':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png_bic/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/'
    elif test_set == 'AI4K_testA':
        test_dataset_folder = '/data0/yhliu/AI4K/val2_LR_png'  #'/home/yhliu/AI4K/val2_LR_png/'
        GT_dataset_folder = '/data0/yhliu/AI4K/val1_HR_png/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/A01xxx/900000_G.pth'
    model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth'

    color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    color_model = SRResNet_arch.ResNet_alpha_beta_multi_in(
        structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(test_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    #### set up the models
    load_net = torch.load(color_model_path)
    load_net_clean = OrderedDict()  # add prefix 'color_net.'
    for k, v in load_net.items():
        k = 'color_net.' + k
        load_net_clean[k] = v

    color_model.load_state_dict(load_net_clean, strict=True)
    color_model.eval()
    color_model = color_model.to(device)
    color_model = nn.DataParallel(color_model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0,
                torch.LongTensor(select_idx)).unsqueeze(0).cpu()  #to(device)
            print(imgs_in.size())

            if flip_test:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.flipx4_forward(model, imgs_in)
            else:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))