Example #1
0
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    # print("GPU ", torch.cuda.device_count())

    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    print('export CUDA_VISIBLE_DEVICES=' + str(args.gpu_id))

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)



    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                          groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                          back_RBs=opt['network_G']['back_RBs'],
                          predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                          w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    # load screen change notation
    import json
    with open(args.screen_notation) as f:
        frame_notation = json.load(f)

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)


        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l[:5]):
            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx = data_util.index_generation_process_screen_change(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)  # 960 x 540


            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4] # 960
            intHeight_ori = imgs_in.shape[3] # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256



            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)# N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY )
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)


            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            y_all = y_all / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #2
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'ai4khdr_valid'
    flip_test = False

    ############################################################################
    #### model
    #################
    if data_mode == 'ai4khdr_valid':
        model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
    else:
        raise NotImplementedError
    N_in = 5
    front_RBs = 5
    back_RBs = 10
    predeblur, HR_in = False, False
    model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in)

    ############################################################################
    #### dataset
    #################
    if data_mode == 'ai4khdr_valid':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames'
        GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames'
    else:
        raise NotImplementedError

    ############################################################################
    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'ai4khdr_valid':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp())
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for ai4khdr_valid, evaluate on the Y channel
            if data_mode == 'ai4khdr_valid':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #3
0
    img_LR = torch.from_numpy(
        np.ascontiguousarray(np.transpose(
            img_LR, (2, 0, 1)))).float().unsqueeze(0).cuda()

    with torch.no_grad():
        begin_time = time.time()
        frame, sr_base, output = model(img_LR)
        end_time = time.time()
        stat_time += (end_time - begin_time)
        #print(end_time-begin_time)

    output = util.tensor2img(output.squeeze(0))
    frame = util.tensor2img(frame.squeeze(0))
    sr_base = util.tensor2img(sr_base.squeeze(0))

    # save images
    save_path_name = osp.join(
        save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name))
    merge = np.concatenate((frame, sr_base, output), axis=1)
    util.save_img(merge, save_path_name)

    # calculate PSNR
    sr_img, gt_img = util.crop_border([output, img_GT], scale)
    PSNR_avg += util.calculate_psnr(sr_img, gt_img)
    SSIM_avg += util.calculate_ssim(sr_img, gt_img)

print('average PSNR: ', PSNR_avg / len(img_list))
print('average SSIM: ', SSIM_avg / len(img_list))
print('time: ', stat_time / len(img_list))
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)

    parser = argparse.ArgumentParser()

    parser.add_argument('dataset')
    parser.add_argument('n_frames')
    parser.add_argument('stage')

    args = parser.parse_args()

    data_mode = str(args.dataset)
    N_in = int(args.n_frames)
    stage = int(args.stage)

    #if args.command == 'start':
    #    start(int(args.params[0]))
    #elif args.command == 'stop':
    #    stop(args.params[0], int(args.params[1]))
    #elif args.command == 'stop_all':
    #    stop_all(args.params[0])

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    #stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    #### dataset
    if data_mode == 'Vid4':
        N_model_default = 7
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        N_model_default = 5
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    raw_model = EDVR_arch.EDVR(128,
                               N_model_default,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    data_mode_t = copy.deepcopy(data_mode)
    if stage == 1 and data_mode_t != 'Vid4':
        data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
    save_folder = '../results/{}'.format(data_mode)
    data_mode = copy.deepcopy(data_mode_t)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    print([a for a in dir(model)
           if not callable(getattr(model, a))])  # not a.startswith('__') and

    #model.load_state_dict(torch.load(model_path), strict=True)
    raw_model.load_state_dict(torch.load(model_path), strict=True)

    #   model.load_state_dict(torch.load(model_path), strict=True)

    #### change model so it can work with less input

    model.nf = raw_model.nf
    model.center = N_in // 2  #  if center is None else center
    model.is_predeblur = raw_model.is_predeblur
    model.HR_in = raw_model.HR_in
    model.w_TSA = raw_model.w_TSA
    #ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)

    #### extract features (for each frame)
    if model.is_predeblur:
        model.pre_deblur = raw_model.pre_deblur  #Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
        model.conv_1x1 = raw_model.conv_1x1  #nn.Conv2d(nf, nf, 1, 1, bias=True)
    else:
        if model.HR_in:
            model.conv_first_1 = raw_model.conv_first_1  #nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.conv_first_2 = raw_model.conv_first_2  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.conv_first_3 = raw_model.conv_first_3  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        else:
            model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
    model.feature_extraction = raw_model.feature_extraction  #  arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
    model.fea_L2_conv1 = raw_model.fea_L2_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L2_conv2 = raw_model.fea_L2_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
    model.fea_L3_conv1 = raw_model.fea_L3_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L3_conv2 = raw_model.fea_L3_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

    model.pcd_align = raw_model.pcd_align  #PCD_Align(nf=nf, groups=groups)

    ######## Resize TSA

    model.tsa_fusion.center = model.center
    # temporal attention (before fusion conv)
    model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
    model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

    # fusion conv: using 1x1 to save parameters and computation

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)
    #print(raw_model.tsa_fusion.fea_fusion.weight[127][639].shape)
    #print("MAIN SHAPE(FEA): ", raw_model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.fea_fusion = copy.deepcopy(
        raw_model.tsa_fusion.fea_fusion)
    model.tsa_fusion.fea_fusion.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in *
                                                                  128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.fea_fusion.bias = raw_model.tsa_fusion.fea_fusion.bias

    # spatial attention (after fusion conv)
    model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
    model.tsa_fusion.sAtt_1.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in *
                                                              128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.sAtt_1.bias = raw_model.tsa_fusion.sAtt_1.bias

    #print(N_in * 128)
    #print(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :].shape)
    print("MODEL TSA SHAPE: ", model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
    model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
    model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
    model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
    model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
    model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
    model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
    model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
    model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
    model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
    model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

    model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

    #if model.w_TSA:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
    #else:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)

    #   print(self.tsa_fusion)

    #### reconstruction
    model.recon_trunk = raw_model.recon_trunk  # arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
    #### upsampling
    model.upconv1 = raw_model.upconv1  #nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
    model.upconv2 = raw_model.upconv2  #nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
    model.pixel_shuffle = raw_model.pixel_shuffle  # nn.PixelShuffle(2)
    model.HRconv = raw_model.HRconv
    model.conv_last = raw_model.conv_last

    #### activation function
    model.lrelu = raw_model.lrelu

    #####################################################

    model.eval()
    model = model.to(device)

    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)

        print("MAX_IDX: ", max_idx)

        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_ssim, avg_ssim_border, avg_ssim_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            if data_mode == "blur":
                select_idx = data_util.glarefree_index_generation(
                    img_idx, max_idx, N_in, padding=padding)
            else:
                select_idx = data_util.index_generation(
                    img_idx, max_idx, N_in, padding=padding)  #  HERE GOTCHA
            print("SELECT IDX: ", select_idx)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                print("IMGS_IN SHAPE: ", imgs_in.shape)  # check this
                output = util.single_forward(model, imgs_in)  # error here 1
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate SSIM
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_ssim = util.calculate_ssim(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tSSIM: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_ssim_border += crt_ssim
                N_border += 1

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average SSIM: {:.6f} dB for {} frames; '
                    'Center SSIM: {:.6f} dB for {} frames; '
                    'Border SSIM: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, ssim, ssim_center, ssim_border in zip(
            subfolder_name_l, avg_ssim_l, avg_ssim_center_l,
            avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f} dB. '
                    'Center SSIM: {:.6f} dB. '
                    'Border SSIM: {:.6f} dB.'.format(subfolder_name, ssim,
                                                     ssim_center, ssim_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average SSIM: {:.6f} dB for {} clips. '
                'Center SSIM: {:.6f} dB. Border SSIM: {:.6f} dB.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
Example #5
0
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)



   # data_mode = str(args.dataset)
   # N_in = int(args.n_frames)
   # metrics = str(args.metrics)
   # output_format = str(args.output_format)


    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).


    # STAGE Vid4
    # Collecting results for Vid4

    model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    predeblur, HR_in = False, False
    back_RBs = 40

    N_model_default = 7
    data_mode = 'Vid4'

   # vid4_dir_map = {"calendar": 0, "city": 1, "foliage": 2, "walk": 3}
    vid4_results = {"calendar": {}, "city": {}, "foliage": {}, "walk": {}}

    #vid4_results = 4 * [[]]

    for N_in in range(1, N_model_default + 1):
        raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
        model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
        aposterior_GT_dataset_folder = '../datasets/Vid4/GT_7'

        crop_border = 0
        border_frame = N_in // 2  # border frames when evaluate
        padding = 'new_info'

        save_imgs = False

        raw_model.load_state_dict(torch.load(model_path), strict=True)

        model.nf = raw_model.nf
        model.center = N_in // 2  # if center is None else center
        model.is_predeblur = raw_model.is_predeblur
        model.HR_in = raw_model.HR_in
        model.w_TSA = raw_model.w_TSA

        if model.is_predeblur:
            model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
            model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
        else:
            if model.HR_in:
                model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            else:
                model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
        model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
        model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

        model.tsa_fusion.center = model.center

        model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
        model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

        model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
        model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
        model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
        model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
        model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
        model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
        model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
        model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
        model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
        model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
        model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
        model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
        model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

        model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

        model.recon_trunk = raw_model.recon_trunk

        model.upconv1 = raw_model.upconv1
        model.upconv2 = raw_model.upconv2
        model.pixel_shuffle = raw_model.pixel_shuffle
        model.HRconv = raw_model.HRconv
        model.conv_last = raw_model.conv_last

        model.lrelu = raw_model.lrelu

    #####################################################

        model.eval()
        model = model.to(device)

        #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
        subfolder_name_l = []

        subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
        subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

        subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
        for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):
            subfolder_name = osp.basename(subfolder)
            subfolder_name_l.append(subfolder_name)

            img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
            max_idx = len(img_path_l)

            print("MAX_IDX: ", max_idx)


            #### read LQ and GT images
            imgs_LQ = data_util.read_img_seq(subfolder)
            img_GT_l = []
            for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                img_GT_l.append(data_util.read_img(None, img_GT_path))

            img_GT_a = []
            for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                img_GT_a.append(data_util.read_img(None, img_GT_a_path))
            #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
            for img_idx, img_path in enumerate(img_path_l):
                img_name = osp.splitext(osp.basename(img_path))[0]
                select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs_in)
                else:
                    print("IMGS_IN SHAPE: ", imgs_in.shape)
                    output = util.single_forward(model, imgs_in)
                output = util.tensor2img(output.squeeze(0))

                # calculate PSNR
                output = output / 255.
                GT = np.copy(img_GT_l[img_idx])
                # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
                GT_a = np.copy(img_GT_a[img_idx])
                GT_a = data_util.bgr2ycbcr(GT_a, only_y=True)
                output_a = copy.deepcopy(output)

                output, GT = util.crop_border([output, GT], crop_border)
                crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                t = vid4_results[subfolder_name].get(str(img_name))

                if t != None:
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                else:
                    vid4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)


    ############################################################################
    #### model



