Exemple #1
0
def worker(path, save_folder, mode, compression_level):
    img_name = os.path.basename(path)
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR
    if mode == 'gray':
        img_y = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        img_y = bgr2ycbcr(img, only_y=True)
    cv2.imwrite(os.path.join(save_folder, img_name), img_y,
                [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
    return 'Processing {:s} ...'.format(img_name)
Exemple #2
0
def worker(path, save_folder, mode, compression_level):
    img_name = os.path.basename(path)
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR
    folder_name = path.split('/')[-2]
    save_folder_new = save_folder+'/'+folder_name+'/'
    if not os.path.exists(save_folder_new):
        os.makedirs(save_folder_new)
    if mode == 'gray':
        img_y = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        img_y = bgr2ycbcr(img, only_y=True)
    if int(img_name[:-4])<10:
        cv2.imwrite(os.path.join(save_folder_new, img_name), img_y,
                    [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
    return 'Processing {:s} ...'.format(img_name)
Exemple #3
0
def cascade_test_main(opt, logger, model, test_loader):
    test_set_name = test_loader.dataset.opt['name']
    logger.info('\nTesting [{:s}]...'.format(test_set_name))
    results = []

    try:
        total_nfe = model.netG.module.conv_trunk.nfe
    except AttributeError:
        total_nfe = None

    for data in test_loader:
        need_GT = False if test_loader.dataset.opt[
            'dataroot_GT'] is None else True
        model.feed_data(data, need_GT=need_GT)
        img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
        img_name = osp.splitext(osp.basename(img_path))[0]

        model.test()
        visuals = model.get_current_visuals(need_GT=need_GT)
        sr_img = util.tensor2img(visuals['SR'])  # uint8

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.

            crop_border = opt['crop_border'] if opt['crop_border'] else opt[
                'scale']
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

            psnr = util.calculate_psnr(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            # niqe = util.calculate_niqe(cropped_sr_img * 255)

            if total_nfe is not None:
                last_nfe = model.netG.module.conv_trunk.nfe - total_nfe
                total_nfe = model.netG.module.conv_trunk.nfe
            else:
                last_nfe = None

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                if crop_border == 0:
                    cropped_sr_img_y = sr_img_y
                    cropped_gt_img_y = gt_img_y
                else:
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                results.append((psnr, ssim, psnr_y, ssim_y))
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}; NFE: {}'
                    .format(img_name, psnr, ssim, psnr_y, ssim_y, last_nfe))
            else:
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; NFE: {}'.format(
                        img_name, psnr, ssim, last_nfe))
        else:
            logger.info(img_name)

    return results
Exemple #4
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(
                    current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    if opt['crop_scale'] is not None:
                        crop_size = opt['crop_scale']
                    else:
                        crop_size = opt['scale']
                    if crop_size <= 0:
                        cropped_sr_img = sr_img.copy()
                        cropped_gt_img = gt_img.copy()
                    else:
                        if len(gt_img.shape) < 3:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        else:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                    #avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.psnr(
                        cropped_sr_img_y,
                        cropped_gt_img_y)  ##########only y channel

                avg_psnr = avg_psnr / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
Exemple #5
0
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)



   # data_mode = str(args.dataset)
   # N_in = int(args.n_frames)
   # metrics = str(args.metrics)
   # output_format = str(args.output_format)


    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).


    # STAGE Vid4
    # Collecting results for Vid4

    model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    predeblur, HR_in = False, False
    back_RBs = 40

    N_model_default = 7
    data_mode = 'Vid4'

   # vid4_dir_map = {"calendar": 0, "city": 1, "foliage": 2, "walk": 3}
    vid4_results = {"calendar": {}, "city": {}, "foliage": {}, "walk": {}}

    #vid4_results = 4 * [[]]

    for N_in in range(1, N_model_default + 1):
        raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
        model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
        aposterior_GT_dataset_folder = '../datasets/Vid4/GT_7'

        crop_border = 0
        border_frame = N_in // 2  # border frames when evaluate
        padding = 'new_info'

        save_imgs = False

        raw_model.load_state_dict(torch.load(model_path), strict=True)

        model.nf = raw_model.nf
        model.center = N_in // 2  # if center is None else center
        model.is_predeblur = raw_model.is_predeblur
        model.HR_in = raw_model.HR_in
        model.w_TSA = raw_model.w_TSA

        if model.is_predeblur:
            model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
            model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
        else:
            if model.HR_in:
                model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            else:
                model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
        model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
        model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

        model.tsa_fusion.center = model.center

        model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
        model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

        model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
        model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
        model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
        model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
        model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
        model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
        model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
        model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
        model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
        model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
        model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
        model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
        model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

        model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

        model.recon_trunk = raw_model.recon_trunk

        model.upconv1 = raw_model.upconv1
        model.upconv2 = raw_model.upconv2
        model.pixel_shuffle = raw_model.pixel_shuffle
        model.HRconv = raw_model.HRconv
        model.conv_last = raw_model.conv_last

        model.lrelu = raw_model.lrelu

    #####################################################

        model.eval()
        model = model.to(device)

        #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
        subfolder_name_l = []

        subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
        subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

        subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
        for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):
            subfolder_name = osp.basename(subfolder)
            subfolder_name_l.append(subfolder_name)

            img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
            max_idx = len(img_path_l)

            print("MAX_IDX: ", max_idx)


            #### read LQ and GT images
            imgs_LQ = data_util.read_img_seq(subfolder)
            img_GT_l = []
            for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                img_GT_l.append(data_util.read_img(None, img_GT_path))

            img_GT_a = []
            for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                img_GT_a.append(data_util.read_img(None, img_GT_a_path))
            #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
            for img_idx, img_path in enumerate(img_path_l):
                img_name = osp.splitext(osp.basename(img_path))[0]
                select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs_in)
                else:
                    print("IMGS_IN SHAPE: ", imgs_in.shape)
                    output = util.single_forward(model, imgs_in)
                output = util.tensor2img(output.squeeze(0))

                # calculate PSNR
                output = output / 255.
                GT = np.copy(img_GT_l[img_idx])
                # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
                GT_a = np.copy(img_GT_a[img_idx])
                GT_a = data_util.bgr2ycbcr(GT_a, only_y=True)
                output_a = copy.deepcopy(output)

                output, GT = util.crop_border([output, GT], crop_border)
                crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                t = vid4_results[subfolder_name].get(str(img_name))

                if t != None:
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                else:
                    vid4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)


    ############################################################################
    #### model



#### writing vid4  results


    util.mkdirs('../results/calendar')
    util.mkdirs('../results/city')
    util.mkdirs('../results/foliage')
    util.mkdirs('../results/walk')
    save_folder = '../results/'

    for i, dir_name in enumerate(["calendar", "city", "foliage", "walk"]):
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in vid4_results[dir_name].items():
         #   cur_result = json.dumps(_)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
                #json.dump(cur_result, outfile)

            #cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)



