Esempio n. 1
0
 def img2HQimg(self):
     select_idx = data_util.index_generation(0, 1, 5, padding=self.padding)
     imgs = data_util.read_img_seq(os.path.dirname(self.input_path))
     imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(self.device)
     output = util.single_forward(self.model, imgs_in)
     output = util.tensor2img(output.squeeze(0))
     util.save_img(img=output, img_path=self.output_path)
Esempio n. 2
0
def eval(images, N_in, model, device):
    images = images.astype(np.float32) / 255.
    images = torch.from_numpy(
        np.ascontiguousarray(np.transpose(images, (0, 3, 1, 2)))).float()
    imgs_in = images.index_select(0, torch.LongTensor(list(
        range(N_in)))).unsqueeze(0).to(device)
    output = util.single_forward(model, imgs_in)
    output = util.tensor2img(output.squeeze(0))
    output = np.ascontiguousarray(output)
    return output
Esempio n. 3
0
 def video2HQvideo(self):
     """create a HQ video from a cached imgs and save it into output path"""
     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
     video_writer = cv2.VideoWriter(self.output_path, fourcc, self.fps, self.img_shape)
     for batch_num in range(100, ( 100 * len(os.listdir(self.imgs_cache)) ) + 100, 100):
         self.read_batch(batch_num)
         bar = self.progress.ProgressBar(max_value=len(self.imgs_paths))
         for img_idx, img_path in enumerate(self.imgs_paths):
             select_idx = data_util.index_generation(img_idx, len(self.imgs_paths), self.N_in, padding=self.padding)
             imgs_in = self.imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(self.device)
             output = util.single_forward(self.model, imgs_in)
             output = util.tensor2img(output.squeeze(0))
             video_writer.write(output)
             bar.update(img_idx)
         video_writer.release()
Esempio n. 4
0
 def imgs2HQimgs(self):
     """create a HQ imgs from a cached imgs and save them in a ./output/ subfolder of curren directory"""
     util.mkdir(self.output_path)
     print(100 * len(os.listdir(self.imgs_cache)))
     for batch_num in range(100, ( 100 * len(os.listdir(self.imgs_cache)) ) + 100, 100):
         self.read_batch(batch_num)
         bar = self.progress.ProgressBar(max_value=len(self.imgs_paths))
         for img_idx, img_path in enumerate(self.imgs_paths):
             select_idx = data_util.index_generation(img_idx, len(self.imgs_paths), self.N_in, padding=self.padding)
             imgs_in = self.imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(self.device)
             output = util.single_forward(self.model, imgs_in)
             output = util.tensor2img(output.squeeze(0))
             print(output.shape)
             cv2.imwrite(f'{self.output_path}{img_idx}.png', output)
             bar.update(img_idx)
Esempio n. 5
0
 def single_inference(self, input_tensor):
     h_outputs = list()
     h_tensors = input_tensor.split(self.split_H, dim=-2)
     for h_tensor in h_tensors:
         w_outputs = list()
         split_tensors = h_tensor.split(self.split_W, dim=-1)
         for split_tensor in split_tensors:
             output = util.single_forward(
                 self.model,
                 split_tensor.to(self.device))  # output: Tensor[B,1,C,H,W]
             output = output.squeeze(1).float().to('cpu').clamp_(
                 0, 1)  # clamp / output: Tensor[B,C,H,W]
             output = (output * 255.0).round().type(torch.uint8)
             w_outputs.append(output)
         h_outputs.append(torch.cat(w_outputs, dim=-1))
     result = torch.cat(h_outputs, dim=-2)
     return result
Esempio n. 6
0
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(description="Test with EDVR, requre path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')
        GT_dataset_folder = os.path.join(args.input, 'GT')
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 7
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'HDR'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = True
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    elif data_mode == 'HDR':
        # model_path = '../experiments/001_EDVR_scratch_lr4e-4_600k_HDR_LrCAR4S/models/20000_G.pth'
        # model_path = '../experiments/005_EDVRwoTSA_scratch_lr4e-4_600k_HDR_LrCAR4S/models/490000_G.pth'
        model_path = '../experiments/pretrained_models/50000_G.pth'

    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if data_mode == 'HDR':
        predeblur = True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in,
                           w_TSA=True)

    #### dataset
    GT_dataset_folder = None
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif data_mode == 'HDR':
        test_dataset_folder = '../datasets/HDR/valid/new_method/540p'
        GT_dataset_folder = '../datasets/HDR/valid/new_method/4k'
        # test_dataset_folder = '../datasets/HDR/valid/sequences_540'
        # GT_dataset_folder = '../datasets/HDR/valid/sequences_4k'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}_50000'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        # print(subfolder, subfolde
        if '10675978' not in subfolder:
            print('pass')
            continue
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)

            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            # output = output / 255.
            # GT = np.copy(img_GT_l[img_idx])
            # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            # if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #     GT = data_util.bgr2ycbcr(GT, only_y=True)
            #     output = data_util.bgr2ycbcr(output, only_y=True)

            # output, GT = util.crop_border([output, GT], crop_border)
            # crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            # # crt_ssim = util.calculate_ssim(output * 255, GT * 255)
            logger.info('{:3d} - {:25} '.format(img_idx + 1, img_name))
Esempio n. 8
0
def val(model_name, current_step, arch='EDVR'):
    #################
    # configurations
    #################
    device = torch.device('cuda')
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4'
    test_set = 'REDS4'  # Vid4 | YouKu10 | REDS4 | AI4K_val
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_LR_png/'
        GT_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_val_small':
        test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_small/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_small/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    model_path = os.path.join('../experiments/', model_name,
                              'models/{}_G.pth'.format(current_step))

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    if arch == 'EDVR':
        model = EDVR_arch.EDVR(64,
                               N_in,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    elif arch == 'MY_EDVR':
        model = my_EDVR_arch.MYEDVR(64,
                                    N_in,
                                    8,
                                    5,
                                    back_RBs,
                                    predeblur=predeblur,
                                    HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = False

    save_folder = '../validation/{}'.format(test_set)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            #print(imgs_in.size())

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))

    return sum(avg_psnr_l) / len(avg_psnr_l)
Esempio n. 9
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/cinepak_small2.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:
        print('Processing video {:s}'.format(subfolder))
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            print('\tProcessing frame {:s}')
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)),
                        output)