#### writing vid4  results


    util.mkdirs('../results/calendar')
    util.mkdirs('../results/city')
    util.mkdirs('../results/foliage')
    util.mkdirs('../results/walk')
    save_folder = '../results/'

    for i, dir_name in enumerate(["calendar", "city", "foliage", "walk"]):
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in vid4_results[dir_name].items():
         #   cur_result = json.dumps(_)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
                #json.dump(cur_result, outfile)

            #cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)



###################################################################################





    # STAGE REDS

    reds4_results = {"000": {}, "011": {}, "015": {}, "020": {}}
    data_mode = 'sharp_bicubic'

    N_model_default = 5

    for N_in in range(1, N_model_default + 1):
        for stage in range(1,3):

            flip_test = False

            if data_mode == 'sharp_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
            elif data_mode == 'blur_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
            elif data_mode == 'blur':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
            elif data_mode == 'blur_comp':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
            else:
                raise NotImplementedError

            predeblur, HR_in = False, False
            back_RBs = 40
            if data_mode == 'blur_bicubic':
                predeblur = True
            if data_mode == 'blur' or data_mode == 'blur_comp':
                predeblur, HR_in = True, True
            if stage == 2:
                HR_in = True
                back_RBs = 20

            if stage == 1:
                test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            else:
                test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
                print('You should modify the test_dataset_folder path for stage 2')
            GT_dataset_folder = '../datasets/REDS4/GT'

            raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
            model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate
            # temporal padding mode
            if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
                padding = 'new_info'
            else:
                padding = 'replicate'
            save_imgs = True

            data_mode_t = copy.deepcopy(data_mode)
            if stage == 1 and data_mode_t != 'Vid4':
                data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
            save_folder = '../results/{}'.format(data_mode)
            data_mode = copy.deepcopy(data_mode_t)
            util.mkdirs(save_folder)
            util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)


            aposterior_GT_dataset_folder = '../datasets/REDS4/GT_5'

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate

            raw_model.load_state_dict(torch.load(model_path), strict=True)

            model.nf = raw_model.nf
            model.center = N_in // 2  # if center is None else center
            model.is_predeblur = raw_model.is_predeblur
            model.HR_in = raw_model.HR_in
            model.w_TSA = raw_model.w_TSA

            if model.is_predeblur:
                model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
                model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
            else:
                if model.HR_in:
                    model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                    model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                    model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                else:
                    model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
            model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

            model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

            model.tsa_fusion.center = model.center

            model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
            model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

            model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
            model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
            model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
            model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
            model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
            model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
            model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
            model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
            model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
            model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
            model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
            model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
            model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

            model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

            model.recon_trunk = raw_model.recon_trunk

            model.upconv1 = raw_model.upconv1
            model.upconv2 = raw_model.upconv2
            model.pixel_shuffle = raw_model.pixel_shuffle
            model.HRconv = raw_model.HRconv
            model.conv_last = raw_model.conv_last

            model.lrelu = raw_model.lrelu

    #####################################################

            model.eval()
            model = model.to(device)

            #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
            subfolder_name_l = []

            subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
            subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

            subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
            for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):

                subfolder_name = osp.basename(subfolder)
                subfolder_name_l.append(subfolder_name)
                save_subfolder = osp.join(save_folder, subfolder_name)

                img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
                max_idx = len(img_path_l)

                print("MAX_IDX: ", max_idx)

                print("SAVE FOLDER::::::", save_folder)

                if save_imgs:
                    util.mkdirs(save_subfolder)


            #### read LQ and GT images
                imgs_LQ = data_util.read_img_seq(subfolder)
                img_GT_l = []
                for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                    img_GT_l.append(data_util.read_img(None, img_GT_path))

                img_GT_a = []
                for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                    img_GT_a.append(data_util.read_img(None, img_GT_a_path))
                #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
                for img_idx, img_path in enumerate(img_path_l):
                    img_name = osp.splitext(osp.basename(img_path))[0]
                    select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                    imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                    if flip_test:
                        output = util.flipx4_forward(model, imgs_in)
                    else:
                        print("IMGS_IN SHAPE: ", imgs_in.shape)
                        output = util.single_forward(model, imgs_in)
                    output = util.tensor2img(output.squeeze(0))

                    if save_imgs and stage == 1:
                        cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
                    # calculate PSNR
                    if stage == 2:

                        output = output / 255.
                        GT = np.copy(img_GT_l[img_idx])
                        # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                        #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                        GT_a = np.copy(img_GT_a[img_idx])
                        output_a = copy.deepcopy(output)

                        output, GT = util.crop_border([output, GT], crop_border)
                        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                        crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                        output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                        crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                        t = reds4_results[subfolder_name].get(str(img_name))

                        if t != None:
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                        else:
                            reds4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)



    ############################################################################
    #### model



#### writing reds4  results

    util.mkdirs('../results/000')
    util.mkdirs('../results/011')
    util.mkdirs('../results/015')
    util.mkdirs('../results/020')
    save_folder = '../results/'

    for i, dir_name in enumerate(["000", "011", "015", "020"]):     #   +
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in reds4_results[dir_name].items():
           # cur_result = json.dumps(value.__dict__)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
Example #6
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch', 'slurm'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)
    #pdb.set_trace()
    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist(args.launcher)
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    #pdb.set_trace()
    dataset_ratio = 1000  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if rank <= 0:
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)
            if current_step % 50000 == 0:
                torch.cuda.empty_cache()

        torch.cuda.empty_cache()

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Example #7
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'SDR_4bit'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'SDR_4bit':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    # use N_in images to restore one high bitdepth image
    N_in = 5

    # predeblur: predeblur for blurry input
    # HR_in: downsample high resolution input
    predeblur, HR_in = False, False
    back_RBs = 40
    predeblur = True
    HR_in = True
    if data_mode == 'SDR_4bit':
        # predeblur, HR_in = False, True
        pass
    if stage == 2:
        HR_in = True
        back_RBs = 20
    # EDVR(num_feature_map, num_input_frames, deformable_groups?, front_RBs,
    #      back_RBs, predeblur, HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if stage == 1:
        test_dataset_folder = '../datasets/{}'.format(data_mode)
    else:
        test_dataset_folder = '../'
        print('You should modify the test_dataset_folder path for stage 2')
    GT_dataset_folder = '../datasets/SDR_10bit/'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LBD and GT images
        #### resize to avoid cuda out of memory, 2160x3840->720x1280
        imgs_LBD = data_util.read_img_seq(subfolder,
                                          scale=65535.,
                                          zoomout=(1280, 720))
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(
                data_util.read_img(None,
                                   img_GT_path,
                                   scale=65535.,
                                   zoomout=True))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            # generate frame index
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LBD.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                # self ensemble with fipping input at four different directions
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0), out_type=np.uint16)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            # output = output / 255.
            output = output / 65535.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 65535, GT * 65535)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #8