###################################################################################





    # STAGE REDS

    reds4_results = {"000": {}, "011": {}, "015": {}, "020": {}}
    data_mode = 'sharp_bicubic'

    N_model_default = 5

    for N_in in range(1, N_model_default + 1):
        for stage in range(1,3):

            flip_test = False

            if data_mode == 'sharp_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
            elif data_mode == 'blur_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
            elif data_mode == 'blur':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
            elif data_mode == 'blur_comp':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
            else:
                raise NotImplementedError

            predeblur, HR_in = False, False
            back_RBs = 40
            if data_mode == 'blur_bicubic':
                predeblur = True
            if data_mode == 'blur' or data_mode == 'blur_comp':
                predeblur, HR_in = True, True
            if stage == 2:
                HR_in = True
                back_RBs = 20

            if stage == 1:
                test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            else:
                test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
                print('You should modify the test_dataset_folder path for stage 2')
            GT_dataset_folder = '../datasets/REDS4/GT'

            raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
            model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate
            # temporal padding mode
            if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
                padding = 'new_info'
            else:
                padding = 'replicate'
            save_imgs = True

            data_mode_t = copy.deepcopy(data_mode)
            if stage == 1 and data_mode_t != 'Vid4':
                data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
            save_folder = '../results/{}'.format(data_mode)
            data_mode = copy.deepcopy(data_mode_t)
            util.mkdirs(save_folder)
            util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)


            aposterior_GT_dataset_folder = '../datasets/REDS4/GT_5'

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate

            raw_model.load_state_dict(torch.load(model_path), strict=True)

            model.nf = raw_model.nf
            model.center = N_in // 2  # if center is None else center
            model.is_predeblur = raw_model.is_predeblur
            model.HR_in = raw_model.HR_in
            model.w_TSA = raw_model.w_TSA

            if model.is_predeblur:
                model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
                model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
            else:
                if model.HR_in:
                    model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                    model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                    model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                else:
                    model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
            model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

            model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

            model.tsa_fusion.center = model.center

            model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
            model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

            model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
            model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
            model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
            model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
            model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
            model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
            model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
            model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
            model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
            model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
            model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
            model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
            model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

            model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

            model.recon_trunk = raw_model.recon_trunk

            model.upconv1 = raw_model.upconv1
            model.upconv2 = raw_model.upconv2
            model.pixel_shuffle = raw_model.pixel_shuffle
            model.HRconv = raw_model.HRconv
            model.conv_last = raw_model.conv_last

            model.lrelu = raw_model.lrelu

    #####################################################

            model.eval()
            model = model.to(device)

            #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
            subfolder_name_l = []

            subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
            subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

            subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
            for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):

                subfolder_name = osp.basename(subfolder)
                subfolder_name_l.append(subfolder_name)
                save_subfolder = osp.join(save_folder, subfolder_name)

                img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
                max_idx = len(img_path_l)

                print("MAX_IDX: ", max_idx)

                print("SAVE FOLDER::::::", save_folder)

                if save_imgs:
                    util.mkdirs(save_subfolder)


            #### read LQ and GT images
                imgs_LQ = data_util.read_img_seq(subfolder)
                img_GT_l = []
                for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                    img_GT_l.append(data_util.read_img(None, img_GT_path))

                img_GT_a = []
                for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                    img_GT_a.append(data_util.read_img(None, img_GT_a_path))
                #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
                for img_idx, img_path in enumerate(img_path_l):
                    img_name = osp.splitext(osp.basename(img_path))[0]
                    select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                    imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                    if flip_test:
                        output = util.flipx4_forward(model, imgs_in)
                    else:
                        print("IMGS_IN SHAPE: ", imgs_in.shape)
                        output = util.single_forward(model, imgs_in)
                    output = util.tensor2img(output.squeeze(0))

                    if save_imgs and stage == 1:
                        cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
                    # calculate PSNR
                    if stage == 2:

                        output = output / 255.
                        GT = np.copy(img_GT_l[img_idx])
                        # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                        #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                        GT_a = np.copy(img_GT_a[img_idx])
                        output_a = copy.deepcopy(output)

                        output, GT = util.crop_border([output, GT], crop_border)
                        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                        crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                        output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                        crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                        t = reds4_results[subfolder_name].get(str(img_name))

                        if t != None:
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                        else:
                            reds4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)



    ############################################################################
    #### model



#### writing reds4  results

    util.mkdirs('../results/000')
    util.mkdirs('../results/011')
    util.mkdirs('../results/015')
    util.mkdirs('../results/020')
    save_folder = '../results/'

    for i, dir_name in enumerate(["000", "011", "015", "020"]):     #   +
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in reds4_results[dir_name].items():
           # cur_result = json.dumps(value.__dict__)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
Exemple #6
0
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
            sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
            psnr = util.calculate_psnr(sr_img, gt_img)
            ssim = util.calculate_ssim(sr_img, gt_img)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
                ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                    .format(img_name, psnr, ssim, psnr_y, ssim_y))
            else:
                logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                    img_name, psnr, ssim))
        else:
            logger.info(img_name)
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            #model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M.pth'
            model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 5  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64, 5, 8, 5, 10, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames'
        GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames'
        #test_dataset_folder = '../datasets/Vid4/BIx4'
        #GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Exemple #8
0
            cropped_sr_img = sr_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]
            cropped_gt_img = gt_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]

            psnr = util.calculate_psnr(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            if opt['val_lpips']:
                test_results['lpips'].append(lpips)

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                            crop_border:-crop_border]
                cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                            crop_border:-crop_border]
                psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                if opt['val_lpips']:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}; LPIPS: {:.3f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y, lpips))
                else:
def main():
    #################
    # configurations
    #################
    device = torch.device("cuda")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    data_mode = ("licensePlate_blur_bicubic"
                 )  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == "Vid4":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth"
        else:
            raise ValueError("Vid4 does not support stage 2.")
    elif data_mode == "sharp_bicubic":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SR_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth"
    elif data_mode == "blur_bicubic":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth"
    elif data_mode == "blur":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblur_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth"
    elif data_mode == "blur_comp":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth"
        else:
            model_path = (
                "../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth"
            )
    elif data_mode == "licensePlate_blur_bicubic":
        model_path = ("/workspace/video_sr/EDVR/experiments/" +
                      "pretrained_models/EDVR_licensePlate_SRblur_L.pth")
    else:
        raise NotImplementedError

    if data_mode == "Vid4":
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if (data_mode == "blur_bicubic") or (data_mode
                                         == "licensePlate_blur_bicubic"):
        predeblur = True
    elif data_mode == "blur" or data_mode == "blur_comp":
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == "Vid4":
        test_dataset_folder = "../datasets/Vid4/BIx4"
        GT_dataset_folder = "../datasets/Vid4/GT"
    elif data_mode == "licensePlate_blur_bicubic":
        test_dataset_folder = "../datasets/license_plate2/BIx4"
        GT_dataset_folder = "../datasets/license_plate2/GT"
    else:
        if stage == 1:
            test_dataset_folder = "../datasets/REDS4/{}".format(data_mode)
        else:
            test_dataset_folder = "../results/REDS-EDVR_REDS_SR_L_flipx4"
            print("You should modify the test_dataset_folder path for stage 2")
        GT_dataset_folder = "../datasets/REDS4/GT"

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == "Vid4" or data_mode == "sharp_bicubic":
        padding = "new_info"
    else:
        padding = "replicate"
    save_imgs = True

    save_folder = "../results/{}".format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger("base",
                      save_folder,
                      "test",
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger("base")

    #### log info
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, "*")))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, "*")))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, "*")))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = test_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, "*"))):
            img_GT_l.append(test_util.read_img(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = test_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = (imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device))

            if flip_test:
                output = test_util.flipx4_forward(model, imgs_in)
            else:
                output = test_util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, "{}.png".format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.0
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == "Vid4":  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = test_util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info("{:3d} - {:25} \tPSNR: {:.6f} dB".format(
                img_idx + 1, img_name, crt_psnr))

            if (img_idx >= border_frame
                    and img_idx < max_idx - border_frame):  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info("Folder {} - Average PSNR: {:.6f} dB for {} frames; "
                    "Center PSNR: {:.6f} dB for {} frames; "
                    "Border PSNR: {:.6f} dB for {} frames.".format(
                        subfolder_name,
                        avg_psnr,
                        (N_center + N_border),
                        avg_psnr_center,
                        N_center,
                        avg_psnr_border,
                        N_border,
                    ))

    logger.info("################ Tidy Outputs ################")
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info("Folder {} - Average PSNR: {:.6f} dB. "
                    "Center PSNR: {:.6f} dB. "
                    "Border PSNR: {:.6f} dB.".format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info("################ Final Results ################")
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))
    logger.info("Total Average PSNR: {:.6f} dB for {} clips. "
                "Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.".format(
                    sum(avg_psnr_l) / len(avg_psnr_l),
                    len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l),
                ))
    def __getitem__(self, index):
        # path_LQ = self.data_info['path_LQ'][index]
        # path_GT = self.data_info['path_GT'][index]
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]
        if self.cache_data:
            select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],padding=self.opt['padding'])

            #select_idx = util.index_generation_process_screen_change_withlog_fixbug(folder, self.frame_notation, idx, max_idx, self.opt['N_frames'],
            #                                   padding=self.opt['padding'])

            # select_idx, log1, log2, nota = util.index_generation_process_screen_change_withlog_fixbug(
            #     folder, self.frame_notation, idx, max_idx, self.opt['N_frames'], padding=self.opt['padding'],
            #     enable=1)

            # if not log1 == None:
            #     print('screen change')
            #     print(nota)
            #     print(log1)
            #     print(log2)

            imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
            img_GT = self.imgs_GT[folder][idx]


        else:
            select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],padding=self.opt['padding'])

            imgs_LQ_path = []
            for v in select_idx:
                imgs_LQ_path.append(self.imgs_LQ[folder][v])
            img_GT_path = self.imgs_GT[folder][idx]

            # read gt
            img_GT = util.read_img(None, img_GT_path)
            img_LQ_l = []
            for v in imgs_LQ_path:
                img_LQ = util.read_img(None, v)
                img_LQ_l.append(img_LQ)

            #color
            if self.opt['color'] == 'YUV':
                img_LQ_l = [util.bgr2ycbcr(v,only_y=False)  for v in img_LQ_l]
                img_GT = util.bgr2ycbcr(img_GT,only_y=False)


            # stack LQ images to NHWC, N is the frame number
            img_LQs = np.stack(img_LQ_l, axis=0)
            # BGR to RGB
            img_GT = img_GT[:, :, [2, 1, 0]] # HWC
            img_LQs = img_LQs[:, :, :, [2, 1, 0]]

            # HWC to CHW  numpy to tensor
            # if YUV -> VUY
            img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
            imgs_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,(0, 3, 1, 2)))).float()

        return {
            'LQs': imgs_LQ,
            'GT': img_GT,
            'folder': folder,
            'idx': self.data_info['idx'][index],
            'border': border
        }