Esempio n. 10
0
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    #parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    print("GPU ", torch.cuda.device_count())
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device('cuda')
    
    data_mode = 'sharp_bicubic' 
    flip_test = False

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    model_path = Model_path

    N_in = 5

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])


    #### dataset
    test_dataset_folder = Input_folder
    GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'
    save_imgs = True

    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l = [], [], []

    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

    end = time.time()

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images, notice we load yuv img here
        imgs_LQ = data_util.read_img_seq_yuv(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img_yuv(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_rgb_psnr, avg_rgb_psnr_border, avg_rgb_psnr_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            scale = 4
 
            intPaddingRight = PAD   # 32# 64# 128# 256
            intPaddingLeft = PAD    # 32#64 #128# 256
            intPaddingTop = PAD     # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader(imgs_in)  # N C 604 1024

            # todo: output 4k

            X0 = imgs_in

            X0 = torch.unsqueeze(X0, 0)
            if flip_test:
                output = util.flipx4_forward(model, X0)
            else:
                output = util.single_forward(model, X0)

            # todo remove padding
            output = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight_ori) * scale,
                                  intPaddingLeft * scale: (intPaddingLeft + intWidth_ori) * scale]

            output = util.tensor2img(output.squeeze(0))

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")

            total_run_time.update(time.time() - end, 1)

            # calculate PSNR on YUV
            y_all = output / 255.
            GT = np.copy(img_GT_l[img_idx])

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tYUV_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))


            # here we also calculate PSNR on RGB
            y_all_rgb = data_util.ycbcr2rgb(output / 255.)
            GT_rgb = data_util.ycbcr2rgb(np.copy(img_GT_l[img_idx])) 
            y_all_rgb, GT_rgb = util.crop_border([y_all_rgb, GT_rgb], crop_border)
            crt_rgb_psnr = util.calculate_psnr(y_all_rgb * 255, GT_rgb * 255)
            logger.info('{:3d} - {:25} \tRGB_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_rgb_psnr))


            if save_imgs:
                im_out = np.round(y_all_rgb*255.).astype(numpy.uint8)
                # todo, notice here we got rgb img, but cv2 need bgr when saving a img
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), cv2.cv2Color(im_out, cv2.COLOR_RGB2BGR))

            # for YUV and RGB, respectively
            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_rgb_psnr_center += crt_rgb_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_rgb_psnr_border += crt_rgb_psnr
                N_border += 1

        # for YUV
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average YUV PSNR: {:.6f} dB for {} frames; '
                    'Center YUV PSNR: {:.6f} dB for {} frames; '
                    'Border YUV PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

        # for RGB
        avg_rgb_psnr = (avg_rgb_psnr_center + avg_rgb_psnr_border) / (N_center + N_border)
        avg_rgb_psnr_center = avg_rgb_psnr_center / N_center
        avg_rgb_psnr_border = 0 if N_border == 0 else avg_rgb_psnr_border / N_border
        avg_rgb_psnr_l.append(avg_rgb_psnr)
        avg_rgb_psnr_center_l.append(avg_rgb_psnr_center)
        avg_rgb_psnr_border_l.append(avg_rgb_psnr_border)

        logger.info('Folder {} - Average RGB PSNR: {:.6f} dB for {} frames; '
                    'Center RGB PSNR: {:.6f} dB for {} frames; '
                    'Border RGB PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_rgb_psnr,
                                                                   (N_center + N_border),
                                                                   avg_rgb_psnr_center, N_center,
                                                                   avg_rgb_psnr_border, N_border))

        
    logger.info('################ Tidy Outputs ################')
    # for YUV
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average YUV PSNR: {:.6f} dB. '
                    'Center YUV PSNR: {:.6f} dB. '
                    'Border YUV PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))

    # for RGB
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l):
        logger.info('Folder {} - Average RGB PSNR: {:.6f} dB. '
                    'Center RGB PSNR: {:.6f} dB. '
                    'Border RGB PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))


    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average YUV PSNR: {:.6f} dB for {} clips. '
                'Center YUV PSNR: {:.6f} dB. Border YUV PSNR: {:.6f} dB.'.format(
        sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
        sum(avg_psnr_center_l) / len(avg_psnr_center_l),
        sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average RGB PSNR: {:.6f} dB for {} clips. '
                'Center RGB PSNR: {:.6f} dB. Border RGB PSNR: {:.6f} dB.'.format(
        sum(avg_rgb_psnr_l) / len(avg_rgb_psnr_l), len(subfolder_l),
        sum(avg_rgb_psnr_center_l) / len(avg_rgb_psnr_center_l),
        sum(avg_rgb_psnr_border_l) / len(avg_rgb_psnr_border_l)))
Esempio n. 11
0
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    #parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt',
                        type=str,
                        required=True,
                        help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    print("GPU ", torch.cuda.device_count())
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'],
                           nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'],
                           front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'],
                           HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(
                    device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX +
                                   1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY +
                                    1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d(
                [0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " +
                  str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([
                intPaddingLeft, intPaddingRight, intPaddingTop,
                intPaddingBottom
            ])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            # todo: output 4k

            X0 = imgs_in
            X0 = torch.unsqueeze(X0, 0)
            if flip_test:
                output = util.flipx4_forward(model, X0)
            else:
                output = util.single_forward(model, X0)

            # todo remove padding
            output = output[0, :, intPaddingTop *
                            scale:(intPaddingTop + intHeight_ori) * scale,
                            intPaddingLeft *
                            scale:(intPaddingLeft + intWidth_ori) * scale]

            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            print("*****************current image process time \t " +
                  str(time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            y_all = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 12
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    #### model
    data_mode = 'sharp'
    if stage == 1:
        model_path = '../experiments/001_EDVRwoTSA_scratch_lr4e-4_600k_SR4K_LrCAR4S_64_20_5/models/600000_G.pth'
    else:
        model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'

    N_in = 5  # use N_in images to restore one HR image

    predeblur, HR_in = False, False
    back_RBs = 20
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, w_TSA=True)

    #### dataset
    if stage == 1:
        test_dataset_folder = '/home/mcc/4khdr/image/540p_test'
    else:
        test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
        print('You should modify the test_dataset_folder path for stage 2')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'sharp':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output,
                            [int(cv2.IMWRITE_PNG_COMPRESSION), 1])

            logger.info('{:3d} - {:25}'.format(img_idx + 1, img_name))
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    flip_test = False
    scale = 4
    N_in = 5
    predeblur, HR_in = False, False
    n_feats = 128
    back_RBs = 40

    save_imgs = False
    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode',
                      '-t',
                      type=str,
                      default='REDS',
                      help='train mode')
    prog.add_argument('--data_mode',
                      '-m',
                      type=str,
                      default=None,
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')

    args = prog.parse_args()

    train_data_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_data_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_data_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    ############################################################################
    #### model
    if scale == 2:
        if train_data_mode == 'Vimeo':
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_M_Scale2_FT.pth'
            #model_path = '../experiments/pretrained_models/EDVR_M_BLIND_V_FT_report.pth'
            # model_path = '../experiments/pretrained_models/2500_G.pth'
        elif train_data_mode == 'REDS':
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M_Scale2.pth'
            # model_path = '../experiments/pretrained_models/EDVR_M_BLIND_R_FT_report.pth'
        elif train_data_mode == 'Both':
            model_path = '../experiments/pretrained_models/EDVR_REDS+Vimeo90K_SR_M_Scale2_FT.pth'
        elif train_data_mode == 'MM522':
            model_path = '../experiments/pretrained_models/EDVR_MM522_SR_M_Scale2_FT.pth'

        else:
            raise NotImplementedError
    else:
        if data_mode == 'Vid4':
            model_path = '../experiments/pretrained_models/EDVR_BLIND_Vimeo_SR_L.pth'
            # model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'

        elif data_mode == 'REDS':
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
            # model_path = '../experiments/pretrained_models/EDVR_BLIND_REDS_SR_L.pth'
        else:
            raise NotImplementedError

    model = EDVR_arch.EDVR(n_feats,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in,
                           scale=scale)

    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str(
            '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    #### dataset
    if data_mode == 'Vid4':
        # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format(
            scale)
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/REDS/train/HR'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    if data_mode == 'REDS':
        subfolder_GT_l = [
            k for k in subfolder_GT_l if k.find('000') >= 0
            or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0
        ]
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            '''
            output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1)
            GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor)
            torch.save(output_tensor.cpu(), '../results/sr_test.pt')
            torch.save(GT_tensor.cpu(), '../results/hr_test.pt')
            my_psnr = utility.calc_psnr(output_tensor, GT_tensor)
            GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0)
            imageio.imwrite('../results/hr_test.png', GT_tensor)
            print('saved', my_psnr)
            '''
            '''
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = 0.001  #util.calculate_ssim(output, GT)

            # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))
    '''
    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))
    for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l,
                                                              avg_ssim_center_l, avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f}. '
                    'Center SSIM: {:.6f}. '
                    'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border))
    '''
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))

    print('\n\n\n')
Esempio n. 14
0
def main(name_flag, input_path, gt_path, model_path, save_path, save_imgs,
         flip_test):
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    save_path = os.path.join(save_path, name_flag)

    #### model
    model = CNLRN_arch.CNLRN(n_colors=3,
                             n_deblur_blocks=20,
                             n_nlrgs_body=6,
                             n_nlrgs_up1=2,
                             n_nlrgs_up2=2,
                             n_subgroups=2,
                             n_rcabs=4,
                             n_feats=64,
                             nonlocal_psize=(4, 4, 4, 4),
                             scale=4)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    #### logger
    util.mkdirs(save_path)
    util.setup_logger('base',
                      save_path,
                      name_flag,
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    logger.info('Evaluate: {}'.format(name_flag))
    logger.info('Input images path: {}'.format(input_path))
    logger.info('GT images path: {}'.format(gt_path))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Results save path: {}'.format(save_path))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Save images: {}'.format(save_imgs))

    #### Evaluation
    total_psnr_l = []
    total_ssim_l = []

    img_path_l = sorted(glob.glob(osp.join(input_path, '*')))

    #### read LQ and GT images
    imgs_LQ = data_util.read_img_seq(input_path)
    img_GT_l = []
    for img_GT_path in sorted(glob.glob(osp.join(gt_path, '*'))):
        img_GT_l.append(data_util.read_img(None, img_GT_path))

    # process each image
    for img_idx, img_path in enumerate(img_path_l):
        img_name = osp.splitext(osp.basename(img_path))[0]
        imgs_in = imgs_LQ[img_idx:img_idx + 1].to(device)

        if flip_test:
            output = util.flipx8_forward(model, imgs_in)
        else:
            output = util.single_forward(model, imgs_in)

        output = util.tensor2img(output.squeeze(0))

        if save_imgs:
            cv2.imwrite(osp.join(save_path, '{}.png'.format(img_name)), output)

        # calculate PSNR
        output = output / 255.
        GT = np.copy(img_GT_l[img_idx])

        output, GT = util.crop_border([output, GT], crop_border=4)
        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
        crt_ssim = util.ssim(output * 255, GT * 255)
        total_psnr_l.append(crt_psnr)
        total_ssim_l.append(crt_ssim)

        logger.info('{} \tPSNR: {:.3f} \tSSIM: {:.4f}'.format(
            img_name, crt_psnr, crt_ssim))

    logger.info('################ Final Results ################')
    logger.info('Evaluate: {}'.format(name_flag))
    logger.info('Input images path: {}'.format(input_path))
    logger.info('GT images path: {}'.format(gt_path))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Results save path: {}'.format(save_path))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info(
        'Total Average PSNR: {:.3f} SSIM: {:.4f} for {} images.'.format(
            sum(total_psnr_l) / len(total_psnr_l),
            sum(total_ssim_l) / len(total_ssim_l), len(img_path_l)))
Esempio n. 15
0
def main(gpu_id, start_id, step):
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    data_mode = 'AI4K'

    stage = 1  # 1 or 2
    flip_test = True

    #### model
    if data_mode == 'AI4K':
        if stage == 1:
            model_path = '/home/zenghui/projects/4KHDR/experiments/pretrained_models/EDVR_L_G300k.pth'  # TODO: change path
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'  # TODO: change path
    else:
        raise NotImplementedError()

    N_in = 5  # use N_in images to restore one HR image
    predeblur, HR_in = False, False
    back_RBs = 40
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == 'AI4K':
        test_dataset_folder = '../datasets/SDR_540p_PNG_test'  # TODO: change path
    else:
        raise NotImplementedError()

    #### scene information
    scene_index_path = '../keys/test_scene_idx.pkl'  # TODO: change path
    scene_dict = pickle.load(open(scene_index_path, 'rb'))

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    padding = 'replicate'  # temporal padding mode
    save_imgs = True
    save_folder = '../results_edvr_l_tsa_300k_2/{}'.format(
        data_mode)  # TODO: change path
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    subfolder_name_l = []
    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    seq_id = start_id
    for subfolder in subfolder_l[start_id::step]:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        logger.info(
            'Processing sequence: {}, seq_id = {}, crop_edge = {}'.format(
                subfolder_name, seq_id, crop_edge[seq_id]))
        hr_crop_edge = crop_edge[seq_id] * 4

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))

        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation_with_scene_list(
                img_idx,
                max_idx,
                N_in,
                scene_dict[subfolder_name],
                padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            #if crop_edge[seq_id]>0:
            #    output[:,:,:hr_crop_edge, :] = 0
            #    output[:,:,-hr_crop_edge:, :] = 0
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

        seq_id += step
Esempio n. 16
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    data_mode = 'KWAI'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 2  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    elif data_mode == 'KWAI':
        if stage == 1:
            model_path = '/home/web_server/zhouhuanxiang/disk/log/experiments/EDVR_KWAI40_M/models/latest_G.pth'
        else:
            model_path = '/home/web_server/zhouhuanxiang/disk/log/experiments/EDVR_KWAI40_M_2/models/latest_G.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp' or data_mode == 'KWAI':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    # import models.archs.EDVR_arch as EDVR_arch
    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
    # import models.archs.EDVR_woDCN_arch as EDVR_arch
    import models.archs.EDVR_arch as EDVR_arch
    if stage == 1:
        model = EDVR_arch.EDVR(64,
                               N_in,
                               8,
                               5,
                               10,
                               predeblur=False,
                               HR_in=True,
                               w_TSA=False)
    elif stage == 2:
        model = EDVR_arch.EDVR(64,
                               N_in,
                               8,
                               5,
                               10,
                               predeblur=False,
                               HR_in=True,
                               w_TSA=True)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif data_mode == 'KWAI':
        test_dataset_folder = '/home/web_server/zhouhuanxiang/disk/data/HD_UGC_crf40_raw_test'
        GT_dataset_folder = '/home/web_server/zhouhuanxiang/disk/data/HD_UGC_raw_test'
        # test_dataset_folder = '/home/web_server/zhouhuanxiang/disk/data/TEMP_crf40_raw'
        # GT_dataset_folder = '/home/web_server/zhouhuanxiang/disk/data/TEMP_crf40_raw'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '/home/web_server/zhouhuanxiang/disk/log/results/EDVR_{}_{}'.format(
        data_mode, stage)
    # save_folder = '/home/web_server/zhouhuanxiang/disk/log/results/TEMP'
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                # output = util.single_forward(model, imgs_in)
                output = util.single_forward(model, imgs_in)
                output, aligned_fea = util.single_forward_all_results(
                    model, imgs_in)

            # save feature
            aligned_fea = aligned_fea.squeeze(0)
            for i in range(aligned_fea.shape[0]):
                fea = util.tensor2img(aligned_fea[i])
                print(osp.join(save_subfolder, '{}_{}.png'.format(img_name,
                                                                  i)))
                cv2.imwrite(
                    osp.join(save_subfolder, '{}_{}.png'.format(img_name, i)),
                    fea)

            if img_idx == 50:
                break

            # time1 = time.time()
            # for i in range(50):
            #     if flip_test:
            #         output = util.flipx4_forward(model, imgs_in)
            #     else:
            #         output = util.single_forward(model, imgs_in)
            # time2 = time.time()
            # print(1 / (time2 - time1) * 50)

            # output = util.tensor2img(output.squeeze(0))

            # if save_imgs:
            #     cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            # output = output / 255.
            # GT = np.copy(img_GT_l[img_idx])
            # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            # if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #     GT = data_util.bgr2ycbcr(GT, only_y=True)
            #     output = data_util.bgr2ycbcr(output, only_y=True)

            # output, GT = util.crop_border([output, GT], crop_border)
            # crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            # logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            # if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
            #     avg_psnr_center += crt_psnr
            #     N_center += 1
            # else:  # border frames
            #     avg_psnr_border += crt_psnr
            #     N_border += 1

        # avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        # avg_psnr_center = avg_psnr_center / N_center
        # avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        # avg_psnr_l.append(avg_psnr)
        # avg_psnr_center_l.append(avg_psnr_center)
        # avg_psnr_border_l.append(avg_psnr_border)

        # logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
        #             'Center PSNR: {:.6f} dB for {} frames; '
        #             'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
        #                                                            (N_center + N_border),
        #                                                            avg_psnr_center, N_center,
        #                                                            avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 17
0
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(
        description="Test with EDVR, require path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)
Esempio n. 18
0
def main():
    #################
    # configurations
    #################
    #torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.enabled = True

    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'

    test_set = 'AI4K_test'  # Vid4 | YouKu10 | REDS4 | AI4K_test
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    test_name = 'Contest2_Test18_A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000'  #'AI4K_TEST_Denoise_A02_265000'    |  AI4K_test_A01b_145000
    N_in = 5

    # load test set
    if test_set == 'AI4K_test':
        #test_dataset_folder =  '/data1/yhliu/AI4K/Corrected_TestA_Contest2_001_ResNet_alpha_beta_gaussian_65000/'     #'/data1/yhliu/AI4K/testA_LR_png/'
        test_dataset_folder = '/home/yhliu/AI4K/contest2/testA_LR_png/'

    flip_test = False  #False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #model_path = '../experiments/002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new/models/latest_G.pth'
    #model_path = '../experiments/A02_predenoise/models/415000_G.pth'

    model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth'

    color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR_FusionDenoise(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, deconv=False)

    color_model = SRResNet_arch.ResNet_alpha_beta_multi_in(
        structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(test_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    #### set up the models
    load_net = torch.load(color_model_path)
    load_net_clean = OrderedDict()  # add prefix 'color_net.'
    for k, v in load_net.items():
        k = 'color_net.' + k
        load_net_clean[k] = v

    color_model.load_state_dict(load_net_clean, strict=True)
    color_model.eval()
    color_model = color_model.to(device)
    color_model = nn.DataParallel(color_model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).cpu()
            print(imgs_in.size())

            if flip_test:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.flipx4_forward(model, imgs_in)
            else:
                start_time = time.time()
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.single_forward(model, imgs_in)
                end_time = time.time()
                print('Forward One image:', end_time - start_time)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            logger.info('{:3d} - {:25}'.format(img_idx + 1, img_name))

    logger.info('################ Tidy Outputs ################')

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Esempio n. 19
0
def main(opts):
    ################## configurations #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus
    cache_all_imgs = opts.cache > 0
    n_gpus = len(opts.gpus.split(','))
    flip_test, save_imgs = False, False
    scale = 4
    N_in, nf = 5, 64
    back_RBs = 10
    w_TSA = False
    predeblur, HR_in = False, False
    crop_border = 0
    border_frame = N_in // 2
    padding = 'new_info'

    ################## model files ####################
    model_dir = opts.model_dir
    if osp.isfile(model_dir):
        model_names = [osp.basename(model_dir)]
        model_dir = osp.dirname(model_dir)
    elif osp.isdir(model_dir):
        model_names = [
            x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0])
        ]
        model_names = sorted(model_names, key=lambda x: int(x.split("_")[0]))
    else:
        raise IOError('Invalid model_dir: {}'.format(model_dir))

    ################## dataset ########################
    test_subs = sorted(os.listdir(opts.test_dir))
    gt_subs = os.listdir(opts.gt_dir)
    valid_test_subs = [sub in gt_subs for sub in test_subs]
    assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format(
        opts.test_dir)
    scale = float(os.path.basename(os.path.dirname(opts.test_dir))[1:])
    if cache_all_imgs:
        print('Cacheing all testing images ...')
        all_imgs = {}
        for sub in test_subs:
            print('Reading sub-folder: {} ...'.format(sub))
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            all_imgs[sub] = {'test': [], 'gt': []}
            im_names = sorted(os.listdir(test_sub_dir))
            for i, name in enumerate(im_names):
                test_im_path = osp.join(test_sub_dir, name)
                gt_im_path = osp.join(gt_sub_dir, name)
                test_im = cv2.imread(test_im_path,
                                     cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                test_im = test_im.astype(np.float32).transpose(
                    (2, 0, 1)) / 255.
                all_imgs[sub]['test'].append(test_im)
                gt_im = cv2.imread(gt_im_path,
                                   cv2.IMREAD_UNCHANGED).astype(np.float32)
                all_imgs[sub]['gt'].append(gt_im)

    all_psnrs = []
    for model_name in model_names:
        model_path = osp.join(model_dir, model_name)
        exp_name = model_name.split('_')[0]
        if 'meta' in opts.mode.lower():
            model = EDVR_arch.MetaEDVR(nf=nf,
                                       nframes=N_in,
                                       groups=8,
                                       front_RBs=5,
                                       center=None,
                                       back_RBs=back_RBs,
                                       predeblur=predeblur,
                                       HR_in=HR_in,
                                       w_TSA=w_TSA)
        elif opts.mode.lower() == 'edvr':
            model = EDVR_arch.EDVR(nf=nf,
                                   nframes=N_in,
                                   groups=8,
                                   front_RBs=5,
                                   center=None,
                                   back_RBs=back_RBs,
                                   predeblur=predeblur,
                                   HR_in=HR_in,
                                   w_TSA=w_TSA)
        elif opts.mode.lower() == 'upedvr':
            model = EDVR_arch.UPEDVR(nf=nf,
                                     nframes=N_in,
                                     groups=8,
                                     front_RBs=5,
                                     center=None,
                                     back_RBs=10,
                                     w_TSA=w_TSA,
                                     down_scale=True,
                                     align_target=True,
                                     ret_valid=True)
        elif opts.mode.lower() == 'upcont1':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=False)
        elif opts.mode.lower() == 'upcont3':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        elif opts.mode.lower() == 'upcont2':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        else:
            raise TypeError('Unknown model mode: {}'.format(opts.mode))
        save_folder = osp.join(opts.save_dir, exp_name)
        util.mkdirs(save_folder)
        util.setup_logger(exp_name,
                          save_folder,
                          'test',
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger(exp_name)

        #### log info
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))

        #### set up the models
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        if n_gpus > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], []
        avg_ssims, avg_ssim_centers, avg_ssim_borders = [], [], []
        evaled_subs = []

        # for each subfolder
        for sub in test_subs:
            evaled_subs.append(sub)
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            img_names = sorted(os.listdir(test_sub_dir))
            max_idx = len(img_names)
            if save_imgs:
                save_subfolder = osp.join(save_folder, sub)
                util.mkdirs(save_subfolder)

            #### get LQ and GT images
            if not cache_all_imgs:
                img_LQs, img_GTs = [], []
                for i, name in enumerate(img_names):
                    test_im_path = osp.join(test_sub_dir, name)
                    gt_im_path = osp.join(gt_sub_dir, name)
                    test_im = cv2.imread(test_im_path,
                                         cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                    test_im = test_im.astype(np.float32).transpose(
                        (2, 0, 1)) / 255.
                    gt_im = cv2.imread(gt_im_path,
                                       cv2.IMREAD_UNCHANGED).astype(np.float32)
                    img_LQs.append(test_im)
                    img_GTs.append(gt_im)
            else:
                img_LQs = all_imgs[sub]['test']
                img_GTs = all_imgs[sub]['gt']

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
            avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

            # process each image
            for i in range(0, max_idx, n_gpus):
                end = min(i + n_gpus, max_idx)
                select_idxs = [
                    data_util.index_generation(j,
                                               max_idx,
                                               N_in,
                                               padding=padding)
                    for j in range(i, end)
                ]
                imgs = []
                for select_idx in select_idxs:
                    im = torch.from_numpy(
                        np.stack([img_LQs[k] for k in select_idx]))
                    imgs.append(im)
                if (i + n_gpus) > max_idx:
                    for _ in range(max_idx, i + n_gpus):
                        imgs.append(torch.zeros_like(im))
                imgs = torch.stack(imgs, 0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs)
                else:
                    if 'meta' in opts.mode.lower():
                        output = util.meta_single_forward(
                            model, imgs, scale, n_gpus)
                    if 'up' in opts.mode.lower():
                        output = util.up_single_forward(model, imgs, scale)
                    else:
                        output = util.single_forward(model, imgs)
                output = [
                    util.tensor2img(x).astype(np.float32) for x in output
                ]

                if save_imgs:
                    for ii in range(i, end):
                        cv2.imwrite(
                            osp.join(save_subfolder,
                                     '{}.png'.format(img_names[ii])),
                            output[ii - i].astype(np.uint8))

                # calculate PSNR
                GT = np.copy(img_GTs[i:end])

                output = util.crop_border(output, crop_border)
                GT = util.crop_border(GT, crop_border)
                for m in range(i, end):
                    crt_psnr = util.calculate_psnr(output[m - i], GT[m - i])
                    crt_ssim = util.calculate_ssim(output[m - i], GT[m - i])
                    logger.info(
                        '{:3d} - {:25} \tPSNR: {:.6f} dB    SSIM: {:.6}'.
                        format(m + 1, img_names[m], crt_psnr, crt_ssim))

                    if m >= border_frame and m < max_idx - border_frame:  # center frames
                        avg_psnr_center += crt_psnr
                        avg_ssim_center += crt_ssim
                        N_center += 1
                    else:  # border frames
                        avg_psnr_border += crt_psnr
                        avg_ssim_border += crt_ssim
                        N_border += 1

            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border

            avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center +
                                                              N_border)
            avg_ssim_center = avg_ssim_center / N_center
            avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border

            avg_psnrs.append(avg_psnr)
            avg_psnr_centers.append(avg_psnr_center)
            avg_psnr_borders.append(avg_psnr_border)

            avg_ssims.append(avg_ssim)
            avg_ssim_centers.append(avg_ssim_center)
            avg_ssim_borders.append(avg_ssim_border)

            logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                        'Center PSNR: {:.6f} dB for {} frames; '
                        'Border PSNR: {:.6f} dB for {} frames.'.format(
                            sub, avg_psnr, (N_center + N_border),
                            avg_psnr_center, N_center, avg_psnr_border,
                            N_border))
            logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                        'Center SSIM: {:.6f} for {} frames; '
                        'Border SSIM: {:.6f} for {} frames.'.format(
                            sub, avg_ssim, (N_center + N_border),
                            avg_ssim_center, N_center, avg_ssim_border,
                            N_border))

        logger.info('################ Tidy Outputs ################')
        for sub_name, psnr, psnr_center, psnr_border, ssim, ssim_center, ssim_border in zip(
                evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders,
                avg_ssims, avg_ssim_centers, avg_ssim_borders):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sub_name, psnr, psnr_center, psnr_border))
            logger.info('Folder {} - Average SSIM: {:.6f} '
                        'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                            sub_name, ssim, ssim_center, ssim_border))

        logger.info('################ Final Results ################')
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))
        logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                    'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                        sum(avg_psnrs) / len(avg_psnrs), len(test_subs),
                        sum(avg_psnr_centers) / len(avg_psnr_centers),
                        sum(avg_psnr_borders) / len(avg_psnr_borders)))
        logger.info('Total Average SSIM: {:.6f} for {} clips. '
                    'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                        sum(avg_ssims) / len(avg_ssims), len(test_subs),
                        sum(avg_ssim_centers) / len(avg_ssim_centers),
                        sum(avg_ssim_borders) / len(avg_ssim_borders)))
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            #model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M.pth'
            model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 5  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64, 5, 8, 5, 10, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames'
        GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames'
        #test_dataset_folder = '../datasets/Vid4/BIx4'
        #GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 21
0
        print('Done processing')
        break
    prep_frame =prepare_frame(frame)
    currentFrame += 1

    if currentFrame< 5:
        prep_frame_lst.append(prep_frame)
    else:
        prep_frame_lst.append(prep_frame)
         # stack to Torch tensor
        imgs_LQ = get_Torch_tensor_from_prep_frame_lst(prep_frame_lst)

        # process each image
        count +=1
        img_name =  "frame_%05i" % count
        select_idx = data_util.index_generation(3, 5, N_in, padding=padding)
        imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

        output = util.single_forward(model, imgs_in)
        output = util.tensor2img(output.squeeze(0))

        if save_imgs:
            cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)


    # To stop duplicate images
    

# When everything done, release the capture
cap.release()
Esempio n. 22
0
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)



   # data_mode = str(args.dataset)
   # N_in = int(args.n_frames)
   # metrics = str(args.metrics)
   # output_format = str(args.output_format)


    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).


    # STAGE Vid4
    # Collecting results for Vid4

    model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    predeblur, HR_in = False, False
    back_RBs = 40

    N_model_default = 7
    data_mode = 'Vid4'

   # vid4_dir_map = {"calendar": 0, "city": 1, "foliage": 2, "walk": 3}
    vid4_results = {"calendar": {}, "city": {}, "foliage": {}, "walk": {}}

    #vid4_results = 4 * [[]]

    for N_in in range(1, N_model_default + 1):
        raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
        model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
        aposterior_GT_dataset_folder = '../datasets/Vid4/GT_7'

        crop_border = 0
        border_frame = N_in // 2  # border frames when evaluate
        padding = 'new_info'

        save_imgs = False

        raw_model.load_state_dict(torch.load(model_path), strict=True)

        model.nf = raw_model.nf
        model.center = N_in // 2  # if center is None else center
        model.is_predeblur = raw_model.is_predeblur
        model.HR_in = raw_model.HR_in
        model.w_TSA = raw_model.w_TSA

        if model.is_predeblur:
            model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
            model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
        else:
            if model.HR_in:
                model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            else:
                model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
        model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
        model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

        model.tsa_fusion.center = model.center

        model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
        model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

        model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
        model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
        model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
        model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
        model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
        model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
        model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
        model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
        model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
        model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
        model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
        model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
        model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

        model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

        model.recon_trunk = raw_model.recon_trunk

        model.upconv1 = raw_model.upconv1
        model.upconv2 = raw_model.upconv2
        model.pixel_shuffle = raw_model.pixel_shuffle
        model.HRconv = raw_model.HRconv
        model.conv_last = raw_model.conv_last

        model.lrelu = raw_model.lrelu

    #####################################################

        model.eval()
        model = model.to(device)

        #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
        subfolder_name_l = []

        subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
        subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

        subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
        for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):
            subfolder_name = osp.basename(subfolder)
            subfolder_name_l.append(subfolder_name)

            img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
            max_idx = len(img_path_l)

            print("MAX_IDX: ", max_idx)


            #### read LQ and GT images
            imgs_LQ = data_util.read_img_seq(subfolder)
            img_GT_l = []
            for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                img_GT_l.append(data_util.read_img(None, img_GT_path))

            img_GT_a = []
            for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                img_GT_a.append(data_util.read_img(None, img_GT_a_path))
            #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
            for img_idx, img_path in enumerate(img_path_l):
                img_name = osp.splitext(osp.basename(img_path))[0]
                select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs_in)
                else:
                    print("IMGS_IN SHAPE: ", imgs_in.shape)
                    output = util.single_forward(model, imgs_in)
                output = util.tensor2img(output.squeeze(0))

                # calculate PSNR
                output = output / 255.
                GT = np.copy(img_GT_l[img_idx])
                # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
                GT_a = np.copy(img_GT_a[img_idx])
                GT_a = data_util.bgr2ycbcr(GT_a, only_y=True)
                output_a = copy.deepcopy(output)

                output, GT = util.crop_border([output, GT], crop_border)
                crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                t = vid4_results[subfolder_name].get(str(img_name))

                if t != None:
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                else:
                    vid4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)


    ############################################################################
    #### model



#### writing vid4  results


    util.mkdirs('../results/calendar')
    util.mkdirs('../results/city')
    util.mkdirs('../results/foliage')
    util.mkdirs('../results/walk')
    save_folder = '../results/'

    for i, dir_name in enumerate(["calendar", "city", "foliage", "walk"]):
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in vid4_results[dir_name].items():
         #   cur_result = json.dumps(_)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
                #json.dump(cur_result, outfile)

            #cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)



###################################################################################





    # STAGE REDS

    reds4_results = {"000": {}, "011": {}, "015": {}, "020": {}}
    data_mode = 'sharp_bicubic'

    N_model_default = 5

    for N_in in range(1, N_model_default + 1):
        for stage in range(1,3):

            flip_test = False

            if data_mode == 'sharp_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
            elif data_mode == 'blur_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
            elif data_mode == 'blur':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
            elif data_mode == 'blur_comp':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
            else:
                raise NotImplementedError

            predeblur, HR_in = False, False
            back_RBs = 40
            if data_mode == 'blur_bicubic':
                predeblur = True
            if data_mode == 'blur' or data_mode == 'blur_comp':
                predeblur, HR_in = True, True
            if stage == 2:
                HR_in = True
                back_RBs = 20

            if stage == 1:
                test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            else:
                test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
                print('You should modify the test_dataset_folder path for stage 2')
            GT_dataset_folder = '../datasets/REDS4/GT'

            raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
            model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate
            # temporal padding mode
            if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
                padding = 'new_info'
            else:
                padding = 'replicate'
            save_imgs = True

            data_mode_t = copy.deepcopy(data_mode)
            if stage == 1 and data_mode_t != 'Vid4':
                data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
            save_folder = '../results/{}'.format(data_mode)
            data_mode = copy.deepcopy(data_mode_t)
            util.mkdirs(save_folder)
            util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)


            aposterior_GT_dataset_folder = '../datasets/REDS4/GT_5'

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate

            raw_model.load_state_dict(torch.load(model_path), strict=True)

            model.nf = raw_model.nf
            model.center = N_in // 2  # if center is None else center
            model.is_predeblur = raw_model.is_predeblur
            model.HR_in = raw_model.HR_in
            model.w_TSA = raw_model.w_TSA

            if model.is_predeblur:
                model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
                model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
            else:
                if model.HR_in:
                    model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                    model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                    model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                else:
                    model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
            model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

            model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

            model.tsa_fusion.center = model.center

            model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
            model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

            model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
            model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
            model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
            model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
            model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
            model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
            model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
            model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
            model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
            model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
            model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
            model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
            model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

            model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

            model.recon_trunk = raw_model.recon_trunk

            model.upconv1 = raw_model.upconv1
            model.upconv2 = raw_model.upconv2
            model.pixel_shuffle = raw_model.pixel_shuffle
            model.HRconv = raw_model.HRconv
            model.conv_last = raw_model.conv_last

            model.lrelu = raw_model.lrelu

    #####################################################

            model.eval()
            model = model.to(device)

            #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
            subfolder_name_l = []

            subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
            subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

            subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
            for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):

                subfolder_name = osp.basename(subfolder)
                subfolder_name_l.append(subfolder_name)
                save_subfolder = osp.join(save_folder, subfolder_name)

                img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
                max_idx = len(img_path_l)

                print("MAX_IDX: ", max_idx)

                print("SAVE FOLDER::::::", save_folder)

                if save_imgs:
                    util.mkdirs(save_subfolder)


            #### read LQ and GT images
                imgs_LQ = data_util.read_img_seq(subfolder)
                img_GT_l = []
                for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                    img_GT_l.append(data_util.read_img(None, img_GT_path))

                img_GT_a = []
                for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                    img_GT_a.append(data_util.read_img(None, img_GT_a_path))
                #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
                for img_idx, img_path in enumerate(img_path_l):
                    img_name = osp.splitext(osp.basename(img_path))[0]
                    select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                    imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                    if flip_test:
                        output = util.flipx4_forward(model, imgs_in)
                    else:
                        print("IMGS_IN SHAPE: ", imgs_in.shape)
                        output = util.single_forward(model, imgs_in)
                    output = util.tensor2img(output.squeeze(0))

                    if save_imgs and stage == 1:
                        cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
                    # calculate PSNR
                    if stage == 2:

                        output = output / 255.
                        GT = np.copy(img_GT_l[img_idx])
                        # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                        #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                        GT_a = np.copy(img_GT_a[img_idx])
                        output_a = copy.deepcopy(output)

                        output, GT = util.crop_border([output, GT], crop_border)
                        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                        crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                        output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                        crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                        t = reds4_results[subfolder_name].get(str(img_name))

                        if t != None:
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                        else:
                            reds4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)



    ############################################################################
    #### model



#### writing reds4  results

    util.mkdirs('../results/000')
    util.mkdirs('../results/011')
    util.mkdirs('../results/015')
    util.mkdirs('../results/020')
    save_folder = '../results/'

    for i, dir_name in enumerate(["000", "011", "015", "020"]):     #   +
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in reds4_results[dir_name].items():
           # cur_result = json.dumps(value.__dict__)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
Esempio n. 23
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '6'
    test_set = 'AI4K_val'  # Vid4 | YouKu10 | REDS4 | AI4K_val | zhibo | AI4K_val_bic
    test_name = 'PCD_Vis_Test_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_A01xxx_900000_AI4K_5000'  #     'AI4K_val_Denoise_A02_420000'
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_val_bic':
        test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_bic/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/'
    elif test_set == 'zhibo':
        test_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_LR/'
        GT_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_HR/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #model_path = '../experiments/A01b/models/250000_G.pth'
    #model_path = '../experiments/A02_predenoise/models/415000_G.pth'
    model_path = '../experiments/A37_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/5000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR_RES(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True  #True | False

    save_folder = '../results/{}'.format(test_name)
    if test_set == 'zhibo':
        save_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/SR_png_sample_150'
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    print(subfolder_l)
    print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0,
                torch.LongTensor(select_idx)).unsqueeze(0).cpu()  #to(device)
            print(imgs_in.size())

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                start_time = time.time()
                output = util.single_forward(model, imgs_in)
                end_time = time.time()
                print('Forward One image:', end_time - start_time)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 24
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    #### model
    data_mode = 'sharp_bicubic'
    if stage == 1:
        model_path = '../experiments/001_EDVRwoTSA_scratch_lr4e-4_600k_SR4K_LrCAR4S/models/200000_G.pth'
    else:
        model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'

    N_in = 3  # use N_in images to restore one HR image

    predeblur, HR_in = False, False
    back_RBs = 10
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in,
                           w_TSA=False)

    #### dataset
    val_txt = '/home/mcc/4khdr/val.txt'
    if stage == 1:
        test_dataset_folder = '/home/mcc/4khdr/image/540p'
        GT_dataset_folder = '/home/mcc/4khdr/image/4k'
    else:
        test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
        print('You should modify the test_dataset_folder path for stage 2')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = False

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    with open(val_txt, 'r') as f:
        split_name_l = []
        name_l = [x.strip() for x in f.readlines()]
        # for name in name_l:
        #     for i in range(4):
        #         split_name_l.append(name + 'x{}'.format(i))
    subfolder_l = sorted(
        [osp.join(test_dataset_folder, name) for name in name_l])
    subfolder_GT_l = sorted(
        [osp.join(GT_dataset_folder, name) for name in name_l])
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # evaluate on RGB channels
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)

            sys.stdout.write('\r' +
                             '{:03d}/{:03d}'.format(img_idx, len(img_path_l)))
            sys.stdout.flush()

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - frames {} -  Average PSNR: {:.6f} dB; '
                    'Center PSNR: {:.6f} dB; '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name,
                                                     (N_center + N_border),
                                                     avg_psnr, avg_psnr_center,
                                                     avg_psnr_border))
        break

    logger.info('################ Tidy Outputs ################')
    for i in range(len(subfolder_name_l)):
        logger.info(
            'Folder {} - Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB. '
            .format(subfolder_name_l[i], avg_psnr_l[i], avg_psnr_center_l[i],
                    avg_psnr_border_l[i]))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info(
        'Total clips {}. Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB.'
        'Score: {:.6f}'.format(len(subfolder_l),
                               sum(avg_psnr_l) / len(avg_psnr_l),
                               sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                               sum(avg_psnr_border_l) / len(avg_psnr_border_l),
                               sum(avg_psnr_l) / len(avg_psnr_l) / 50))
Esempio n. 25
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'SDR_4bit'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'SDR_4bit':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    # use N_in images to restore one high bitdepth image
    N_in = 5

    # predeblur: predeblur for blurry input
    # HR_in: downsample high resolution input
    predeblur, HR_in = False, False
    back_RBs = 40
    predeblur = True
    HR_in = True
    if data_mode == 'SDR_4bit':
        # predeblur, HR_in = False, True
        pass
    if stage == 2:
        HR_in = True
        back_RBs = 20
    # EDVR(num_feature_map, num_input_frames, deformable_groups?, front_RBs,
    #      back_RBs, predeblur, HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if stage == 1:
        test_dataset_folder = '../datasets/{}'.format(data_mode)
    else:
        test_dataset_folder = '../'
        print('You should modify the test_dataset_folder path for stage 2')
    GT_dataset_folder = '../datasets/SDR_10bit/'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LBD and GT images
        #### resize to avoid cuda out of memory, 2160x3840->720x1280
        imgs_LBD = data_util.read_img_seq(subfolder,
                                          scale=65535.,
                                          zoomout=(1280, 720))
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(
                data_util.read_img(None,
                                   img_GT_path,
                                   scale=65535.,
                                   zoomout=True))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            # generate frame index
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LBD.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                # self ensemble with fipping input at four different directions
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0), out_type=np.uint16)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            # output = output / 255.
            output = output / 65535.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 65535, GT * 65535)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Esempio n. 26
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'ai4khdr_test'
    flip_test = False

    ############################################################################
    #### model
    #################
    if data_mode == 'ai4khdr_test':
        model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
    else:
        raise NotImplementedError
    N_in = 5
    front_RBs = 5
    back_RBs = 10
    predeblur, HR_in = False, False
    model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in)

    ############################################################################
    #### dataset
    #################
    if data_mode == 'ai4khdr_test':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/test/540p_frames'
    else:
        raise NotImplementedError

    ############################################################################
    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'ai4khdr_test':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp())
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
        logger.info('Folder {}'.format(subfolder_name))

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)

    parser = argparse.ArgumentParser()

    parser.add_argument('dataset')
    parser.add_argument('n_frames')
    parser.add_argument('stage')

    args = parser.parse_args()

    data_mode = str(args.dataset)
    N_in = int(args.n_frames)
    stage = int(args.stage)

    #if args.command == 'start':
    #    start(int(args.params[0]))
    #elif args.command == 'stop':
    #    stop(args.params[0], int(args.params[1]))
    #elif args.command == 'stop_all':
    #    stop_all(args.params[0])

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    #stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    #### dataset
    if data_mode == 'Vid4':
        N_model_default = 7
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        N_model_default = 5
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    raw_model = EDVR_arch.EDVR(128,
                               N_model_default,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    data_mode_t = copy.deepcopy(data_mode)
    if stage == 1 and data_mode_t != 'Vid4':
        data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
    save_folder = '../results/{}'.format(data_mode)
    data_mode = copy.deepcopy(data_mode_t)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    print([a for a in dir(model)
           if not callable(getattr(model, a))])  # not a.startswith('__') and

    #model.load_state_dict(torch.load(model_path), strict=True)
    raw_model.load_state_dict(torch.load(model_path), strict=True)

    #   model.load_state_dict(torch.load(model_path), strict=True)

    #### change model so it can work with less input

    model.nf = raw_model.nf
    model.center = N_in // 2  #  if center is None else center
    model.is_predeblur = raw_model.is_predeblur
    model.HR_in = raw_model.HR_in
    model.w_TSA = raw_model.w_TSA
    #ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)

    #### extract features (for each frame)
    if model.is_predeblur:
        model.pre_deblur = raw_model.pre_deblur  #Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
        model.conv_1x1 = raw_model.conv_1x1  #nn.Conv2d(nf, nf, 1, 1, bias=True)
    else:
        if model.HR_in:
            model.conv_first_1 = raw_model.conv_first_1  #nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.conv_first_2 = raw_model.conv_first_2  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.conv_first_3 = raw_model.conv_first_3  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        else:
            model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
    model.feature_extraction = raw_model.feature_extraction  #  arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
    model.fea_L2_conv1 = raw_model.fea_L2_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L2_conv2 = raw_model.fea_L2_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
    model.fea_L3_conv1 = raw_model.fea_L3_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L3_conv2 = raw_model.fea_L3_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

    model.pcd_align = raw_model.pcd_align  #PCD_Align(nf=nf, groups=groups)

    ######## Resize TSA

    model.tsa_fusion.center = model.center
    # temporal attention (before fusion conv)
    model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
    model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

    # fusion conv: using 1x1 to save parameters and computation

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)
    #print(raw_model.tsa_fusion.fea_fusion.weight[127][639].shape)
    #print("MAIN SHAPE(FEA): ", raw_model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.fea_fusion = copy.deepcopy(
        raw_model.tsa_fusion.fea_fusion)
    model.tsa_fusion.fea_fusion.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in *
                                                                  128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.fea_fusion.bias = raw_model.tsa_fusion.fea_fusion.bias

    # spatial attention (after fusion conv)
    model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
    model.tsa_fusion.sAtt_1.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in *
                                                              128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.sAtt_1.bias = raw_model.tsa_fusion.sAtt_1.bias

    #print(N_in * 128)
    #print(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :].shape)
    print("MODEL TSA SHAPE: ", model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
    model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
    model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
    model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
    model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
    model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
    model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
    model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
    model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
    model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
    model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

    model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

    #if model.w_TSA:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
    #else:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)

    #   print(self.tsa_fusion)

    #### reconstruction
    model.recon_trunk = raw_model.recon_trunk  # arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
    #### upsampling
    model.upconv1 = raw_model.upconv1  #nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
    model.upconv2 = raw_model.upconv2  #nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
    model.pixel_shuffle = raw_model.pixel_shuffle  # nn.PixelShuffle(2)
    model.HRconv = raw_model.HRconv
    model.conv_last = raw_model.conv_last

    #### activation function
    model.lrelu = raw_model.lrelu

    #####################################################

    model.eval()
    model = model.to(device)

    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)

        print("MAX_IDX: ", max_idx)

        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_ssim, avg_ssim_border, avg_ssim_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            if data_mode == "blur":
                select_idx = data_util.glarefree_index_generation(
                    img_idx, max_idx, N_in, padding=padding)
            else:
                select_idx = data_util.index_generation(
                    img_idx, max_idx, N_in, padding=padding)  #  HERE GOTCHA
            print("SELECT IDX: ", select_idx)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                print("IMGS_IN SHAPE: ", imgs_in.shape)  # check this
                output = util.single_forward(model, imgs_in)  # error here 1
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate SSIM
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_ssim = util.calculate_ssim(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tSSIM: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_ssim_border += crt_ssim
                N_border += 1

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average SSIM: {:.6f} dB for {} frames; '
                    'Center SSIM: {:.6f} dB for {} frames; '
                    'Border SSIM: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, ssim, ssim_center, ssim_border in zip(
            subfolder_name_l, avg_ssim_l, avg_ssim_center_l,
            avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f} dB. '
                    'Center SSIM: {:.6f} dB. '
                    'Border SSIM: {:.6f} dB.'.format(subfolder_name, ssim,
                                                     ssim_center, ssim_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average SSIM: {:.6f} dB for {} clips. '
                'Center SSIM: {:.6f} dB. Border SSIM: {:.6f} dB.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
Esempio n. 28
0
def create_test_png(model_path, device, gpu_id, opt, subfolder_l, save_folder, save_imgs,
                    frame_notation, N_in, PAD, flip_test, end, total_run_time, logger, padding):


    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    #if (torch.cuda.is_available()):
    model = model.cuda(gpu_id)


    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        # subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        #if not os.path.exists(subfolder_GT):
        #    continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        #subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        #img_GT_l = []
        #for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #    img_GT_l.append(data_util.read_img(None, img_GT_path))

        #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):

            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding)

            if not log1 == None:
                logger.info('screen change')
                logger.info(nota)
                logger.info(log1)
                logger.info(log2)



            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).cuda(gpu_id)  # 960 x 540


            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4] # 960
            intHeight_ori = imgs_in.shape[3] # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256



            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)# N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY )
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W
                #X0 = X0.cuda(gpu_id)

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)


            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            #y_all = y_all / 255.
            #GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            #if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #    GT = data_util.bgr2ycbcr(GT, only_y=True)
            #    y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            #y_all, GT = util.crop_border([y_all, GT], crop_border)
            #crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder, img_idx + 1, img_name))
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    # parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--gpu_number", type=str, required=True)
    parser.add_argument("--gpu_index", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt',
                        type=str,
                        required=True,
                        help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    gpu_number = int(args.gpu_number)
    gpu_index = int(args.gpu_index)

    PAD = 32

    total_run_time = AverageMeter()
    # print("GPU ", torch.cuda.device_count())

    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    print('export CUDA_VISIBLE_DEVICES=' + str(args.gpu_id))

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    # GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'],
                           nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'],
                           front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'],
                           HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        # GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    # load screen change notation
    import json
    with open(args.screen_notation) as f:
        frame_notation = json.load(f)

    subfolder_n = len(subfolder_l)
    subfolder_l = subfolder_l[int(subfolder_n * gpu_index /
                                  gpu_number):int(subfolder_n *
                                                  (gpu_index + 1) /
                                                  gpu_number)]

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        # subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        #if not os.path.exists(subfolder_GT):
        #    continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        #img_GT_l = []
        #for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #    img_GT_l.append(data_util.read_img(None, img_GT_path))

        #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):

            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(
                input_subfolder,
                frame_notation,
                img_idx,
                max_idx,
                N_in,
                padding=padding)

            if not log1 == None:
                logger.info('screen change')
                logger.info(nota)
                logger.info(log1)
                logger.info(log2)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(
                    device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX +
                                   1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY +
                                    1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d(
                [0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " +
                  str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([
                intPaddingLeft, intPaddingRight, intPaddingTop,
                intPaddingBottom
            ])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY)
                    and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY)
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :, split_j *
                             split_lengthY:(split_j + 1) * split_lengthY +
                             intPaddingBottom + intPaddingTop, split_i *
                             split_lengthX:(split_i + 1) * split_lengthX +
                             intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop *
                                         scale:(intPaddingTop + intHeight) *
                                         scale, intPaddingLeft *
                                         scale:(intPaddingLeft + intWidth) *
                                         scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " +
                  str(time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            #y_all = y_all / 255.
            #GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            #if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #    GT = data_util.bgr2ycbcr(GT, only_y=True)
            #    y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            #y_all, GT = util.crop_border([y_all, GT], crop_border)
            #crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder,
                                                       img_idx + 1, img_name))

            #if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
            #    avg_psnr_center += crt_psnr
            #    N_center += 1
            #else:  # border frames
            #    avg_psnr_border += crt_psnr
            #    N_border += 1

        #avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        #avg_psnr_center = avg_psnr_center / N_center
        #avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        #avg_psnr_l.append(avg_psnr)
        #avg_psnr_center_l.append(avg_psnr_center)
        #avg_psnr_border_l.append(avg_psnr_border)

        #logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
        #            'Center PSNR: {:.6f} dB for {} frames; '
        #            'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
        #                                                           (N_center + N_border),
        #                                                           avg_psnr_center, N_center,
        #                                                           avg_psnr_border, N_border))

    #logger.info('################ Tidy Outputs ################')
    #for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
    #                                                          avg_psnr_center_l, avg_psnr_border_l):
    #    logger.info('Folder {} - Average PSNR: {:.6f} dB. '
    #                'Center PSNR: {:.6f} dB. '
    #                'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
    #                                                 psnr_border))
    #logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Esempio n. 30
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'ai4khdr_valid'
    flip_test = False

    ############################################################################
    #### model
    #################
    if data_mode == 'ai4khdr_valid':
        model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
    else:
        raise NotImplementedError
    N_in = 5
    front_RBs = 5
    back_RBs = 10
    predeblur, HR_in = False, False
    model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in)

    ############################################################################
    #### dataset
    #################
    if data_mode == 'ai4khdr_valid':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames'
        GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames'
    else:
        raise NotImplementedError

    ############################################################################
    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'ai4khdr_valid':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp())
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for ai4khdr_valid, evaluate on the Y channel
            if data_mode == 'ai4khdr_valid':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))