0
def main():
    #################
    # configurations
    #################
    flip_test = False
    scale = 4
    N_in = 7
    predeblur, HR_in = False, False
    n_feats = 128
    back_RBs = 40

    save_imgs = False
    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode', '-t', type=str, default='REDS', help='train mode')
    prog.add_argument('--data_mode', '-m', type=str, default=None, help='data_mode')
    prog.add_argument('--degradation_mode', '-d', type=str, default='impulse', choices=('impulse', 'bicubic', 'preset'), help='path to image output directory.')
    prog.add_argument('--sigma_x', '-sx', type=float, default=1, help='sigma_x')
    prog.add_argument('--sigma_y', '-sy', type=float, default=0, help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')
    prog.add_argument('--model', type=str, default=None, help='name for subdirectory')

    args = prog.parse_args()

    train_data_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_data_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_data_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../dataset/Vid4/LR_{}/X1_KG_ZSSR'.format(folder_subname)
        #test_dataset_folder = '../dataset/Vid4/LR_{}/X1_CF_DBPN'.format(folder_subname)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X1_KG_ZSSR'
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_KG_ZSSR'.format(folder_subname)
        #test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_CF_DBPN'.format(folder_subname)
        GT_dataset_folder = '../dataset/REDS/train/HR'

    #### evaluation
    crop_border = 0
    border_frame = 0 # N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    if data_mode == 'REDS':
        subfolder_GT_l = [k for k in subfolder_GT_l if
                          k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0]
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder_GT, '*5.png')))  ## *5.png
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        img_LQ_l = []
        img_GT_l = []
        for img_LQ_path in sorted(glob.glob(osp.join(subfolder, '*5.png.png'))):
            img_LQ_l.append(data_util.read_img(None, img_LQ_path))

        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*5.png'))):   ### *5.png
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = np.copy(img_LQ_l[img_idx])
            GT = np.copy(img_GT_l[img_idx])
            '''
            output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1)
            GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor)
            torch.save(output_tensor.cpu(), '../results/sr_test.pt')
            torch.save(GT_tensor.cpu(), '../results/hr_test.pt')
            my_psnr = utility.calc_psnr(output_tensor, GT_tensor)
            GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0)
            imageio.imwrite('../results/hr_test.png', GT_tensor)
            print('saved', my_psnr)
            '''
            '''
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = util.calculate_ssim(output, GT)

            # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(subfolder_name, avg_ssim,
                                                                   (N_center + N_border),
                                                                   avg_ssim_center, N_center,
                                                                   avg_ssim_border, N_border))

    '''
    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))
    for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l,
                                                              avg_ssim_center_l, avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f}. '
                    'Center SSIM: {:.6f}. '
                    'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border))
    '''
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))

    print('\n\n\n')
Example #9
0
    def do_step(self, train_data):
        if self._profile:
            print("Data fetch: %f" % (time() - _t))
            _t = time()

        opt = self.opt
        self.current_step += 1
        #### update learning rate
        self.model.update_learning_rate(
            self.current_step, warmup_iter=opt['train']['warmup_iter'])

        #### training
        if self._profile:
            print("Update LR: %f" % (time() - _t))
            _t = time()
        self.model.feed_data(train_data, self.current_step)
        self.model.optimize_parameters(self.current_step)
        if self._profile:
            print("Model feed + step: %f" % (time() - _t))
            _t = time()

        #### log
        if self.current_step % opt['logger'][
                'print_freq'] == 0 and self.rank <= 0:
            logs = self.model.get_current_log(self.current_step)
            message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                self.epoch, self.current_step)
            for v in self.model.get_current_learning_rate():
                message += '{:.3e},'.format(v)
            message += ')] '
            for k, v in logs.items():
                if 'histogram' in k:
                    self.tb_logger.add_histogram(k, v, self.current_step)
                elif isinstance(v, dict):
                    self.tb_logger.add_scalars(k, v, self.current_step)
                else:
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        self.tb_logger.add_scalar(k, v, self.current_step)
            if opt['wandb'] and self.rank <= 0:
                import wandb
                wandb.log(logs)
            self.logger.info(message)

        #### save models and training states
        if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
            if self.rank <= 0:
                self.logger.info('Saving models and training states.')
                self.model.save(self.current_step)
                self.model.save_training_state(self.epoch, self.current_step)
            if 'alt_path' in opt['path'].keys():
                import shutil
                print("Synchronizing tb_logger to alt_path..")
                alt_tblogger = os.path.join(opt['path']['alt_path'],
                                            "tb_logger")
                shutil.rmtree(alt_tblogger, ignore_errors=True)
                shutil.copytree(self.tb_logger_path, alt_tblogger)

        #### validation
        if opt['datasets'].get(
                'val',
                None) and self.current_step % opt['train']['val_freq'] == 0:
            if opt['model'] in [
                    'sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'
            ] and self.rank <= 0:  # image restoration validation
                avg_psnr = 0.
                avg_fea_loss = 0.
                idx = 0
                val_tqdm = tqdm(self.val_loader)
                for val_data in val_tqdm:
                    idx += 1
                    for b in range(len(val_data['HQ_path'])):
                        img_name = os.path.splitext(
                            os.path.basename(val_data['HQ_path'][b]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)

                        util.mkdir(img_dir)

                        self.model.feed_data(val_data, self.current_step)
                        self.model.test()

                        visuals = self.model.get_current_visuals()
                        if visuals is None:
                            continue

                        sr_img = util.tensor2img(visuals['rlt'][b])  # uint8
                        # calculate PSNR
                        if self.val_compute_psnr:
                            gt_img = util.tensor2img(visuals['hq'][b])  # uint8
                            sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                              opt['scale'])
                            avg_psnr += util.calculate_psnr(sr_img, gt_img)

                        # Save SR images for reference
                        img_base_name = '{:s}_{:d}.png'.format(
                            img_name, self.current_step)
                        save_img_path = os.path.join(img_dir, img_base_name)
                        util.save_img(sr_img, save_img_path)

                avg_psnr = avg_psnr / idx
                avg_fea_loss = avg_fea_loss / idx

                # log
                self.logger.info(
                    '# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(
                        avg_psnr, avg_fea_loss))

                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt[
                        'name'] and self.rank <= 0:
                    self.tb_logger.add_scalar('val_psnr', avg_psnr,
                                              self.current_step)
                    self.tb_logger.add_scalar('val_fea', avg_fea_loss,
                                              self.current_step)

        if len(self.evaluators
               ) != 0 and self.current_step % opt['train']['val_freq'] == 0:
            eval_dict = {}
            for eval in self.evaluators:
                if eval.uses_all_ddp or self.rank <= 0:
                    eval_dict.update(eval.perform_eval())
            if self.rank <= 0:
                print("Evaluator results: ", eval_dict)
                for ek, ev in eval_dict.items():
                    self.tb_logger.add_scalar(ek, ev, self.current_step)
                if opt['wandb']:
                    import wandb
                    wandb.log(eval_dict)
Example #10
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    #### model
    data_mode = 'sharp_bicubic'
    if stage == 1:
        model_path = '../experiments/001_EDVRwoTSA_scratch_lr4e-4_600k_SR4K_LrCAR4S/models/200000_G.pth'
    else:
        model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'

    N_in = 3  # use N_in images to restore one HR image

    predeblur, HR_in = False, False
    back_RBs = 10
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in,
                           w_TSA=False)

    #### dataset
    val_txt = '/home/mcc/4khdr/val.txt'
    if stage == 1:
        test_dataset_folder = '/home/mcc/4khdr/image/540p'
        GT_dataset_folder = '/home/mcc/4khdr/image/4k'
    else:
        test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
        print('You should modify the test_dataset_folder path for stage 2')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = False

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    with open(val_txt, 'r') as f:
        split_name_l = []
        name_l = [x.strip() for x in f.readlines()]
        # for name in name_l:
        #     for i in range(4):
        #         split_name_l.append(name + 'x{}'.format(i))
    subfolder_l = sorted(
        [osp.join(test_dataset_folder, name) for name in name_l])
    subfolder_GT_l = sorted(
        [osp.join(GT_dataset_folder, name) for name in name_l])
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # evaluate on RGB channels
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)

            sys.stdout.write('\r' +
                             '{:03d}/{:03d}'.format(img_idx, len(img_path_l)))
            sys.stdout.flush()

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - frames {} -  Average PSNR: {:.6f} dB; '
                    'Center PSNR: {:.6f} dB; '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name,
                                                     (N_center + N_border),
                                                     avg_psnr, avg_psnr_center,
                                                     avg_psnr_border))
        break

    logger.info('################ Tidy Outputs ################')
    for i in range(len(subfolder_name_l)):
        logger.info(
            'Folder {} - Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB. '
            .format(subfolder_name_l[i], avg_psnr_l[i], avg_psnr_center_l[i],
                    avg_psnr_border_l[i]))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info(
        'Total clips {}. Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB.'
        'Score: {:.6f}'.format(len(subfolder_l),
                               sum(avg_psnr_l) / len(avg_psnr_l),
                               sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                               sum(avg_psnr_border_l) / len(avg_psnr_border_l),
                               sum(avg_psnr_l) / len(avg_psnr_l) / 50))
Example #11
0
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    #parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    print("GPU ", torch.cuda.device_count())
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device('cuda')
    
    data_mode = 'sharp_bicubic' 
    flip_test = False

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    model_path = Model_path

    N_in = 5

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])


    #### dataset
    test_dataset_folder = Input_folder
    GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'
    save_imgs = True

    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l = [], [], []

    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

    end = time.time()

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images, notice we load yuv img here
        imgs_LQ = data_util.read_img_seq_yuv(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img_yuv(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_rgb_psnr, avg_rgb_psnr_border, avg_rgb_psnr_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            scale = 4
 
            intPaddingRight = PAD   # 32# 64# 128# 256
            intPaddingLeft = PAD    # 32#64 #128# 256
            intPaddingTop = PAD     # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader(imgs_in)  # N C 604 1024

            # todo: output 4k

            X0 = imgs_in

            X0 = torch.unsqueeze(X0, 0)
            if flip_test:
                output = util.flipx4_forward(model, X0)
            else:
                output = util.single_forward(model, X0)

            # todo remove padding
            output = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight_ori) * scale,
                                  intPaddingLeft * scale: (intPaddingLeft + intWidth_ori) * scale]

            output = util.tensor2img(output.squeeze(0))

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")

            total_run_time.update(time.time() - end, 1)

            # calculate PSNR on YUV
            y_all = output / 255.
            GT = np.copy(img_GT_l[img_idx])

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tYUV_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))


            # here we also calculate PSNR on RGB
            y_all_rgb = data_util.ycbcr2rgb(output / 255.)
            GT_rgb = data_util.ycbcr2rgb(np.copy(img_GT_l[img_idx])) 
            y_all_rgb, GT_rgb = util.crop_border([y_all_rgb, GT_rgb], crop_border)
            crt_rgb_psnr = util.calculate_psnr(y_all_rgb * 255, GT_rgb * 255)
            logger.info('{:3d} - {:25} \tRGB_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_rgb_psnr))


            if save_imgs:
                im_out = np.round(y_all_rgb*255.).astype(numpy.uint8)
                # todo, notice here we got rgb img, but cv2 need bgr when saving a img
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), cv2.cv2Color(im_out, cv2.COLOR_RGB2BGR))

            # for YUV and RGB, respectively
            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_rgb_psnr_center += crt_rgb_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_rgb_psnr_border += crt_rgb_psnr
                N_border += 1

        # for YUV
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average YUV PSNR: {:.6f} dB for {} frames; '
                    'Center YUV PSNR: {:.6f} dB for {} frames; '
                    'Border YUV PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

        # for RGB
        avg_rgb_psnr = (avg_rgb_psnr_center + avg_rgb_psnr_border) / (N_center + N_border)
        avg_rgb_psnr_center = avg_rgb_psnr_center / N_center
        avg_rgb_psnr_border = 0 if N_border == 0 else avg_rgb_psnr_border / N_border
        avg_rgb_psnr_l.append(avg_rgb_psnr)
        avg_rgb_psnr_center_l.append(avg_rgb_psnr_center)
        avg_rgb_psnr_border_l.append(avg_rgb_psnr_border)

        logger.info('Folder {} - Average RGB PSNR: {:.6f} dB for {} frames; '
                    'Center RGB PSNR: {:.6f} dB for {} frames; '
                    'Border RGB PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_rgb_psnr,
                                                                   (N_center + N_border),
                                                                   avg_rgb_psnr_center, N_center,
                                                                   avg_rgb_psnr_border, N_border))

        
    logger.info('################ Tidy Outputs ################')
    # for YUV
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average YUV PSNR: {:.6f} dB. '
                    'Center YUV PSNR: {:.6f} dB. '
                    'Border YUV PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))

    # for RGB
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l):
        logger.info('Folder {} - Average RGB PSNR: {:.6f} dB. '
                    'Center RGB PSNR: {:.6f} dB. '
                    'Border RGB PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))


    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average YUV PSNR: {:.6f} dB for {} clips. '
                'Center YUV PSNR: {:.6f} dB. Border YUV PSNR: {:.6f} dB.'.format(
        sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
        sum(avg_psnr_center_l) / len(avg_psnr_center_l),
        sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average RGB PSNR: {:.6f} dB for {} clips. '
                'Center RGB PSNR: {:.6f} dB. Border RGB PSNR: {:.6f} dB.'.format(
        sum(avg_rgb_psnr_l) / len(avg_rgb_psnr_l), len(subfolder_l),
        sum(avg_rgb_psnr_center_l) / len(avg_rgb_psnr_center_l),
        sum(avg_rgb_psnr_border_l) / len(avg_rgb_psnr_border_l)))
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    save_imgs = False

    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode',
                      '-t',
                      type=str,
                      default='Vimeo',
                      help='train mode')
    prog.add_argument('--data_mode',
                      '-m',
                      type=str,
                      default=None,
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')

    args = prog.parse_args()

    train_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
    scale = 2
    layer = 16
    assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
                              ], 'Unrecognized (scale, layer) combination'

    # model
    N_in = 7
    # model_path = '../experiments/pretrained_models/DUF_{}L_BLIND_{}_FT_report.pth'.format(layer, train_mode[0])
    model_path = '../experiments/pretrained_models/DUF_{}L_{}_S{}.pth'.format(
        layer, train_mode, scale)
    # model_path = '../experiments/pretrained_models/DUF_x2_16L_official.pth'
    adapt_official = True  # if 'official' in model_path else False
    DUF_downsampling = False  # True | False
    if layer == 16:
        model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
    elif layer == 28:
        model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
    elif layer == 52:
        model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)

    #### dataset
    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str(
            '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    # folder_subname = degradation_mode + '_' + str('{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))
    if data_mode == 'Vid4':
        # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format(
            scale)
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/REDS/train/HR'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'  # different from the official testing codes, which pads zeros.
    ############################################################################
    device = torch.device('cuda')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    sub_folder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    if data_mode == 'REDS':
        sub_folder_GT_l = [
            k for k in sub_folder_GT_l if k.find('000') >= 0
            or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0
        ]
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    # for each sub-folder
    for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l):
        sub_folder_name = sub_folder.split('/')[-1]
        subfolder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []

        for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))):
            img_GT_l.append(read_image(img_GT_path))

        # When using the downsampling in DUF official code, we downsample the HR images
        if DUF_downsampling:
            sub_folder = sub_folder_GT
            img_path_l = sorted(glob.glob(sub_folder + '/*'))
            max_idx = len(img_path_l)
            imgs = read_seq_imgs(sub_folder)

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

        # process each image
        num_images = len(img_path_l)
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx,
                                          max_idx,
                                          N_in,
                                          padding=padding)
            # get input images
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            # Downsample the HR images
            H, W = imgs_in.size(3), imgs_in.size(4)
            if DUF_downsampling:
                imgs_in = util.DUF_downsample(imgs_in, sigma=1.3, scale=scale)

            output = single_forward(model, imgs_in)

            # Crop to the original shape
            if scale == 3:
                pad_h = scale - (H % scale)
                pad_w = scale - (W % scale)
                if pad_h > 0:
                    output = output[:, :, :-pad_h, :]
                if pad_w > 0:
                    output = output[:, :, :, :-pad_w]
            output_f = output.data.float().cpu().squeeze(0)
            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(
                    osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)),
                    output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = util.calculate_ssim(output, GT)

            logger.info(
                '{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(
                    img_idx + 1, img_name, crt_psnr, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        sub_folder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))
        '''
        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(sub_folder_name, avg_ssim,
                                                                (N_center + N_border),
                                                                avg_ssim_center, N_center,
                                                                avg_ssim_border, N_border))
        '''
    '''
    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))
    for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l,
                                                              avg_ssim_center_l, avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f}. '
                    'Center SSIM: {:.6f}. '
                    'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border))
    '''
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(sub_folder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))

    print('\n\n\n')
Example #13
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### loading resume state if exists
    if 'resume_latest' in opt and opt['resume_latest'] == True:
        if os.path.isdir(opt['path']['training_state']):
            name_state_files = os.listdir(opt['path']['training_state'])
            if len(name_state_files) > 0:
                latest_state_num = 0
                for name_state_file in name_state_files:
                    state_num = int(name_state_file.split('.')[0])
                    if state_num > latest_state_num:
                        latest_state_num = state_num
                opt['path']['resume_state'] = os.path.join(
                    opt['path']['training_state'], str(latest_state_num)+'.state')
            else:
                raise ValueError
    if opt['path'].get('resume_state', None):
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if resume_state is None:
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename experiment folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                      screen=True, tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        version = float(torch.__version__[0:3])
        if version >= 1.1:  # PyTorch 1.1
            from torch.utils.tensorboard import SummaryWriter
        else:
            logger.info(
                'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
            from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' +
                                  opt['name'] + '_{}'.format(util.get_timestamp()))

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            train_sampler = None
            train_loader = create_dataloader(
                train_set, dataset_opt, opt, train_sampler)
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    is_time = False

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))

    if is_time:
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')

    for epoch in range(start_epoch, total_epochs + 1):
        if current_step > total_iters:
            break

        if is_time:
            torch.cuda.synchronize()
            end = time.time()
        for _, train_data in enumerate(train_loader):
            if 'adv_train' in opt:
                current_step += opt['adv_train']['m']
            else:
                current_step += 1
            if current_step > total_iters:
                break

            #### training
            model.feed_data(train_data)

            if is_time:
                torch.cuda.synchronize()
                data_time.update(time.time() - end)

            model.optimize_parameters(current_step)

            #### update learning rate
            model.update_learning_rate(
                current_step, warmup_iter=opt['train']['warmup_iter'])

            if is_time:
                torch.cuda.synchronize()
                batch_time.update(time.time() - end)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                # FIXME remove debug
                debug = True
                if debug:
                    torch.cuda.empty_cache()
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)
                if is_time:
                    logger.info(str(data_time))
                    logger.info(str(batch_time))

            #### validation
            if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in ['sr', 'srgan']:  # image restoration validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(
                            opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border(
                            [sr_img, gt_img], opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)
                tb_logger.flush()
            if is_time:
                torch.cuda.synchronize()
                end = time.time()

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
    tb_logger.close()
Example #14
0
def main(name_flag, input_path, gt_path, model_path, save_path, save_imgs,
         flip_test):
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    save_path = os.path.join(save_path, name_flag)

    #### model
    model = CNLRN_arch.CNLRN(n_colors=3,
                             n_deblur_blocks=20,
                             n_nlrgs_body=6,
                             n_nlrgs_up1=2,
                             n_nlrgs_up2=2,
                             n_subgroups=2,
                             n_rcabs=4,
                             n_feats=64,
                             nonlocal_psize=(4, 4, 4, 4),
                             scale=4)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    #### logger
    util.mkdirs(save_path)
    util.setup_logger('base',
                      save_path,
                      name_flag,
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    logger.info('Evaluate: {}'.format(name_flag))
    logger.info('Input images path: {}'.format(input_path))
    logger.info('GT images path: {}'.format(gt_path))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Results save path: {}'.format(save_path))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Save images: {}'.format(save_imgs))

    #### Evaluation
    total_psnr_l = []
    total_ssim_l = []

    img_path_l = sorted(glob.glob(osp.join(input_path, '*')))

    #### read LQ and GT images
    imgs_LQ = data_util.read_img_seq(input_path)
    img_GT_l = []
    for img_GT_path in sorted(glob.glob(osp.join(gt_path, '*'))):
        img_GT_l.append(data_util.read_img(None, img_GT_path))

    # process each image
    for img_idx, img_path in enumerate(img_path_l):
        img_name = osp.splitext(osp.basename(img_path))[0]
        imgs_in = imgs_LQ[img_idx:img_idx + 1].to(device)

        if flip_test:
            output = util.flipx8_forward(model, imgs_in)
        else:
            output = util.single_forward(model, imgs_in)

        output = util.tensor2img(output.squeeze(0))

        if save_imgs:
            cv2.imwrite(osp.join(save_path, '{}.png'.format(img_name)), output)

        # calculate PSNR
        output = output / 255.
        GT = np.copy(img_GT_l[img_idx])

        output, GT = util.crop_border([output, GT], crop_border=4)
        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
        crt_ssim = util.ssim(output * 255, GT * 255)
        total_psnr_l.append(crt_psnr)
        total_ssim_l.append(crt_ssim)

        logger.info('{} \tPSNR: {:.3f} \tSSIM: {:.4f}'.format(
            img_name, crt_psnr, crt_ssim))

    logger.info('################ Final Results ################')
    logger.info('Evaluate: {}'.format(name_flag))
    logger.info('Input images path: {}'.format(input_path))
    logger.info('GT images path: {}'.format(gt_path))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Results save path: {}'.format(save_path))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info(
        'Total Average PSNR: {:.3f} SSIM: {:.4f} for {} images.'.format(
            sum(total_psnr_l) / len(total_psnr_l),
            sum(total_ssim_l) / len(total_ssim_l), len(img_path_l)))
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    flip_test = False
    scale = 4
    N_in = 5
    predeblur, HR_in = False, False
    n_feats = 128
    back_RBs = 40

    save_imgs = False
    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode',
                      '-t',
                      type=str,
                      default='REDS',
                      help='train mode')
    prog.add_argument('--data_mode',
                      '-m',
                      type=str,
                      default=None,
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')

    args = prog.parse_args()

    train_data_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_data_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_data_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    ############################################################################
    #### model
    if scale == 2:
        if train_data_mode == 'Vimeo':
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_M_Scale2_FT.pth'
            #model_path = '../experiments/pretrained_models/EDVR_M_BLIND_V_FT_report.pth'
            # model_path = '../experiments/pretrained_models/2500_G.pth'
        elif train_data_mode == 'REDS':
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M_Scale2.pth'
            # model_path = '../experiments/pretrained_models/EDVR_M_BLIND_R_FT_report.pth'
        elif train_data_mode == 'Both':
            model_path = '../experiments/pretrained_models/EDVR_REDS+Vimeo90K_SR_M_Scale2_FT.pth'
        elif train_data_mode == 'MM522':
            model_path = '../experiments/pretrained_models/EDVR_MM522_SR_M_Scale2_FT.pth'

        else:
            raise NotImplementedError
    else:
        if data_mode == 'Vid4':
            model_path = '../experiments/pretrained_models/EDVR_BLIND_Vimeo_SR_L.pth'
            # model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'

        elif data_mode == 'REDS':
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
            # model_path = '../experiments/pretrained_models/EDVR_BLIND_REDS_SR_L.pth'
        else:
            raise NotImplementedError

    model = EDVR_arch.EDVR(n_feats,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in,
                           scale=scale)

    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str(
            '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    #### dataset
    if data_mode == 'Vid4':
        # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format(
            scale)
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/REDS/train/HR'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    if data_mode == 'REDS':
        subfolder_GT_l = [
            k for k in subfolder_GT_l if k.find('000') >= 0
            or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0
        ]
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            '''
            output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1)
            GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor)
            torch.save(output_tensor.cpu(), '../results/sr_test.pt')
            torch.save(GT_tensor.cpu(), '../results/hr_test.pt')
            my_psnr = utility.calc_psnr(output_tensor, GT_tensor)
            GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0)
            imageio.imwrite('../results/hr_test.png', GT_tensor)
            print('saved', my_psnr)
            '''
            '''
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = 0.001  #util.calculate_ssim(output, GT)

            # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))
    '''
    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))
    for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l,
                                                              avg_ssim_center_l, avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f}. '
                    'Center SSIM: {:.6f}. '
                    'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border))
    '''
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))

    print('\n\n\n')