Exemple #11
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options YMAL file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    opt = option.dict_to_nonedict(opt)

    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
    util.setup_logger('base',
                      opt['path']['log'],
                      'test_' + opt['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    #### Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    model = create_model(opt)
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_GT = False if test_loader.dataset.opt[
                'dataroot_GT'] is None else True
            model.feed_data(data, need_GT=need_GT)
            img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
            img_name = osp.splitext(osp.basename(img_path))[0]

            model.test()
            visuals = model.get_current_visuals(need_GT=need_GT)

            sr_img = util.tensor2img(visuals['rlt'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir,
                                         img_name + suffix + '.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(visuals['GT'])
                sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                  opt['scale'])
                psnr = util.calculate_psnr(sr_img, gt_img)
                ssim = util.calculate_ssim(sr_img, gt_img)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                    psnr_y = util.calculate_psnr(sr_img_y * 255,
                                                 gt_img_y * 255)
                    ssim_y = util.calculate_ssim(sr_img_y * 255,
                                                 gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_GT:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info(
                '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info(
                    '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                    .format(ave_psnr_y, ave_ssim_y))
Exemple #12
0
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth')
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth')
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth')
    elif data_mode == 'blur':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth')
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth')
    else:
        raise NotImplementedError
    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5
    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = osp.join(root, '../datasets/Vid4/BIx4/*')
        GT_dataset_folder = osp.join(root, '../datasets/Vid4/GT/*')
    else:
        if stage == 1:
            test_dataset_folder = osp.join(root, f'../datasets/REDS4/{data_mode}/*')
        else:
            raise ValueError('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = osp.join(root, '../datasets/REDS4/GT/*')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    device = torch.device('cuda')
    save_folder = f'../results/{data_mode}'
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info(f'Data: {data_mode} - {test_dataset_folder}')
    logger.info(f'Padding mode: {padding}')
    logger.info(f'Model path: {model_path}')
    logger.info(f'Save images: {save_imgs}')
    logger.info(f'Flip Test: {flip_test}')

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))
    sub_folder_GT_l = sorted(glob.glob(GT_dataset_folder))
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l):
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))):
            img_GT_l.append(read_image(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
            # get input images
            imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            output = single_forward(model, imgs_in)
            output_f = output.data.float().cpu().squeeze(0)

            if flip_test:
                # flip W
                output = single_forward(model, torch.flip(imgs_in, (-1, )))
                output = torch.flip(output, (-1, ))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output
                # flip H
                output = single_forward(model, torch.flip(imgs_in, (-2, )))
                output = torch.flip(output, (-2, ))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output
                # flip both H and W
                output = single_forward(model, torch.flip(imgs_in, (-2, -1)))
                output = torch.flip(output, (-2, -1))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output

                output_f = output_f / 4

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(osp.join(save_sub_folder, f'{c_idx:08d}.png'), output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
            crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
            logger.info(f'{img_idx+1:3d} - {c_idx:25}.png \tPSNR: {crt_psnr:.6f} dB')

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info(f'Folder {sub_folder_name} - Average PSNR: {avg_psnr:.6f} dB for {(cal_n_center + cal_n_border)} frames; '
                    f'Center PSNR: {avg_psnr_center:.6f} dB for {cal_n_center} frames; '
                    f'Border PSNR: {avg_psnr_border:.6f} dB for {cal_n_border} frames.')

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
                                                    avg_psnr_center_l, avg_psnr_border_l):
        logger.info(f'Folder {name} - Average PSNR: {psnr:.6f} dB. '
                    f'Center PSNR: {psnr_center:.6f} dB. '
                    f'Border PSNR: {psnr_border:.6f} dB.')
    logger.info('################ Final Results ################')
    logger.info(f'Data: {data_mode} - {test_dataset_folder}')
    logger.info(f'Padding mode: {padding}')
    logger.info(f'Model path: {model_path}')
    logger.info(f'Save images: {save_imgs}')
    logger.info(f'Flip Test: {flip_test}')
    logger.info(f'Total Average PSNR: {sum(avg_psnr_l) / len(avg_psnr_l):.6f} dB for {len(sub_folder_l)} clips. '
                f'Center PSNR: {sum(avg_psnr_center_l) / len(avg_psnr_center_l):.6f} dB. '
                f'Border PSNR: {sum(avg_psnr_border_l) / len(avg_psnr_border_l):.6f} dB.')
Exemple #13
0
def cal_pnsr_ssim(sr_img, gt_img, lr_img, lrgt_img):
    # save images
    suffix = opt['suffix']
    if suffix:
        save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '.png')
    else:
        save_img_path = osp.join(dataset_dir, folder, img_name + '.png')
    util.save_img(sr_img, save_img_path)
    #
    # if suffix:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_GT.png')
    # else:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + '_GT.png')
    # util.save_img(gt_img, save_img_path)
    #
    if suffix:
        save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_LR.png')
    else:
        save_img_path = osp.join(dataset_dir, folder, img_name + '_LR.png')
    util.save_img(lr_img, save_img_path)
    #
    # if suffix:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_LR_ref.png')
    # else:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + '_LR_ref.png')
    # util.save_img(lrgt_img, save_img_path)

    # calculate PSNR and SSIM
    gt_img = gt_img / 255.
    sr_img = sr_img / 255.

    lr_img = lr_img / 255.
    lrgt_img = lrgt_img / 255.

    crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale']
    if crop_border == 0:
        cropped_sr_img = sr_img
        cropped_gt_img = gt_img
    else:
        cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
        cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

    psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
    ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
    test_results['psnr'].append(psnr)
    test_results['ssim'].append(ssim)

    # PSNR and SSIM for LR
    psnr_lr = util.calculate_psnr(lr_img * 255, lrgt_img * 255)
    ssim_lr = util.calculate_ssim(lr_img * 255, lrgt_img * 255)
    test_results['psnr_lr'].append(psnr_lr)
    test_results['ssim_lr'].append(ssim_lr)

    if gt_img.shape[2] == 3:  # RGB image
        sr_img_y = bgr2ycbcr(sr_img, only_y=True)
        gt_img_y = bgr2ycbcr(gt_img, only_y=True)
        if crop_border == 0:
            cropped_sr_img_y = sr_img_y
            cropped_gt_img_y = gt_img_y
        else:
            cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
            cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
        psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
        ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
        test_results['psnr_y'].append(psnr_y)
        test_results['ssim_y'].append(ssim_y)

        lr_img_y = bgr2ycbcr(lr_img, only_y=True)
        lrgt_img_y = bgr2ycbcr(lrgt_img, only_y=True)
        psnr_y_lr = util.calculate_psnr(lr_img_y * 255, lrgt_img_y * 255)
        ssim_y_lr = util.calculate_ssim(lr_img_y * 255, lrgt_img_y * 255)
        test_results['psnr_y_lr'].append(psnr_y_lr)
        test_results['ssim_y_lr'].append(ssim_y_lr)

        writer.writerow([osp.join(folder, img_name), psnr_y, psnr_y_lr, ssim_y, ssim_y_lr])
        logger.info(
            '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
                format(osp.join(folder, img_name), psnr, ssim, psnr_y, ssim_y, psnr_lr, ssim_lr, psnr_y_lr, ssim_y_lr))
    else:
        writer.writerow([osp.join(folder, img_name), psnr, psnr_lr])
        logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
            osp.join(folder, img_name), psnr, ssim, psnr_lr, ssim_lr))

    return test_results
Exemple #14
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)
    chop2 = opt['chop']
    chop_patch_size = opt['chop_patch_size']
    multi_upscale = opt['multi_upscale']
    scale = opt['scale']

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    # znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        #Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt['znorm'] and znorm == False: 
            # znorm = True
            
    # Create model
    model = create_model(opt)
    
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True
            img_path = data['LR_path'][0] # because there's only 1 image per "data" dataset loader?
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            znorm = test_loader.dataset.opt['znorm']
            
            if chop2==True:
                lowres_img = data['LR'] #.to('cuda')
                if multi_upscale: # Upscale 8 times in different rotations/flips and average the results in a single image
                    LR_90 = lowres_img.transpose(2, 3).flip(2) #PyTorch > 0.4.1
                    LR_180 = LR_90.transpose(2, 3).flip(2) #PyTorch > 0.4.1 
                    LR_270 = LR_180.transpose(2, 3).flip(2) #PyTorch > 0.4.1 
                    LR_f = lowres_img.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_90f = LR_90.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_180f = LR_180.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_270f = LR_270.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    
                    pred = chop_forward2(lowres_img, model, scale=scale, patch_size=chop_patch_size)
                    pred_90 = chop_forward2(LR_90, model, scale=scale, patch_size=chop_patch_size)
                    pred_180 = chop_forward2(LR_180, model, scale=scale, patch_size=chop_patch_size)
                    pred_270 = chop_forward2(LR_270, model, scale=scale, patch_size=chop_patch_size)
                    pred_f = chop_forward2(LR_f, model, scale=scale, patch_size=chop_patch_size)
                    pred_90f = chop_forward2(LR_90f, model, scale=scale, patch_size=chop_patch_size)
                    pred_180f = chop_forward2(LR_180f, model, scale=scale, patch_size=chop_patch_size)
                    pred_270f = chop_forward2(LR_270f, model, scale=scale, patch_size=chop_patch_size)
                    
                    #convert to numpy array
                    if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                        pred = util.tensor2img(pred,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_90 = util.tensor2img(pred_90,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_180 = util.tensor2img(pred_180,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_270 = util.tensor2img(pred_270,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_f = util.tensor2img(pred_f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_90f = util.tensor2img(pred_90f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_180f = util.tensor2img(pred_180f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_270f = util.tensor2img(pred_270f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                    else: # Default: Image range is [0,1]
                        pred = util.tensor2img(pred).clip(0, 255)  # uint8
                        pred_90 = util.tensor2img(pred_90).clip(0, 255)  # uint8
                        pred_180 = util.tensor2img(pred_180).clip(0, 255)  # uint8
                        pred_270 = util.tensor2img(pred_270).clip(0, 255)  # uint8
                        pred_f = util.tensor2img(pred_f).clip(0, 255)  # uint8
                        pred_90f = util.tensor2img(pred_90f).clip(0, 255)  # uint8
                        pred_180f = util.tensor2img(pred_180f).clip(0, 255)  # uint8
                        pred_270f = util.tensor2img(pred_270f).clip(0, 255)  # uint8
                    
                    pred_90 = np.rot90(pred_90, 3)
                    pred_180 = np.rot90(pred_180, 2)
                    pred_270 = np.rot90(pred_270, 1)
                    pred_f = np.fliplr(pred_f)
                    pred_90f = np.rot90(np.fliplr(pred_90f), 3)
                    pred_180f = np.rot90(np.fliplr(pred_180f), 2)
                    pred_270f = np.rot90(np.fliplr(pred_270f), 1)
                    
                    #The reason for overflow is that your NumPy arrays (im1arr im2arr) are of the uint8 type (i.e. 8-bit). This means each element of the array can only hold values up to 255, so when your sum exceeds 255, it loops back around 0:
                    #To avoid overflow, your arrays should be able to contain values beyond 255. You need to convert them to floats for instance, perform the blending operation and convert the result back to uint8:
                    # sr_img = (pred + pred_90 + pred_180 + pred_270 + pred_f + pred_90f + pred_180f + pred_270f) / 8.0
                    sr_img = (pred.astype('float') + pred_90.astype('float') + pred_180.astype('float') + pred_270.astype('float') + pred_f.astype('float') + pred_90f.astype('float') + pred_180f.astype('float') + pred_270f.astype('float')) / 8.0
                    sr_img = sr_img.astype('uint8')                    
                    
                else:
                    highres_output = chop_forward2(lowres_img, model, scale=scale, patch_size=chop_patch_size)
                
                    #convert to numpy array
                    #highres_image = highres_output[0].permute(1, 2, 0).clamp(0.0, 1.0).cpu()
                    if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                        sr_img = util.tensor2img(highres_output,min_max=(-1, 1))  # uint8
                    else: # Default: Image range is [0,1]
                        sr_img = util.tensor2img(highres_output)  # uint8
            
            else: # will upscale each image in the batch without chopping 
                model.feed_data(data, need_HR=need_HR)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                
                if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    sr_img = util.tensor2img(visuals['SR'],min_max=(-1, 1))  # uint8
                else: # Default: Image range is [0,1]
                    sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],min_max=(-1, 1))  # uint8
                else: # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR']) # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Exemple #15
0
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    # print("GPU ", torch.cuda.device_count())

    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    print('export CUDA_VISIBLE_DEVICES=' + str(args.gpu_id))

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)



    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                          groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                          back_RBs=opt['network_G']['back_RBs'],
                          predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                          w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    # load screen change notation
    import json
    with open(args.screen_notation) as f:
        frame_notation = json.load(f)

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)


        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l[:5]):
            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx = data_util.index_generation_process_screen_change(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)  # 960 x 540


            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4] # 960
            intHeight_ori = imgs_in.shape[3] # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256



            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)# N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY )
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)


            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            y_all = y_all / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Exemple #16
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/test/test_ppon.json', help='Path to options JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)
            
            img_c = util.tensor2img(visuals['img_c'])  # uint8
            img_s = util.tensor2img(visuals['img_s'])  # uint8
            img_p = util.tensor2img(visuals['img_p'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_c_img_path = os.path.join(dataset_dir, img_name + suffix + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + suffix + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + suffix + '_p.png')                
            else:
                save_c_img_path = os.path.join(dataset_dir, img_name + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + '_p.png')
            
            util.save_img(img_c, save_c_img_path)
            util.save_img(img_s, save_s_img_path)
            util.save_img(img_p, save_p_img_path)
            

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals['HR'])
                gt_img = gt_img / 255.
                sr_img = img_c / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Exemple #17
0
    def __getitem__(self, index):

        scale = self.opt['scale']
        GT_size = self.opt['GT_size']
        key = self.paths_GT[index]

        name_a = key.split('/')[-2]
        name_b = key.split('/')[-1][:-4]

        center_frame_idx = int(name_b)

        #### determine the neighbor frames
        interval = random.choice(self.interval_list)
        if self.opt['border_mode']:
            direction = 1  # 1: forward; 0: backward
            N_frames = self.opt['N_frames']
            if self.random_reverse and random.random() < 0.5:
                direction = random.choice([0, 1])
            if center_frame_idx + interval * (N_frames - 1) > 99:
                direction = 0
            elif center_frame_idx - interval * (N_frames - 1) < 0:
                direction = 1
            # get the neighbor list
            if direction == 1:
                neighbor_list = list(
                    range(center_frame_idx,
                          center_frame_idx + interval * N_frames, interval))
            else:
                neighbor_list = list(
                    range(center_frame_idx,
                          center_frame_idx - interval * N_frames, -interval))
            name_b = '{:05d}'.format(neighbor_list[0])
        else:
            # ensure not exceeding the borders

            while (center_frame_idx + self.half_N_frames * interval > 99) \
                    or (center_frame_idx - self.half_N_frames * interval < 0):  # check notation shenwang
                center_frame_idx = random.randint(0, 99)
            # get the neighbor list
            neighbor_list = list(
                range(center_frame_idx - self.half_N_frames * interval,
                      center_frame_idx + self.half_N_frames * interval + 1,
                      interval))
            if self.random_reverse and random.random() < 0.5:
                neighbor_list.reverse()
            name_b = '{:05d}'.format(neighbor_list[self.half_N_frames])
            key = name_a + '_' + name_b  #todo
        assert len(neighbor_list) == self.opt[
            'N_frames'], 'Wrong length of neighbor list: {}'.format(
                len(neighbor_list))

        #### get the GT image (as the center frame)
        img_GT = util.read_img(None,
                               osp.join(self.GT_root, name_a, name_b + '.png'))

        #### get LQ images
        LQ_size_tuple = (3, 540, 960) if self.LR_input else (3, 2160, 3840)
        img_LQ_l = []
        for v in neighbor_list:
            img_LQ_path = osp.join(self.LQ_root, name_a,
                                   '{:05d}.png'.format(v))
            img_LQ = util.read_img(None, img_LQ_path)
            img_LQ_l.append(img_LQ)

        assert key == name_a + '_' + name_b

        if self.opt['phase'] == 'train':
            C, H, W = LQ_size_tuple  # LQ size
            # randomly crop
            if self.LR_input:
                LQ_size = GT_size // scale

                rnd_h = random.randint(0, max(0, H - LQ_size))
                rnd_w = random.randint(0, max(0, W - LQ_size))

                img_LQ_l = [
                    v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
                    for v in img_LQ_l
                ]
                rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
                img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size,
                                rnd_w_HR:rnd_w_HR + GT_size, :]
            else:
                rnd_h = random.randint(0, max(0, H - GT_size))
                rnd_w = random.randint(0, max(0, W - GT_size))
                img_LQ_l = [
                    v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
                    for v in img_LQ_l
                ]
                img_GT = img_GT[rnd_h:rnd_h + GT_size,
                                rnd_w:rnd_w + GT_size, :]

            # augmentation - flip, rotate
            img_LQ_l.append(img_GT)
            rlt = util.augment(img_LQ_l, self.opt['use_flip'],
                               self.opt['use_rot'])
            img_LQ_l = rlt[0:-1]
            img_GT = rlt[-1]

            #color
            if self.opt['color'] == 'YUV':
                img_LQ_l = [util.bgr2ycbcr(v, only_y=False) for v in img_LQ_l]
                img_GT = util.bgr2ycbcr(img_GT, only_y=False)

        # stack LQ images to NHWC, N is the frame number
        img_LQs = np.stack(img_LQ_l, axis=0)
        # BGR to RGB
        img_GT = img_GT[:, :, [2, 1, 0]]  # HWC
        img_LQs = img_LQs[:, :, :, [2, 1, 0]]

        # HWC to CHW  numpy to tensor
        # if YUV -> VUY
        img_GT = torch.from_numpy(
            np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
        img_LQs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(img_LQs, (0, 3, 1, 2)))).float()

        return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
Exemple #18
0
                    cropped_sr_img = sr_img[crop_border:-crop_border,
                                            crop_border:-crop_border]
                    cropped_gt_img = gt_img[crop_border:-crop_border,
                                            crop_border:-crop_border]
                else:
                    cropped_sr_img = sr_img[crop_border:-crop_border,
                                            crop_border:-crop_border, :]
                    cropped_gt_img = gt_img[crop_border:-crop_border,
                                            crop_border:-crop_border, :]
            psnr = util.psnr(cropped_sr_img, cropped_gt_img)
            ssim = util.ssim(cropped_sr_img, cropped_gt_img, multichannel=True)

            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            if len(gt_img.shape) >= 3:  # RGB image
                cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                psnr_y = util.psnr(cropped_sr_img_y, cropped_gt_img_y)
                ssim_y = util.ssim(cropped_sr_img_y,
                                   cropped_gt_img_y,
                                   multichannel=False)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                print('{:20s} - PSNR: {:.4f} dB; SSIM: {:.4f}; PSNR_Y: {:.4f} dB; SSIM_Y: {:.4f}.'\
                    .format(img_name, psnr, ssim, psnr_y, ssim_y))
            else:
                print('{:20s} - PSNR: {:.4f} dB; SSIM: {:.4f}.'.format(
                    img_name, psnr, ssim))
        else:
            print(img_name)
Exemple #19
0
def main():
    #################
    # configurations
    #################
    stage = 1
    device = torch.device("cuda")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    logger = logging.getLogger("base")

    # tensorboard
    # tb_logger = SummaryWriter(log_dir="../tb_logger/" + "vinhlong_040719_1212")

    data_mode = "licensePlate_blur_bicubic_EDVRstock"
    flip_test = False
    model_path = "/content/EDVR/EDVRstock.pth"
    nf = 64
    N_in = 5
    predeblur, HR_in = False, False
    back_RBs = 10
    test_dataset_folder = "/content/EDVR/datasets/license_plate5/BI_x4"
    GT_dataset_folder = "/content/EDVR/datasets/license_plate5/GT"

    if stage == 2:
        model_path = ("/content/EDVR/experiments/" +
                      "pretrained_models/EDVR_REDS_deblur_Stage2.pth")
        nf = 128
        predeblur, HR_in = True, True
        back_RBs = 20
        test_dataset_folder = "/content/EDVR/datasets/licensePlate_blur_bicubic"
        GT_dataset_folder = ""

    model = EDVR_arch.EDVR(nf,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == "Vid4" or data_mode == "sharp_bicubic":
        padding = "new_info"
    else:
        padding = "replicate"
    save_imgs = True

    # Reconfig logger handler
    save_folder = "../results/{}".format(data_mode)
    # remove all old handlers to avoid duplicate
    for hdlr in logger.handlers[:]:
        logger.removeHandler(hdlr)
    util.mkdirs(save_folder)
    util.setup_logger("base",
                      save_folder,
                      "test",
                      level=logging.INFO,
                      screen=True,
                      tofile=True)

    #### log info
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, "*")))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, "*")))
    isGT = bool(GT_dataset_folder)

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, "*")))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = test_util.read_img_seq(subfolder)
        img_GT_l = []
        if isGT:
            for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, "*"))):
                img_GT_l.append(test_util.read_img(img_GT_path))

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = (
                0,
                0,
                0,
                0,
                0,
            )

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = test_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = (imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device))

            if flip_test:
                output = test_util.flipx4_forward(model, imgs_in)
            else:
                output = test_util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, "{}.png".format(img_name)),
                    output)

            if isGT:
                # calculate PSNR
                output = output / 255.0
                GT = np.copy(img_GT_l[img_idx])
                # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                if data_mode == "Vid4":  # bgr2y, [0, 1]
                    GT = data_util.bgr2ycbcr(GT, only_y=True)
                    output = data_util.bgr2ycbcr(output, only_y=True)

                output, GT = test_util.crop_border([output, GT], crop_border)
                crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                """ logger.info(
                    "{:3d} - {:25} \tPSNR: {:.6f} dB".format(
                        img_idx + 1, img_name, crt_psnr
                    )
                ) """

                if (img_idx >= border_frame
                        and img_idx < max_idx - border_frame):  # center frames
                    avg_psnr_center += crt_psnr
                    N_center += 1
                else:  # border frames
                    avg_psnr_border += crt_psnr
                    N_border += 1
            else:
                logger.info("{:3d} - {:25} is generated".format(
                    img_idx + 1, img_name))

        if isGT:
            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
            avg_psnr_l.append(avg_psnr)
            avg_psnr_center_l.append(avg_psnr_center)
            avg_psnr_border_l.append(avg_psnr_border)

            logger.info("Folder {} - Average PSNR: {:.6f} dB for {} frames; "
                        "Center PSNR: {:.6f} dB for {} frames; "
                        "Border PSNR: {:.6f} dB for {} frames.".format(
                            subfolder_name,
                            avg_psnr,
                            (N_center + N_border),
                            avg_psnr_center,
                            N_center,
                            avg_psnr_border,
                            N_border,
                        ))

    logger.info("################ Tidy Outputs ################")
    if isGT:
        for subfolder_name, psnr, psnr_center, psnr_border in zip(
                subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
                avg_psnr_border_l):
            logger.info("Folder {} - Average PSNR: {:.6f} dB. "
                        "Center PSNR: {:.6f} dB. "
                        "Border PSNR: {:.6f} dB.".format(
                            subfolder_name, psnr, psnr_center, psnr_border))
    logger.info("################ Final Results ################")
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))
    logger.info("Is GT: {}".format(isGT))

    if isGT:
        logger.info("Total Average PSNR: {:.6f} dB for {} clips. "
                    "Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.".format(
                        sum(avg_psnr_l) / len(avg_psnr_l),
                        len(subfolder_l),
                        sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                        sum(avg_psnr_border_l) / len(avg_psnr_border_l),
                    ))
                              downsample_and_quantize=False,
                              chroma_mode=True,
                              block_size=8)
jpeg_compressor_16 = JPEG.JPEG(compress=True,
                               downsample_and_quantize=True,
                               downsample_only=True,
                               chroma_mode=True,
                               block_size=16)
jpeg_extractor = JPEG.JPEG(compress=False, chroma_mode=True, block_size=16)
jpeg_compressor_16.Set_Q_Table(torch.tensor(90))
jpeg_compressor_8.Set_Q_Table(torch.tensor(90))
jpeg_extractor.Set_Q_Table(torch.tensor(90))
rmse_NN, rmse_interp, rmse_NN_orig, rmse_JPEG = 0, 0, 0, 0
for im_name in tqdm(images_list):
    image = cv2.imread(os.path.join(dataset_folder, im_name))
    image = bgr2ycbcr(modcrop(image, 16), only_y=False).astype(float)
    im_shape = list(image.shape[:2])
    subsampled_chroma = np.array(image)[::2, ::2, 1:]
    recovered_image_NN = np.tile(
        np.expand_dims(np.expand_dims(subsampled_chroma, 2), 1),
        [1, 2, 1, 2, 1]).reshape(im_shape + [-1])
    recovered_image_NN = 255 * ycbcr2rgb(
        np.concatenate([np.expand_dims(image[..., 0], -1), recovered_image_NN],
                       -1) / 255)
    recovered_image_interp = cv2.resize(subsampled_chroma,
                                        tuple(im_shape[::-1]),
                                        interpolation=cv2.INTER_LINEAR)
    recovered_image_interp = 255 * ycbcr2rgb(
        np.concatenate(
            [np.expand_dims(image[..., 0], -1), recovered_image_interp], -1) /
        255)
Exemple #21
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'SDR_4bit'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'SDR_4bit':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    # use N_in images to restore one high bitdepth image
    N_in = 5

    # predeblur: predeblur for blurry input
    # HR_in: downsample high resolution input
    predeblur, HR_in = False, False
    back_RBs = 40
    predeblur = True
    HR_in = True
    if data_mode == 'SDR_4bit':
        # predeblur, HR_in = False, True
        pass
    if stage == 2:
        HR_in = True
        back_RBs = 20
    # EDVR(num_feature_map, num_input_frames, deformable_groups?, front_RBs,
    #      back_RBs, predeblur, HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if stage == 1:
        test_dataset_folder = '../datasets/{}'.format(data_mode)
    else:
        test_dataset_folder = '../'
        print('You should modify the test_dataset_folder path for stage 2')
    GT_dataset_folder = '../datasets/SDR_10bit/'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LBD and GT images
        #### resize to avoid cuda out of memory, 2160x3840->720x1280
        imgs_LBD = data_util.read_img_seq(subfolder,
                                          scale=65535.,
                                          zoomout=(1280, 720))
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(
                data_util.read_img(None,
                                   img_GT_path,
                                   scale=65535.,
                                   zoomout=True))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            # generate frame index
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LBD.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                # self ensemble with fipping input at four different directions
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0), out_type=np.uint16)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            # output = output / 255.
            output = output / 65535.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 65535, GT * 65535)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Exemple #22
0
def main():
    scale = 4
    N_ot = 5  # 3
    N_in = 1 + N_ot // 2
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # model
    # TODO: change your model path here
    model_path = '../experiments/pretrained_models/xiang2020zooming.pth'
    model = Sakuya_arch.LunaTokis(64, N_ot, 8, 5, 40)

    # dataset
    data_mode = 'Custom'  # 'Vid4' #'SPMC'#'Middlebury'#

    if data_mode == 'Vid4':
        test_dataset_folder = '/data/xiang/SR/Vid4/LR/*'
    if data_mode == 'SPMC':
        test_dataset_folder = '/data/xiang/SR/spmc/*'
    if data_mode == 'Custom':
        test_dataset_folder = '../test_example/*'  # TODO: put your own data path here

    # evaluation
    flip_test = False  # True#
    crop_border = 0

    # temporal padding mode
    padding = 'replicate'
    save_imgs = False  # True#
    if 'Custom' in data_mode:
        save_imgs = True
    ############################################################################
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    model_params = util.get_model_total_params(model)

    # log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Model parameters: {} M'.format(model_params))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip Test: {}'.format(flip_test))

    def single_forward(model, imgs_in):
        with torch.no_grad():
            # imgs_in.size(): [1,n,3,h,w]
            b, n, c, h, w = imgs_in.size()
            h_n = int(4 * np.ceil(h / 4))
            w_n = int(4 * np.ceil(w / 4))
            imgs_temp = imgs_in.new_zeros(b, n, c, h_n, w_n)
            imgs_temp[:, :, :, 0:h, 0:w] = imgs_in
            t0 = time.time()
            model_output = model(imgs_temp)
            t1 = time.time()
            logger.info('Single Test time: {}'.format(t1 - t0))
            # model_output.size(): torch.Size([1, 3, 4h, 4w])
            model_output = model_output[:, :, :, 0:scale * h, 0:scale * w]
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))

    model.load_state_dict(torch.load(model_path), strict=True)

    model.eval()
    model = model.to(device)

    avg_psnr_l = []
    avg_psnr_y_l = []
    sub_folder_name_l = []
    # total_time = []
    # for each sub-folder
    for sub_folder in sub_folder_l:
        gt_tested_list = []
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        if data_mode == 'SPMC':
            sub_folder = sub_folder + '/LR/'
        img_LR_l = sorted(glob.glob(sub_folder + '/*'))

        if save_imgs:
            util.mkdirs(save_sub_folder)

        # read LR images
        imgs = util.read_seq_imgs(sub_folder)
        # read GT images
        img_GT_l = []
        if data_mode == 'SPMC':
            sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/truth/'))
        else:
            sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/HR/'))

        if 'Custom' not in data_mode:
            for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))):
                img_GT_l.append(util.read_image(img_GT_path))

        avg_psnr, avg_psnr_sum, cal_n = 0, 0, 0
        avg_psnr_y, avg_psnr_sum_y = 0, 0

        if len(img_LR_l) == len(img_GT_l):
            skip = True
        else:
            skip = False

        if 'Custom' in data_mode:
            select_idx_list = util.test_index_generation(
                False, N_ot, len(img_LR_l))
        else:
            select_idx_list = util.test_index_generation(
                skip, N_ot, len(img_LR_l))
        # process each image
        for select_idxs in select_idx_list:
            # get input images
            select_idx = select_idxs[0]
            gt_idx = select_idxs[1]
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            output = single_forward(model, imgs_in)
            print(select_idx, gt_idx, imgs_in.size(), output.size())
            outputs = output.data.float().cpu().squeeze(0)

            if flip_test:
                # flip W
                output = single_forward(model, torch.flip(imgs_in, (-1, )))
                output = torch.flip(output, (-1, ))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output
                # flip H
                output = single_forward(model, torch.flip(imgs_in, (-2, )))
                output = torch.flip(output, (-2, ))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output
                # flip both H and W
                output = single_forward(model, torch.flip(imgs_in, (-2, -1)))
                output = torch.flip(output, (-2, -1))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output

                outputs = outputs / 4

            # save imgs
            for idx, name_idx in enumerate(gt_idx):
                if name_idx in gt_tested_list:
                    continue
                gt_tested_list.append(name_idx)
                output_f = outputs[idx, :, :, :].squeeze(0)

                output = util.tensor2img(output_f)
                if save_imgs:
                    cv2.imwrite(
                        osp.join(save_sub_folder,
                                 '{:08d}.png'.format(name_idx + 1)), output)

                if 'Custom' not in data_mode:
                    # calculate PSNR
                    output = output / 255.

                    GT = np.copy(img_GT_l[name_idx])

                    if crop_border == 0:
                        cropped_output = output
                        cropped_GT = GT
                    else:
                        cropped_output = output[crop_border:-crop_border,
                                                crop_border:-crop_border, :]
                        cropped_GT = GT[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                    crt_psnr = util.calculate_psnr(cropped_output * 255,
                                                   cropped_GT * 255)
                    cropped_GT_y = data_util.bgr2ycbcr(cropped_GT, only_y=True)
                    cropped_output_y = data_util.bgr2ycbcr(cropped_output,
                                                           only_y=True)
                    crt_psnr_y = util.calculate_psnr(cropped_output_y * 255,
                                                     cropped_GT_y * 255)
                    logger.info(
                        '{:3d} - {:25}.png \tPSNR: {:.6f} dB  PSNR-Y: {:.6f} dB'
                        .format(name_idx + 1, name_idx + 1, crt_psnr,
                                crt_psnr_y))
                    avg_psnr_sum += crt_psnr
                    avg_psnr_sum_y += crt_psnr_y
                    cal_n += 1

        if 'Custom' not in data_mode:
            avg_psnr = avg_psnr_sum / cal_n
            avg_psnr_y = avg_psnr_sum_y / cal_n

            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} frames; '
                .format(sub_folder_name, avg_psnr, avg_psnr_y, cal_n))

            avg_psnr_l.append(avg_psnr)
            avg_psnr_y_l.append(avg_psnr_y)

    if 'Custom' not in data_mode:
        logger.info('################ Tidy Outputs ################')
        for name, psnr, psnr_y in zip(sub_folder_name_l, avg_psnr_l,
                                      avg_psnr_y_l):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB. '.
                format(name, psnr, psnr_y))
        logger.info('################ Final Results ################')
        logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip Test: {}'.format(flip_test))
        logger.info(
            'Total Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} clips. '.
            format(
                sum(avg_psnr_l) / len(avg_psnr_l),
                sum(avg_psnr_y_l) / len(avg_psnr_y_l), len(sub_folder_l)))
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)

    parser = argparse.ArgumentParser()

    parser.add_argument('dataset')
    parser.add_argument('n_frames')
    parser.add_argument('stage')

    args = parser.parse_args()

    data_mode = str(args.dataset)
    N_in = int(args.n_frames)
    stage = int(args.stage)

    #if args.command == 'start':
    #    start(int(args.params[0]))
    #elif args.command == 'stop':
    #    stop(args.params[0], int(args.params[1]))
    #elif args.command == 'stop_all':
    #    stop_all(args.params[0])

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    #stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    #### dataset
    if data_mode == 'Vid4':
        N_model_default = 7
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        N_model_default = 5
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    raw_model = EDVR_arch.EDVR(128,
                               N_model_default,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    data_mode_t = copy.deepcopy(data_mode)
    if stage == 1 and data_mode_t != 'Vid4':
        data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
    save_folder = '../results/{}'.format(data_mode)
    data_mode = copy.deepcopy(data_mode_t)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    print([a for a in dir(model)
           if not callable(getattr(model, a))])  # not a.startswith('__') and

    #model.load_state_dict(torch.load(model_path), strict=True)
    raw_model.load_state_dict(torch.load(model_path), strict=True)

    #   model.load_state_dict(torch.load(model_path), strict=True)

    #### change model so it can work with less input

    model.nf = raw_model.nf
    model.center = N_in // 2  #  if center is None else center
    model.is_predeblur = raw_model.is_predeblur
    model.HR_in = raw_model.HR_in
    model.w_TSA = raw_model.w_TSA
    #ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)

    #### extract features (for each frame)
    if model.is_predeblur:
        model.pre_deblur = raw_model.pre_deblur  #Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
        model.conv_1x1 = raw_model.conv_1x1  #nn.Conv2d(nf, nf, 1, 1, bias=True)
    else:
        if model.HR_in:
            model.conv_first_1 = raw_model.conv_first_1  #nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.conv_first_2 = raw_model.conv_first_2  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.conv_first_3 = raw_model.conv_first_3  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        else:
            model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
    model.feature_extraction = raw_model.feature_extraction  #  arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
    model.fea_L2_conv1 = raw_model.fea_L2_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L2_conv2 = raw_model.fea_L2_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
    model.fea_L3_conv1 = raw_model.fea_L3_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L3_conv2 = raw_model.fea_L3_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

    model.pcd_align = raw_model.pcd_align  #PCD_Align(nf=nf, groups=groups)

    ######## Resize TSA

    model.tsa_fusion.center = model.center
    # temporal attention (before fusion conv)
    model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
    model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

    # fusion conv: using 1x1 to save parameters and computation

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)
    #print(raw_model.tsa_fusion.fea_fusion.weight[127][639].shape)
    #print("MAIN SHAPE(FEA): ", raw_model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.fea_fusion = copy.deepcopy(
        raw_model.tsa_fusion.fea_fusion)
    model.tsa_fusion.fea_fusion.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in *
                                                                  128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.fea_fusion.bias = raw_model.tsa_fusion.fea_fusion.bias

    # spatial attention (after fusion conv)
    model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
    model.tsa_fusion.sAtt_1.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in *
                                                              128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.sAtt_1.bias = raw_model.tsa_fusion.sAtt_1.bias

    #print(N_in * 128)
    #print(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :].shape)
    print("MODEL TSA SHAPE: ", model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
    model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
    model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
    model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
    model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
    model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
    model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
    model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
    model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
    model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
    model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

    model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

    #if model.w_TSA:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
    #else:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)

    #   print(self.tsa_fusion)

    #### reconstruction
    model.recon_trunk = raw_model.recon_trunk  # arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
    #### upsampling
    model.upconv1 = raw_model.upconv1  #nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
    model.upconv2 = raw_model.upconv2  #nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
    model.pixel_shuffle = raw_model.pixel_shuffle  # nn.PixelShuffle(2)
    model.HRconv = raw_model.HRconv
    model.conv_last = raw_model.conv_last

    #### activation function
    model.lrelu = raw_model.lrelu

    #####################################################

    model.eval()
    model = model.to(device)

    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)

        print("MAX_IDX: ", max_idx)

        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_ssim, avg_ssim_border, avg_ssim_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            if data_mode == "blur":
                select_idx = data_util.glarefree_index_generation(
                    img_idx, max_idx, N_in, padding=padding)
            else:
                select_idx = data_util.index_generation(
                    img_idx, max_idx, N_in, padding=padding)  #  HERE GOTCHA
            print("SELECT IDX: ", select_idx)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                print("IMGS_IN SHAPE: ", imgs_in.shape)  # check this
                output = util.single_forward(model, imgs_in)  # error here 1
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate SSIM
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_ssim = util.calculate_ssim(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tSSIM: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_ssim_border += crt_ssim
                N_border += 1

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average SSIM: {:.6f} dB for {} frames; '
                    'Center SSIM: {:.6f} dB for {} frames; '
                    'Border SSIM: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, ssim, ssim_center, ssim_border in zip(
            subfolder_name_l, avg_ssim_l, avg_ssim_center_l,
            avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f} dB. '
                    'Center SSIM: {:.6f} dB. '
                    'Border SSIM: {:.6f} dB.'.format(subfolder_name, ssim,
                                                     ssim_center, ssim_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average SSIM: {:.6f} dB for {} clips. '
                'Center SSIM: {:.6f} dB. Border SSIM: {:.6f} dB.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
Exemple #24
0
def main(jsonPath):
    # options
    opt = option.parse(jsonPath, is_train=False)
    util.mkdirs((path for key, path in opt["path"].items()
                 if not key == "pretrain_model_G"))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt["path"]["log"],
                      "test.log",
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger("base")
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt["datasets"].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info("Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt["name"]
        logger.info("\nTesting [{:s}]...".format(test_set_name))
        # test_start_time = time.time()
        dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results["psnr"] = []
        test_results["ssim"] = []
        test_results["psnr_y"] = []
        test_results["ssim_y"] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                "dataroot_HR"] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data["LR_path"][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            sr_img = util.tensor2img(visuals["SR"])  # uint8

            # save images
            suffix = opt["suffix"]
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + ".png")
            else:
                save_img_path = os.path.join(dataset_dir, img_name + ".png")
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals["HR"])
                gt_img = gt_img / 255.0
                sr_img = sr_img / 255.0

                crop_border = test_loader.dataset.opt["scale"]
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results["psnr"].append(psnr)
                test_results["ssim"].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}."
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
            ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
            logger.info(
                "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n"
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results["psnr_y"] and test_results["ssim_y"]:
                ave_psnr_y = sum(test_results["psnr_y"]) / len(
                    test_results["psnr_y"])
                ave_ssim_y = sum(test_results["ssim_y"]) / len(
                    test_results["ssim_y"])
                logger.info(
                    "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n"
                    .format(ave_psnr_y, ave_ssim_y))
Exemple #25
0
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic (REDS)

    # model
    N_in = 7
    model_path = '../experiments/pretrained_models/TOF_official.pth'
    adapt_official = True if 'official' in model_path else False
    model = TOF_arch.TOFlow(adapt_official=adapt_official)

    # ### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
    else:
        test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)

    # ### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'  # different from the official setting
    save_imgs = True
    # ###########################################################################
    device = torch.device('cuda')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    # ### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))
    # ### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder in sub_folder_l:
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []
        if data_mode == 'Vid4':
            sub_folder_GT = osp.join(
                sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
        else:
            sub_folder_GT = osp.join(
                sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
        for img_GT_path in sorted(glob.glob(sub_folder_GT)):
            img_GT_l.append(read_image(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx,
                                          max_idx,
                                          N_in,
                                          padding=padding)
            # get input images
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            output = single_forward(model, imgs_in)
            output_f = output.data.float().cpu().squeeze(0)

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(
                    osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)),
                    output)

            # ### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border,
                                        crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border,
                                crop_border:-crop_border]
            crt_psnr = util.calculate_psnr(cropped_output * 255,
                                           cropped_GT * 255)
            logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(
                img_idx + 1, c_idx, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center +
                                                          cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        sub_folder_name, avg_psnr,
                        (cal_n_center + cal_n_border), avg_psnr_center,
                        cal_n_center, avg_psnr_border, cal_n_border))

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l,
                                                    avg_psnr_l,
                                                    avg_psnr_center_l,
                                                    avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(description="Test with EDVR, requre path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')
        GT_dataset_folder = os.path.join(args.input, 'GT')
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Exemple #27
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                sr_img = util.tensor2img(visuals['SR'],
                                         min_max=(-1, 1))  # uint8
            else:  # Default: Image range is [0,1]
                sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],
                                             min_max=(-1, 1))  # uint8
                else:  # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR'])  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Exemple #28
0
def main():
    # Metric path
    metric_path = os.getcwd() + '/utils/metric'
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))

    eng = matlab.engine.connect_matlab()
    names = matlab.engine.find_matlab()
    print('matlab process name: ', names)
    eng.addpath(metric_path)

    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                scores = 0.0
                imrmse = 0.0
                avg_pirm_rmse = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.calculate_psnr(cropped_sr_img_y * 255,
                                                    cropped_gt_img_y * 255)
                    immse = util.mse(cropped_sr_img_y * 255,
                                     cropped_gt_img_y * 255)
                    avg_pirm_rmse += immse
                    scores += eng.calc_NIQE(save_img_path, 4)

                avg_psnr = avg_psnr / idx
                scores = scores / idx
                avg_pirm_rmse = math.sqrt(avg_pirm_rmse / idx)

                # log
                logger.info(
                    '# Validation # PSNR: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(avg_psnr, scores, avg_pirm_rmse))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(epoch, current_step, avg_psnr, scores,
                            avg_pirm_rmse))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('NIQE', scores, current_step)
                    tb_logger.add_scalar('pirm_rmse', avg_pirm_rmse,
                                         current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Exemple #29
0
    import os
    from data.util import bgr2ycbcr
    sr_path = '/media/4T/Dizzy/BasicSR-master/results/Test/DeviceVal20_2xd/'
    hr_path = '/media/4T/Dizzy/SR_classical_DataSet/RealWorldDataSet/Device_degration_Data/City100_iPhoneX/HR_val/'
    psnrtotal, ssimtotal = 0, 0
    psnr_ytotal, ssim_ytotoal = 0, 0
    idx = 0
    crop_border = 4
    for name in os.listdir(hr_path):
        name = name.split('.')[0]
        sr_img_np = np.array(Image.open(sr_path + name + '.png')) / 255
        hr_img_np = np.array(Image.open(hr_path + name + '.PNG')) / 255
        sr_img_np = sr_img_np[crop_border:-crop_border,
                              crop_border:-crop_border, :]
        hr_img_np = hr_img_np[crop_border:-crop_border,
                              crop_border:-crop_border, :]
        psnr = calculate_psnr(hr_img_np * 255, sr_img_np * 255)
        ssim_ = calculate_ssim(hr_img_np * 255, sr_img_np * 255)
        psnrtotal += psnr
        ssimtotal += ssim_
        sr_img_np_y = bgr2ycbcr(sr_img_np, only_y=True)
        hr_img_np_y = bgr2ycbcr(hr_img_np, only_y=True)
        psnr = calculate_psnr(sr_img_np_y * 255, hr_img_np_y * 255)
        ssim_ = calculate_ssim(sr_img_np_y * 255, hr_img_np_y * 255)
        psnr_ytotal += psnr
        ssim_ytotoal += ssim_
        idx += 1

    print('PSNR: ', psnrtotal / idx, 'SSIM: ', ssimtotal / idx)
    print('PSNR_y: ', psnr_ytotal / idx, 'SSIM_y: ', ssim_ytotoal / idx)
Exemple #30
0
def main():
    # Settings
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt) #load settings and initialize settings

    util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'saved_model'))
    option.save(opt)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # Redirect all writes to the "txt" file
    sys.stdout = PrintLogger(opt['path']['log'])

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # Create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                avg_ssim =0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['GT_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    out_img = util.tensor2img(visuals['Output'])
                    gt_img = util.tensor2img(visuals['ground_truth'])  # uint8

                    # Save output images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(out_img, save_img_path)

                    # calculate PSNR
                    if len(gt_img.shape) == 2:
                        gt_img = np.expand_dims(gt_img, axis=2)
                        out_img = np.expand_dims(out_img, axis=2)
                    crop_border = opt['scale']
                    cropped_out_img = out_img[crop_border:-crop_border, crop_border:-crop_border, :]
                    cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]
                    if gt_img.shape[2] == 3:  # RGB image
                        cropped_out_img_y = bgr2ycbcr(cropped_out_img, only_y=True)
                        cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                        avg_psnr += util.psnr(cropped_out_img_y, cropped_gt_img_y)
                        avg_ssim += util.ssim(cropped_out_img_y, cropped_gt_img_y, multichannel=False)
                    else:
                        avg_psnr += util.psnr(cropped_out_img, cropped_gt_img)
                        avg_ssim += util.ssim(cropped_out_img, cropped_gt_img, multichannel=True)

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                print_rlt['ssim'] = avg_ssim
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')