Example #16
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options YMAL file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    opt = option.dict_to_nonedict(opt)

    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
    util.setup_logger('base',
                      opt['path']['log'],
                      'test_' + opt['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    #### Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    model = create_model(opt)
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_GT = False if test_loader.dataset.opt[
                'dataroot_GT'] is None else True
            model.feed_data(data, need_GT=need_GT)
            img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
            img_name = osp.splitext(osp.basename(img_path))[0]

            model.test()
            visuals = model.get_current_visuals(need_GT=need_GT)

            sr_img = util.tensor2img(visuals['rlt'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir,
                                         img_name + suffix + '.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(visuals['GT'])
                sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                  opt['scale'])
                psnr = util.calculate_psnr(sr_img, gt_img)
                ssim = util.calculate_ssim(sr_img, gt_img)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                    psnr_y = util.calculate_psnr(sr_img_y * 255,
                                                 gt_img_y * 255)
                    ssim_y = util.calculate_ssim(sr_img_y * 255,
                                                 gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_GT:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info(
                '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info(
                    '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                    .format(ave_psnr_y, ave_ssim_y))
Example #17
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '6'
    test_set = 'AI4K_val'  # Vid4 | YouKu10 | REDS4 | AI4K_val | zhibo | AI4K_val_bic
    test_name = 'PCD_Vis_Test_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_A01xxx_900000_AI4K_5000'  #     'AI4K_val_Denoise_A02_420000'
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_val_bic':
        test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_bic/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/'
    elif test_set == 'zhibo':
        test_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_LR/'
        GT_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_HR/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #model_path = '../experiments/A01b/models/250000_G.pth'
    #model_path = '../experiments/A02_predenoise/models/415000_G.pth'
    model_path = '../experiments/A37_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/5000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR_RES(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True  #True | False

    save_folder = '../results/{}'.format(test_name)
    if test_set == 'zhibo':
        save_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/SR_png_sample_150'
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    print(subfolder_l)
    print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0,
                torch.LongTensor(select_idx)).unsqueeze(0).cpu()  #to(device)
            print(imgs_in.size())

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                start_time = time.time()
                output = util.single_forward(model, imgs_in)
                end_time = time.time()
                print('Forward One image:', end_time - start_time)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #18
0
def val(model_name, current_step, arch='EDVR'):
    #################
    # configurations
    #################
    device = torch.device('cuda')
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4'
    test_set = 'REDS4'  # Vid4 | YouKu10 | REDS4 | AI4K_val
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_LR_png/'
        GT_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_val_small':
        test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_small/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_small/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    model_path = os.path.join('../experiments/', model_name,
                              'models/{}_G.pth'.format(current_step))

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    if arch == 'EDVR':
        model = EDVR_arch.EDVR(64,
                               N_in,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    elif arch == 'MY_EDVR':
        model = my_EDVR_arch.MYEDVR(64,
                                    N_in,
                                    8,
                                    5,
                                    back_RBs,
                                    predeblur=predeblur,
                                    HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = False

    save_folder = '../validation/{}'.format(test_set)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            #print(imgs_in.size())

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))

    return sum(avg_psnr_l) / len(avg_psnr_l)
Example #19
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items() if
                 not key == 'experiments_root' and 'pretrain_model' not in key
                 and 'resume' not in key and 'wandb_load_run_path' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            json_path = os.path.join(os.path.expanduser('~'),
                                     '.wandb_api_keys.json')
            if os.path.exists(json_path):
                with open(json_path, 'r') as j:
                    json_file = json.loads(j.read())
                    os.environ['WANDB_API_KEY'] = json_file['ryul99']
            wandb.init(project="mmsr", config=opt, sync_tensorboard=True)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            wandb.config.update({'random_seed': seed})
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data,
                            noise_mode=opt['datasets']['train']['noise_mode'],
                            noise_rate=opt['datasets']['train']['noise_rate'])
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            wandb.log({k: v}, step=current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in [
                        'sr', 'srgan'
                ] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(
                            val_data,
                            noise_mode=opt['datasets']['val']['noise_mode'],
                            noise_rate=opt['datasets']['val']['noise_rate'])
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        wandb.log({'psnr': avg_psnr}, step=current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                            if opt['use_wandb_logger'] and 'debug' not in opt[
                                    'name']:
                                lq_img, rlt_img, gt_img = map(
                                    util.tensor2img, [
                                        visuals['LQ'], visuals['rlt'],
                                        visuals['GT']
                                    ])
                                wandb.log({'psnr_avg': psnr_total_avg},
                                          step=current_step)
                                wandb.log(psnr_rlt_avg, step=current_step)
                                wandb.log(
                                    {
                                        'Validation Image': [
                                            wandb.Image(lq_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='LQ'),
                                            wandb.Image(rlt_img[:, :,
                                                                [2, 1, 0]],
                                                        caption='output'),
                                            wandb.Image(gt_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='GT'),
                                        ]
                                    },
                                    step=current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                        if opt['use_wandb_logger'] and 'debug' not in opt[
                                'name']:
                            lq_img, rlt_img, gt_img = map(
                                util.tensor2img,
                                [visuals['LQ'], visuals['rlt'], visuals['GT']])
                            wandb.log({'psnr_avg': psnr_total_avg},
                                      step=current_step)
                            wandb.log(psnr_rlt_avg, step=current_step)
                            wandb.log(
                                {
                                    'Validation Image': [
                                        wandb.Image(lq_img[:, :, [2, 1, 0]],
                                                    caption='LQ'),
                                        wandb.Image(rlt_img[:, :, [2, 1, 0]],
                                                    caption='output'),
                                        wandb.Image(gt_img[:, :, [2, 1, 0]],
                                                    caption='GT'),
                                    ]
                                },
                                step=current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            tb_logger.close()
Example #20
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    print("Model Created! ")

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
        Avg_train_loss_vmaf = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'ssim'):
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'msssim'):
        Avg_train_loss_msssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_msssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
        elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
            Avg_train_loss_vmaf.reset()
        elif (opt['train']['pixel_criterion'] == 'ssim'):
            Avg_train_loss_ssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'msssim'):
            Avg_train_loss_msssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
            Avg_train_loss_pix = AverageMeter()
            Avg_train_loss_msssim = AverageMeter()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))
                util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[0, i, ...])  # uint8
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    util.save_img(LQ_img, save_img_path)

                if (train_idx >= 3):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' vmaf_avg_loss: {:.4e}'.format(
                    Avg_train_loss_vmaf.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                message_train_loss = ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                message_train_loss = ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states
        # if current_step % opt['logger']['save_checkpoint_freq'] == 0:
        #     if rank <= 0:
        #         logger.info('Saving models and training states.')
        #         model.save(current_step)
        #         model.save('latest')
        #         # model.save_training_state(epoch, current_step)
        #         # todo delete previous weights
        #         previous_step = current_step - opt['logger']['save_checkpoint_freq']
        #         save_filename = '{}_{}.pth'.format(previous_step, 'G')
        #         save_path = os.path.join(opt['path']['models'], save_filename)
        #         if os.path.exists(save_path):
        #             os.remove(save_path)

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        print('idx', idx)

                        if 'debug' in opt['name']:
                            if (idx >= 3):
                                break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        if 'debug' in opt['name']:
                            if (val_inx >= 5):
                                break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        print(val_data['LQs'].size())

                        H_S = val_data['LQs'].size(3)  # 540
                        W_S = val_data['LQs'].size(4)  # 960

                        print(H_S)
                        print(W_S)

                        blank_1_S = 0
                        blank_2_S = 0

                        print(val_data['LQs'][0, 2, 0, :, :].size())

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0:
                                blank_1_S = i - 1
                                # assert not sum(data_S[:, :, 0][i+1]) == 0
                                break

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, :,
                                                       H_S - i - 1]) == 0:
                                blank_2_S = (H_S - 1) - i - 1
                                # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                                break
                        print('LQ :', blank_1_S, blank_2_S)

                        if blank_1_S == -1:
                            print('LQ has no blank')
                            blank_1_S = 0
                            blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        print("LQ", val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        blank_1_L = blank_1_S << 2
                        blank_2_L = blank_2_S << 2
                        print(blank_1_L, blank_2_L)

                        print(model.fake_H.size())

                        if not blank_1_S == 0:
                            # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                            model.fake_H[:, :, 0:blank_1_L, :] = 0
                            model.fake_H[:, :, blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average VMAF

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            logger.info("Finished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
            if saved_total_loss >= val_loss_total_avg:
                saved_total_loss = val_loss_total_avg
                #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                model.save('best')
                logger.info(
                    "Best Weights updated for decreased validation loss")

            else:
                logger.info(
                    "Weights Not updated for undecreased validation loss")

            if saved_total_PSNR <= psnr_total_avg:
                saved_total_PSNR = psnr_total_avg
                model.save('bestPSNR')
                logger.info(
                    "Best Weights updated for increased validation PSNR")

            else:
                logger.info(
                    "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            #model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M.pth'
            model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 5  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(64, 5, 8, 5, 10, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames'
        GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames'
        #test_dataset_folder = '../datasets/Vid4/BIx4'
        #GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #22
0
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(description="Test with EDVR, requre path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')
        GT_dataset_folder = os.path.join(args.input, 'GT')
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Example #23
0
        visuals = model.get_current_visuals(need_GT=need_GT)

        sr_img = util.tensor2img(visuals['rlt'])  # uint8

        # save images
        suffix = opt['suffix']
        if suffix:
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
            sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
            psnr = util.calculate_psnr(sr_img, gt_img)
            ssim = util.calculate_ssim(sr_img, gt_img)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
                ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
Example #24
0
def main(opts):
    ################## configurations #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus
    cache_all_imgs = opts.cache > 0
    n_gpus = len(opts.gpus.split(','))
    flip_test, save_imgs = False, False
    scale = 4
    N_in, nf = 5, 64
    back_RBs = 10
    w_TSA = False
    predeblur, HR_in = False, False
    crop_border = 0
    border_frame = N_in // 2
    padding = 'new_info'

    ################## model files ####################
    model_dir = opts.model_dir
    if osp.isfile(model_dir):
        model_names = [osp.basename(model_dir)]
        model_dir = osp.dirname(model_dir)
    elif osp.isdir(model_dir):
        model_names = [
            x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0])
        ]
        model_names = sorted(model_names, key=lambda x: int(x.split("_")[0]))
    else:
        raise IOError('Invalid model_dir: {}'.format(model_dir))

    ################## dataset ########################
    test_subs = sorted(os.listdir(opts.test_dir))
    gt_subs = os.listdir(opts.gt_dir)
    valid_test_subs = [sub in gt_subs for sub in test_subs]
    assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format(
        opts.test_dir)
    scale = float(os.path.basename(os.path.dirname(opts.test_dir))[1:])
    if cache_all_imgs:
        print('Cacheing all testing images ...')
        all_imgs = {}
        for sub in test_subs:
            print('Reading sub-folder: {} ...'.format(sub))
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            all_imgs[sub] = {'test': [], 'gt': []}
            im_names = sorted(os.listdir(test_sub_dir))
            for i, name in enumerate(im_names):
                test_im_path = osp.join(test_sub_dir, name)
                gt_im_path = osp.join(gt_sub_dir, name)
                test_im = cv2.imread(test_im_path,
                                     cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                test_im = test_im.astype(np.float32).transpose(
                    (2, 0, 1)) / 255.
                all_imgs[sub]['test'].append(test_im)
                gt_im = cv2.imread(gt_im_path,
                                   cv2.IMREAD_UNCHANGED).astype(np.float32)
                all_imgs[sub]['gt'].append(gt_im)

    all_psnrs = []
    for model_name in model_names:
        model_path = osp.join(model_dir, model_name)
        exp_name = model_name.split('_')[0]
        if 'meta' in opts.mode.lower():
            model = EDVR_arch.MetaEDVR(nf=nf,
                                       nframes=N_in,
                                       groups=8,
                                       front_RBs=5,
                                       center=None,
                                       back_RBs=back_RBs,
                                       predeblur=predeblur,
                                       HR_in=HR_in,
                                       w_TSA=w_TSA)
        elif opts.mode.lower() == 'edvr':
            model = EDVR_arch.EDVR(nf=nf,
                                   nframes=N_in,
                                   groups=8,
                                   front_RBs=5,
                                   center=None,
                                   back_RBs=back_RBs,
                                   predeblur=predeblur,
                                   HR_in=HR_in,
                                   w_TSA=w_TSA)
        elif opts.mode.lower() == 'upedvr':
            model = EDVR_arch.UPEDVR(nf=nf,
                                     nframes=N_in,
                                     groups=8,
                                     front_RBs=5,
                                     center=None,
                                     back_RBs=10,
                                     w_TSA=w_TSA,
                                     down_scale=True,
                                     align_target=True,
                                     ret_valid=True)
        elif opts.mode.lower() == 'upcont1':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=False)
        elif opts.mode.lower() == 'upcont3':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        elif opts.mode.lower() == 'upcont2':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        else:
            raise TypeError('Unknown model mode: {}'.format(opts.mode))
        save_folder = osp.join(opts.save_dir, exp_name)
        util.mkdirs(save_folder)
        util.setup_logger(exp_name,
                          save_folder,
                          'test',
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger(exp_name)

        #### log info
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))

        #### set up the models
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        if n_gpus > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], []
        avg_ssims, avg_ssim_centers, avg_ssim_borders = [], [], []
        evaled_subs = []

        # for each subfolder
        for sub in test_subs:
            evaled_subs.append(sub)
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            img_names = sorted(os.listdir(test_sub_dir))
            max_idx = len(img_names)
            if save_imgs:
                save_subfolder = osp.join(save_folder, sub)
                util.mkdirs(save_subfolder)

            #### get LQ and GT images
            if not cache_all_imgs:
                img_LQs, img_GTs = [], []
                for i, name in enumerate(img_names):
                    test_im_path = osp.join(test_sub_dir, name)
                    gt_im_path = osp.join(gt_sub_dir, name)
                    test_im = cv2.imread(test_im_path,
                                         cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                    test_im = test_im.astype(np.float32).transpose(
                        (2, 0, 1)) / 255.
                    gt_im = cv2.imread(gt_im_path,
                                       cv2.IMREAD_UNCHANGED).astype(np.float32)
                    img_LQs.append(test_im)
                    img_GTs.append(gt_im)
            else:
                img_LQs = all_imgs[sub]['test']
                img_GTs = all_imgs[sub]['gt']

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
            avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

            # process each image
            for i in range(0, max_idx, n_gpus):
                end = min(i + n_gpus, max_idx)
                select_idxs = [
                    data_util.index_generation(j,
                                               max_idx,
                                               N_in,
                                               padding=padding)
                    for j in range(i, end)
                ]
                imgs = []
                for select_idx in select_idxs:
                    im = torch.from_numpy(
                        np.stack([img_LQs[k] for k in select_idx]))
                    imgs.append(im)
                if (i + n_gpus) > max_idx:
                    for _ in range(max_idx, i + n_gpus):
                        imgs.append(torch.zeros_like(im))
                imgs = torch.stack(imgs, 0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs)
                else:
                    if 'meta' in opts.mode.lower():
                        output = util.meta_single_forward(
                            model, imgs, scale, n_gpus)
                    if 'up' in opts.mode.lower():
                        output = util.up_single_forward(model, imgs, scale)
                    else:
                        output = util.single_forward(model, imgs)
                output = [
                    util.tensor2img(x).astype(np.float32) for x in output
                ]

                if save_imgs:
                    for ii in range(i, end):
                        cv2.imwrite(
                            osp.join(save_subfolder,
                                     '{}.png'.format(img_names[ii])),
                            output[ii - i].astype(np.uint8))

                # calculate PSNR
                GT = np.copy(img_GTs[i:end])

                output = util.crop_border(output, crop_border)
                GT = util.crop_border(GT, crop_border)
                for m in range(i, end):
                    crt_psnr = util.calculate_psnr(output[m - i], GT[m - i])
                    crt_ssim = util.calculate_ssim(output[m - i], GT[m - i])
                    logger.info(
                        '{:3d} - {:25} \tPSNR: {:.6f} dB    SSIM: {:.6}'.
                        format(m + 1, img_names[m], crt_psnr, crt_ssim))

                    if m >= border_frame and m < max_idx - border_frame:  # center frames
                        avg_psnr_center += crt_psnr
                        avg_ssim_center += crt_ssim
                        N_center += 1
                    else:  # border frames
                        avg_psnr_border += crt_psnr
                        avg_ssim_border += crt_ssim
                        N_border += 1

            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border

            avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center +
                                                              N_border)
            avg_ssim_center = avg_ssim_center / N_center
            avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border

            avg_psnrs.append(avg_psnr)
            avg_psnr_centers.append(avg_psnr_center)
            avg_psnr_borders.append(avg_psnr_border)

            avg_ssims.append(avg_ssim)
            avg_ssim_centers.append(avg_ssim_center)
            avg_ssim_borders.append(avg_ssim_border)

            logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                        'Center PSNR: {:.6f} dB for {} frames; '
                        'Border PSNR: {:.6f} dB for {} frames.'.format(
                            sub, avg_psnr, (N_center + N_border),
                            avg_psnr_center, N_center, avg_psnr_border,
                            N_border))
            logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                        'Center SSIM: {:.6f} for {} frames; '
                        'Border SSIM: {:.6f} for {} frames.'.format(
                            sub, avg_ssim, (N_center + N_border),
                            avg_ssim_center, N_center, avg_ssim_border,
                            N_border))

        logger.info('################ Tidy Outputs ################')
        for sub_name, psnr, psnr_center, psnr_border, ssim, ssim_center, ssim_border in zip(
                evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders,
                avg_ssims, avg_ssim_centers, avg_ssim_borders):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sub_name, psnr, psnr_center, psnr_border))
            logger.info('Folder {} - Average SSIM: {:.6f} '
                        'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                            sub_name, ssim, ssim_center, ssim_border))

        logger.info('################ Final Results ################')
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))
        logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                    'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                        sum(avg_psnrs) / len(avg_psnrs), len(test_subs),
                        sum(avg_psnr_centers) / len(avg_psnr_centers),
                        sum(avg_psnr_borders) / len(avg_psnr_borders)))
        logger.info('Total Average SSIM: {:.6f} for {} clips. '
                    'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                        sum(avg_ssims) / len(avg_ssims), len(test_subs),
                        sum(avg_ssim_centers) / len(avg_ssim_centers),
                        sum(avg_ssim_borders) / len(avg_ssim_borders)))
Example #25
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument("-opt", type=str, help="Path to option YAML file.")
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    # distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt["dist"] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    # loading resume state if exists
    if opt["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state["iter"])  # check resume options
    else:
        resume_state = None

    # mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train_" + opt["name"],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt["use_tb_logger"] and "debug" not in opt["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info("You are using PyTorch {}. \
                            Tensorboard will use [tensorboardX]".format(
                    version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"])
    else:
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # random seed
    seed = opt["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info("Random seed: {}".format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)
    print("Model created!")

    # resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt["train"]["warmup_iter"])

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += "{:.3e},".format(v)
                message += ")] "
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            # validation
            if opt["datasets"].get(
                    "val",
                    None) and current_step % opt["train"]["val_freq"] == 0:
                # image restoration validation
                if opt["model"] in ["sr", "srgan"] and rank <= 0:
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.0
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data["LQ_path"][0]))[0]
                        img_dir = os.path.join(opt["path"]["val_images"],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals["rlt"])  # uint8
                        gt_img = util.tensor2img(visuals["GT"])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            "{:s}_{:d}.png".format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt["scale"])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update("Test {}".format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        tb_logger.add_scalar("psnr", avg_psnr, current_step)
                else:  # video restoration validation
                    if opt["dist"]:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data["LQs"].unsqueeze_(0)
                            val_data["GT"].unsqueeze_(0)
                            folder = val_data["folder"]
                            idx_d, max_idx = val_data["idx"].split("/")
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device="cuda")
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update("Test {} - {}/{}".format(
                                        folder, idx_d, max_idx))
                        # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.0
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = "# Validation # PSNR: {:.4e}:".format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += " {}: {:.4e}".format(k, v)
                            logger.info(log_s)
                            if opt["use_tb_logger"] and "debug" not in opt[
                                    "name"]:
                                tb_logger.add_scalar("psnr_avg",
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.0
                        for val_data in val_loader:
                            folder = val_data["folder"][0]
                            idx_d, max_id = val_data["idx"][0].split("/")
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            lq_img = util.tensor2img(visuals["LQ"][2])  # uint8

                            img_dir = opt["path"]["val_images"]
                            util.mkdir(img_dir)
                            save_img_path = os.path.join(
                                img_dir, "{}.png".format(idx_d))
                            util.save_img(np.hstack((lq_img, rlt_img, gt_img)),
                                          save_img_path)

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update("Test {} - {}".format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = "# Validation # PSNR: {:.4e}:".format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += " {}: {:.4e}".format(k, v)
                        logger.info(log_s)
                        if opt["use_tb_logger"] and "debug" not in opt["name"]:
                            tb_logger.add_scalar("psnr_avg", psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            # save models and training states
            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model.save("latest")
        logger.info("End of training.")
        tb_logger.close()
Example #26
0
def main(opts):
    ################## configurations #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus
    with open(opts.cfg, mode='r') as f:
        cfgs = yaml.load(f)
    flip_test, save_imgs = cfgs['test']['flip_test'], cfgs['test']['save_imgs']
    scale = cfgs['datasets']['scale']
    crop_border = cfgs['test']['crop_border']
    nframe = cfgs['datasets']['N_frames']
    border_frame = nframe // 2
    half_nframe = nframe // 2
    padding = cfgs['datasets']['padding']
    flow_path = cfgs['path']['pretrain_model_F']
    cache_all_imgs = cfgs['datasets']['cache_data']
    ################## model files ####################
    model_dir = opts.model_dir
    if osp.isfile(model_dir):
        model_names = [osp.basename(model_dir)]
        model_dir = osp.dirname(model_dir)
    elif osp.isdir(model_dir):
        model_names = [
            x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0])
        ]
        model_names = sorted(model_names, key=lambda x: int(x.split("_")[0]))
    else:
        raise IOError('Invalid model_dir: {}'.format(model_dir))

    ################## dataset ########################
    test_subs = sorted(os.listdir(opts.test_dir))
    gt_subs = os.listdir(opts.gt_dir)
    valid_test_subs = [sub in gt_subs for sub in test_subs]
    assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format(
        opts.test_dir)
    if cache_all_imgs:
        print('Cacheing all testing images ...')
        all_imgs = {}
        for sub in test_subs:
            print('Reading sub-folder: {} ...'.format(sub))
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            all_imgs[sub] = {'test': [], 'gt': []}
            im_names = sorted(os.listdir(test_sub_dir))
            for i, name in enumerate(im_names):
                test_im_path = osp.join(test_sub_dir, name)
                gt_im_path = osp.join(gt_sub_dir, name)
                test_im = cv2.imread(test_im_path,
                                     cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                test_im = test_im.astype(np.float32).transpose(
                    (2, 0, 1)) / 255.
                all_imgs[sub]['test'].append(test_im)
                gt_im = cv2.imread(gt_im_path,
                                   cv2.IMREAD_UNCHANGED).astype(np.float32)
                all_imgs[sub]['gt'].append(gt_im)

    #################### model ########################
    model = define_G(cfgs)
    netF = define_F(cfgs)
    netF.load_state_dict(torch.load(flow_path), strict=True)
    model = model.to(device)
    netF = netF.to(device)

    all_psnrs = []
    for model_name in model_names:
        model_path = osp.join(model_dir, model_name)
        exp_name = model_name.split('_')[0]
        save_folder = osp.join(opts.save_dir, exp_name)
        util.mkdirs(save_folder)
        util.setup_logger(exp_name,
                          save_folder,
                          'test',
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger(exp_name)

        #### log info
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))

        #### load model parameters
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()

        avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], []
        evaled_subs = []

        # for each subfolder
        for sub in test_subs:
            evaled_subs.append(sub)
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            img_names = sorted(os.listdir(test_sub_dir))
            max_idx = len(img_names)
            if save_imgs:
                save_subfolder = osp.join(save_folder, sub)
                util.mkdirs(save_subfolder)

            #### get LQ and GT images
            if not cache_all_imgs:
                img_LQs, img_GTs = [], []
                for i, name in enumerate(img_names):
                    test_im_path = osp.join(test_sub_dir, name)
                    gt_im_path = osp.join(gt_sub_dir, name)
                    test_im = cv2.imread(test_im_path,
                                         cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                    test_im = test_im.astype(np.float32).transpose(
                        (2, 0, 1)) / 255.
                    gt_im = cv2.imread(gt_im_path,
                                       cv2.IMREAD_UNCHANGED).astype(np.float32)
                    img_LQs.append(test_im)
                    img_GTs.append(gt_im)
            else:
                img_LQs = all_imgs[sub]['test']
                img_GTs = all_imgs[sub]['gt']

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
            previous = None
            for i in range(0, max_idx):
                select_idxs = data_util.index_generation(i,
                                                         max_idx,
                                                         nframe,
                                                         padding=padding)
                lqs = torch.from_numpy(
                    np.stack([img_LQs[k] for k in select_idxs]))
                lqs = lqs.unsqueeze(0).to(device)
                upX = F.interpolate(lqs[:, half_nframe, :, :, :],
                                    scale_factor=scale,
                                    mode='bilinear',
                                    align_corners=False)
                if i == 0:
                    wrap_img = upX
                else:
                    if previous is not None:
                        pre = previous.clone()
                    else:
                        pre = img_GTs[i - 1][:, :, (2, 1, 0)].transpose(
                            (2, 0, 1)) / 255.
                        pre = torch.from_numpy(pre).unsqueeze(0).to(device)
                    first = F.interpolate(lqs[:, half_nframe - 1, :, :, :],
                                          scale_factor=scale,
                                          mode='bilinear',
                                          align_corners=False)
                    second = upX
                    b_flow = Flow_arch.estimate_flow(netF, second, first)
                    wrap_img, _ = Flow_arch.wraping(pre, b_flow)

                with torch.no_grad():
                    output = model(lqs, wrap_img, scale)
                    output = output[0].detach()
                previous = output

                output = util.tensor2img(output).astype(np.float32)

                if save_imgs:
                    cv2.imwrite(
                        osp.join(save_subfolder,
                                 '{}.png'.format(img_names[i])),
                        output.astype(np.uint8))

                # calculate PSNR
                GT = np.copy(img_GTs[i])
                output = util.crop_border(output, crop_border)
                GT = util.crop_border(GT, crop_border)

                crt_psnr = util.calculate_psnr(output, GT)
                logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                    i + 1, img_names[i], crt_psnr))

                if i >= border_frame and i < max_idx - border_frame:  # center frames
                    avg_psnr_center += crt_psnr
                    N_center += 1
                else:  # border frames
                    avg_psnr_border += crt_psnr
                    N_border += 1

            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
            avg_psnrs.append(avg_psnr)
            avg_psnr_centers.append(avg_psnr_center)
            avg_psnr_borders.append(avg_psnr_border)

            logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                        'Center PSNR: {:.6f} dB for {} frames; '
                        'Border PSNR: {:.6f} dB for {} frames.'.format(
                            sub, avg_psnr, (N_center + N_border),
                            avg_psnr_center, N_center, avg_psnr_border,
                            N_border))

        logger.info('################ Tidy Outputs ################')
        for subfolder_name, psnr, psnr_center, psnr_border in zip(
                evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders):
            logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                        'Center PSNR: {:.6f} dB. '
                        'Border PSNR: {:.6f} dB.'.format(
                            subfolder_name, psnr, psnr_center, psnr_border))
        logger.info('################ Final Results ################')
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))
        logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                    'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                        sum(avg_psnrs) / len(avg_psnrs), len(test_subs),
                        sum(avg_psnr_centers) / len(avg_psnr_centers),
                        sum(avg_psnr_borders) / len(avg_psnr_borders)))

        all_psnrs.append(avg_psnrs + [model_name])
    with open(osp.join(opts.save_dir, 'all_psnrs.txt'), 'w') as f:
        for psnrs in all_psnrs:
            f.write("{:>14s}: {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} \n".format(
                psnrs[-1],
                sum(psnrs[:-1]) / 4., psnrs[0], psnrs[1], psnrs[2], psnrs[3]))