Пример #1
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    data_mode = 'ai4khdr_test'
    flip_test = False

    ############################################################################
    #### model
    #################
    if data_mode == 'ai4khdr_test':
        model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth'
    else:
        raise NotImplementedError
    N_in = 5
    front_RBs = 5
    back_RBs = 10
    predeblur, HR_in = False, False
    model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in)

    ############################################################################
    #### dataset
    #################
    if data_mode == 'ai4khdr_test':
        test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/test/540p_frames'
    else:
        raise NotImplementedError

    ############################################################################
    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'ai4khdr_test':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp())
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=False)
    model.eval()
    model = model.to(device)

    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
        logger.info('Folder {}'.format(subfolder_name))

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Пример #2
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)
            
            img_c = util.tensor2img(visuals['img_c'])  # uint8
            img_s = util.tensor2img(visuals['img_s'])  # uint8
            img_p = util.tensor2img(visuals['img_p'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_c_img_path = os.path.join(dataset_dir, img_name + suffix + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + suffix + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + suffix + '_p.png')                
            else:
                save_c_img_path = os.path.join(dataset_dir, img_name + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + '_p.png')
            
            util.save_img(img_c, save_c_img_path)
            util.save_img(img_s, save_s_img_path)
            util.save_img(img_p, save_p_img_path)
            

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals['HR'])
                gt_img = gt_img / 255.
                sr_img = img_c / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Пример #3
0
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(
        description="Test with EDVR, requre path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')
        GT_dataset_folder = os.path.join(args.input, 'GT')
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            # print(select_idx)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
Пример #4
0
def create_test_png(model_path, device, gpu_id, opt, subfolder_l, save_folder, save_imgs,
                    frame_notation, N_in, PAD, flip_test, end, total_run_time, logger, padding):


    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    #if (torch.cuda.is_available()):
    model = model.cuda(gpu_id)


    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        # subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        #if not os.path.exists(subfolder_GT):
        #    continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        #subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        #img_GT_l = []
        #for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #    img_GT_l.append(data_util.read_img(None, img_GT_path))

        #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):

            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding)

            if not log1 == None:
                logger.info('screen change')
                logger.info(nota)
                logger.info(log1)
                logger.info(log2)



            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).cuda(gpu_id)  # 960 x 540


            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4] # 960
            intHeight_ori = imgs_in.shape[3] # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256



            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)# N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY )
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W
                #X0 = X0.cuda(gpu_id)

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)


            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " + str(
                time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            #y_all = y_all / 255.
            #GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            #if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #    GT = data_util.bgr2ycbcr(GT, only_y=True)
            #    y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            #y_all, GT = util.crop_border([y_all, GT], crop_border)
            #crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder, img_idx + 1, img_name))
Пример #5
0
def main():
    ###### SFTMD train ######
    #### setup options
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-opt_F",
        type=str,
        default="options/train/SFTMD/train_SFTMD_x4.yml",
        help="Path to option YMAL file of SFTMD_Net.",
    )
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_F = option.dict_to_nonedict(opt_F)

    #### random seed
    seed = opt_F["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # load PCA matrix of enough kernel
    print("load PCA matrix")
    pca_matrix = torch.load(
        "../../../pca_matrix/IKC/pca_matrix.pth",
        map_location=lambda storage, loc: storage,
    )
    print("PCA matrix shape: {}".format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt_F["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt_F["dist"] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### loading resume state if exists
    if opt_F["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt_F["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id),
        )
        option.check_resume(opt_F,
                            resume_state["iter"])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:
        if resume_state is None:
            util.mkdir_and_rename(
                opt_F["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_F["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))
            os.system("rm ./log")
            os.symlink(os.path.join(opt_F["path"]["experiments_root"], ".."),
                       "./log")

        # config loggers. Before it, the log will not work
        util.setup_logger(
            "base",
            opt_F["path"]["log"],
            "train_" + opt_F["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        util.setup_logger(
            "val",
            opt_F["path"]["log"],
            "val_" + opt_F["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt_F))
        # tensorboard logger
        if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    "You are using PyTorch {}. Tensorboard will use [tensorboardX]"
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(
                log_dir="log/{}/tb_logger/".format(opt_F["name"]))
    else:
        util.setup_logger("base",
                          opt_F["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt_F["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt_F["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_F["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_F,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_F, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)

    #### resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model_F.resume_training(
            resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_F["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(
                opt_F["scale"],
                pca_matrix,
                random=True,
                para_input=opt_F["code_length"],
                kernel=opt_F["kernel_size"],
                noise=False,
                cuda=True,
                sig=opt_F["sig"],
                sig_min=opt_F["sig_min"],
                sig_max=opt_F["sig_max"],
                rate_iso=1.0,
                scaling=3,
                rate_cln=0.2,
                noise_high=0.0,
            )
            LR_img, ker_map = prepro(train_data["GT"])

            #### update learning rate, schedulers
            model_F.update_learning_rate(
                current_step, warmup_iter=opt_F["train"]["warmup_iter"])

            #### training
            model_F.feed_data(train_data, LR_img, ker_map)
            model_F.optimize_parameters(current_step)

            #### log
            if current_step % opt_F["logger"]["print_freq"] == 0:
                logs = model_F.get_current_log()
                message = "<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format(
                    epoch, current_step, model_F.get_current_learning_rate())
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt_F["train"]["val_freq"] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    idx += 1
                    #### preprocessing for LR_img and kernel map
                    prepro = util.SRMDPreprocessing(
                        opt_F["scale"],
                        pca_matrix,
                        random=True,
                        para_input=opt_F["code_length"],
                        kernel=opt_F["kernel_size"],
                        noise=False,
                        cuda=True,
                        sig=opt_F["sig"],
                        sig_min=opt_F["sig_min"],
                        sig_max=opt_F["sig_max"],
                        rate_iso=1.0,
                        scaling=3,
                        rate_cln=0.2,
                        noise_high=0.0,
                    )
                    LR_img, ker_map = prepro(val_data["GT"])

                    model_F.feed_data(val_data, LR_img, ker_map)
                    model_F.test()

                    visuals = model_F.get_current_visuals()
                    sr_img = util.tensor2img(visuals["SR"])  # uint8
                    gt_img = util.tensor2img(visuals["GT"])  # uint8

                    # Save SR images for reference
                    img_name = os.path.splitext(
                        os.path.basename(val_data["LQ_path"][0]))[0]
                    # img_dir = os.path.join(opt_F['path']['val_images'], img_name)
                    img_dir = os.path.join(opt_F["path"]["val_images"],
                                           str(current_step))
                    util.mkdir(img_dir)

                    save_img_path = os.path.join(
                        img_dir,
                        "{:s}_{:d}.png".format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt_F["scale"]
                    gt_img = gt_img / 255.0
                    sr_img = sr_img / 255.0
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                logger_val = logging.getLogger("val")  # validation logger
                logger_val.info(
                    "<epoch:{:3d}, iter:{:8,d}> psnr: {:.6f}".format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
                    tb_logger.add_scalar("psnr", avg_psnr, current_step)

            #### save models and training states
            if current_step % opt_F["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model_F.save(current_step)
                    model_F.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model_F.save("latest")
        logger.info("End of SFTMD training.")
Пример #6
0
    test_results['ssim_y'] = []

    time_total = 0.0
    time_cnt = 0
    for data in test_loader:
        need_GT = False if test_loader.dataset.opt[
            'dataroot_GT'] is None else True
        model.feed_data(data, need_GT=need_GT)
        video_name = data['key'][0].split('_')[1]
        category_name = data['key'][0].split('_')[0]

        model.test()

        visuals = model.get_current_visuals(need_GT=need_GT)

        sr_img = util.tensor2img(visuals['rlt'])  # uint8

        # save images
        suffix = opt['suffix']
        if suffix:
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            os.makedirs(os.path.join(dataset_dir, category_name, video_name),
                        exist_ok=True)
            save_img_path = osp.join(dataset_dir, category_name, video_name,
                                     'im4.png')
        util.save_img(sr_img, save_img_path)

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
Пример #7
0
        # Corrector test
        for step in range(opt_C['step']):
            step += 1
            # Test SFTMD to produce SR images
            model_F.feed_data(test_data, LR_img, est_ker_map)
            model_F.test()
            F_visuals = model_F.get_current_visuals()
            SR_img = F_visuals['Batch_SR']

            model_C.feed_data(SR_img, est_ker_map, ker_map)
            model_C.test()
            C_visuals = model_C.get_current_visuals()
            est_ker_map = C_visuals['Batch_est_ker_map']

            sr_img = util.tensor2img(F_visuals['SR'])  # uint8

            # save images
            suffix = opt_F['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir, str(step),
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, str(step),
                                             img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(F_visuals['GT'])
                gt_img = gt_img / 255.
Пример #8
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    opt['dist'] = False
    rank = -1
    print('Disabled distributed training.')

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        resume_state_path, _ = get_resume_paths(opt)

        # distributed resuming: all load into default GPU
        if resume_state_path is None:
            resume_state = None
        else:
            device_id = torch.cuda.current_device()
            resume_state = torch.load(resume_state_path,
                                      map_location=lambda storage, loc: storage.cuda(device_id))
            option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))

        # tensorboard logger
        if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            conf_name = basename(args.opt).replace(".yml", "")
            exp_dir = opt['path']['experiments_root']
            log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
            log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
            tb_logger_train = SummaryWriter(log_dir=log_dir_train)
            tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            print('Dataset created')
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    current_step = 0 if resume_state is None else resume_state['iter']
    model = create_model(opt, current_step)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    timer = Timer()
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    timerData = TickTock()

    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)

        timerData.tick()
        for _, train_data in enumerate(train_loader):
            timerData.tock()
            current_step += 1
            if current_step > total_iters:
                break

            #### training
            model.feed_data(train_data)

            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            try:
                nll = model.optimize_parameters(current_step)
            except RuntimeError as e:
                print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
                print(e)

            if nll is None:
                nll = 0

            #### log
            def eta(t_iter):
                return (t_iter * (opt['train']['niter'] - current_step)) / 3600

            if current_step % opt['logger']['print_freq'] == 0 \
                    or current_step - (resume_state['iter'] if resume_state else 0) < 25:
                avg_time = timer.get_average_and_reset()
                avg_data_time = timerData.get_average_and_reset()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
                    eta(avg_time), nll)
                print(message)
            timer.tick()
            # Reduce number of logs
            if current_step % 5 == 0:
                tb_logger_train.add_scalar('loss/nll', nll, current_step)
                tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
                tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
                tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
                tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
                for k, v in model.get_current_log().items():
                    tb_logger_train.add_scalar(k, v, current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                nlls = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)

                    nll = model.test()
                    if nll is None:
                        nll = 0
                    nlls.append(nll)

                    visuals = model.get_current_visuals()

                    sr_img = None
                    # Save SR images for reference
                    if hasattr(model, 'heats'):
                        for heat in model.heats:
                            for i in range(model.n_sample):
                                sr_img = util.tensor2img(visuals['SR', heat, i])  # uint8
                                save_img_path = os.path.join(img_dir,
                                                             '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
                                                                                                    current_step,
                                                                                                    int(heat * 100), i))
                                util.save_img(sr_img, save_img_path)
                    else:
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)
                    assert sr_img is not None

                    # Save LQ images for reference
                    save_img_path_lq = os.path.join(img_dir,
                                                    '{:s}_LQ.png'.format(img_name))
                    if not os.path.isfile(save_img_path_lq):
                        lq_img = util.tensor2img(visuals['LQ'])  # uint8
                        util.save_img(
                            cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
                                       interpolation=cv2.INTER_NEAREST),
                            save_img_path_lq)

                    # Save GT images for reference
                    gt_img = util.tensor2img(visuals['GT'])  # uint8
                    save_img_path_gt = os.path.join(img_dir,
                                                    '{:s}_GT.png'.format(img_name))
                    if not os.path.isfile(save_img_path_gt):
                        util.save_img(gt_img, save_img_path_gt)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_nll = sum(nlls) / len(nlls)

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                    epoch, current_step, avg_psnr))

                # tensorboard logger
                tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
                tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)

                tb_logger_train.flush()
                tb_logger_valid.flush()

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

            timerData.tick()

    with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
        f.write("TRAIN_DONE")

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Пример #9
0
def test(model, test_data_loader, save=False):
    # pdb.set_trace()
    psnr_cal = PSNR()
    msssim_cal = MS_SSIM(data_range=1.0)
    ssim_cal = SSIM(data_range=1.0)
    psnr_meter_mono, psnr_meter_resL, psnr_meter_resR = AverageMeter(
    ), AverageMeter(), AverageMeter()
    msssim_meter_mono, msssim_meter_resL, msssim_meter_resR = AverageMeter(
    ), AverageMeter(), AverageMeter()
    ssim_meter_mono, ssim_meter_resL, ssim_meter_resR = AverageMeter(
    ), AverageMeter(), AverageMeter()

    with torch.no_grad():
        model.eval()
        for i, (left, right,
                original_shape) in enumerate(tqdm(test_data_loader)):
            batch_size = left.shape[0]
            assert batch_size == 1, 'Only support batch of 1 now!'
            left = left.cuda(non_blocking=True)
            right = right.cuda(non_blocking=True)
            res_mono, res_l, res_r = model(left, right)
            original_shape = [x.item() for x in original_shape]

            def fun(x):
                x = test_data_loader.dataset.depad_tensor(x, original_shape)
                x = inverse_normalize(x).clamp(0.0, 1.0)
                return x

            left, right, res_mono, res_l, res_r = fun(left), fun(right), fun(
                res_mono), fun(res_l), fun(res_r)

            name = test_data_loader.dataset.frames[i][0].split('/')[-1].split(
                '.')[0]

            # pdb.set_trace()
            psnr_meter_mono.update(psnr_cal(res_mono, left), n=batch_size)
            psnr_meter_resL.update(psnr_cal(res_l, left), n=batch_size)
            psnr_meter_resR.update(psnr_cal(res_r, right), n=batch_size)
            msssim_meter_mono.update(msssim_cal(res_mono, left), n=batch_size)
            msssim_meter_resL.update(msssim_cal(res_l, left), n=batch_size)
            msssim_meter_resR.update(msssim_cal(res_r, right), n=batch_size)
            ssim_meter_mono.update(ssim_cal(res_mono, left), n=batch_size)
            ssim_meter_resL.update(ssim_cal(res_l, left), n=batch_size)
            ssim_meter_resR.update(ssim_cal(res_r, right), n=batch_size)

            if save:
                for x, last_fix in zip(
                    [res_mono, res_l, res_r],
                    ["_res_mono.png", "_res_l.png", "_res_r.png"]):
                    cv2.imwrite(
                        join(args.save_folder, args.data_name,
                             name + last_fix),
                        util.tensor2img(x)[..., ::-1] * 255)

    logger.info(
        '==>Mononized: \n'
        'PSNR: {psnr_meter_mono.avg:.2f}\n'
        'MS-SSIM: {msssim_meter_mono.avg:.2f}\n'
        'SSIM: {ssim_meter_mono.avg:.2f}\n'
        '==>restored: \n'
        'PSNR: {psnr_meter_resL.avg:.2f}, {psnr_meter_resR.avg:.2f}\n'
        'MS-SSIM: {msssim_meter_resL.avg:.2f}, {msssim_meter_resR.avg:.2f}\n'
        'SSIM: {ssim_meter_resL.avg:.2f}, {ssim_meter_resR.avg:.2f}'.format(
            psnr_meter_mono=psnr_meter_mono,
            msssim_meter_mono=msssim_meter_mono,
            ssim_meter_mono=ssim_meter_mono,
            psnr_meter_resL=psnr_meter_resL,
            msssim_meter_resL=msssim_meter_resL,
            ssim_meter_resL=ssim_meter_resL,
            psnr_meter_resR=psnr_meter_resR,
            msssim_meter_resR=msssim_meter_resR,
            ssim_meter_resR=ssim_meter_resR))
Пример #10
0
            model.feed_data(data)
            if test_set_name == 'Vid4':
                folder = osp.split(osp.dirname(data['GT_path'][0][0]))[1]
            else:
                folder = ''
            util.mkdir(osp.join(dataset_dir, folder))

            model.test()
            visuals = model.get_current_visuals()

            if test_set_name == 'Vimeo90K':
                center = visuals['SR'].shape[0] // 2
                img_path = data['GT_path'][0]
                img_name = osp.splitext(osp.basename(img_path))[0]

                sr_img = util.tensor2img(visuals['SR'])  # uint8
                gt_img = util.tensor2img(visuals['GT'][center])  # uint8
                lr_img = util.tensor2img(visuals['LR'])  # uint8
                lrgt_img = util.tensor2img(visuals['LR_ref'][center])  # uint8

                test_results = cal_pnsr_ssim(sr_img, gt_img, lr_img, lrgt_img)

            else:
                t_step = visuals['SR'].shape[0]
                for i in range(t_step):
                    img_path = data['GT_path'][i][0]
                    img_name = osp.splitext(osp.basename(img_path))[0]

                    sr_img = util.tensor2img(visuals['SR'][i])  # uint8
                    gt_img = util.tensor2img(visuals['GT'][i])  # uint8
                    lr_img = util.tensor2img(visuals['LR'][i])  # uint8
Пример #11
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.
    pytorch_ver = get_pytorch_ver()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(
                glob.glob(
                    os.path.normpath(opt['path']['resume_state']) +
                    '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(
                logdir='../tb_logger/' +
                opt['name'])  #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(
                log_dir='../tb_logger/' +
                opt['name'])  #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(
            opt['train']
        )  # updated schedulers in case JSON configuration has changed
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for n, train_data in enumerate(train_loader, start=1):
            current_step += 1
            if current_step > total_iters:
                break

            if pytorch_ver == "pre":  #Order for PyTorch ver < 1.1.0
                # update learning rate
                model.update_learning_rate(current_step - 1)
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
            elif pytorch_ver == "post":  #Order for PyTorch ver > 1.1.0
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
                # update learning rate
                model.update_learning_rate(current_step - 1)
            else:
                print('Error identifying PyTorch version. ', torch.__version__)
                break

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # save models and training states (changed to save models before validation)
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                model.save(current_step)
                model.save_training_state(epoch + (n >= len(train_loader)),
                                          current_step)
                logger.info('Models and training states saved.')

            # validation
            if val_loader and current_step % opt['train']['val_freq'] == 0:
                avg_psnr_c = 0.0
                avg_psnr_s = 0.0
                avg_psnr_p = 0.0

                avg_ssim_c = 0.0
                avg_ssim_s = 0.0
                avg_ssim_p = 0.0

                idx = 0

                val_sr_imgs_list = []
                val_gt_imgs_list = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    if opt['datasets']['train'][
                            'znorm']:  # If the image range is [-1,1]
                        img_c = util.tensor2img(visuals['img_c'],
                                                min_max=(-1, 1))  # uint8
                        img_s = util.tensor2img(visuals['img_s'],
                                                min_max=(-1, 1))  # uint8
                        img_p = util.tensor2img(visuals['img_p'],
                                                min_max=(-1, 1))  # uint8
                        gt_img = util.tensor2img(visuals['HR'],
                                                 min_max=(-1, 1))  # uint8
                    else:  # Default: Image range is [0,1]
                        img_c = util.tensor2img(visuals['img_c'])  # uint8
                        img_s = util.tensor2img(visuals['img_s'])  # uint8
                        img_p = util.tensor2img(visuals['img_p'])  # uint8
                        gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_c_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_c.png'.format(img_name, current_step))
                    save_s_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_s.png'.format(img_name, current_step))
                    save_p_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_d.png'.format(img_name, current_step))

                    util.save_img(img_c, save_c_img_path)
                    util.save_img(img_s, save_s_img_path)
                    util.save_img(img_p, save_p_img_path)

                    # calculate PSNR, SSIM and LPIPS distance
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    #sr_img = sr_img / 255. #ESRGAN
                    #PPON

                    sr_img_c = img_c / 255.  #C
                    sr_img_s = img_s / 255.  #S
                    sr_img_p = img_p / 255.  #D

                    # For training models with only one channel ndim==2, if RGB ndim==3, etc.
                    if gt_img.ndim == 2:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:  # gt_img.ndim == 3, # Default: RGB images
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                    # All 3 output images will have the same dimensions
                    if sr_img_c.ndim == 2:
                        cropped_sr_img_c = sr_img_c[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        cropped_sr_img_s = sr_img_s[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        cropped_sr_img_p = sr_img_p[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                    else:  #sr_img_c.ndim == 3, # Default: RGB images
                        cropped_sr_img_c = sr_img_c[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                        cropped_sr_img_s = sr_img_s[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                        cropped_sr_img_p = sr_img_p[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]

                    avg_psnr_c += util.calculate_psnr(cropped_sr_img_c * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_c += util.calculate_ssim(cropped_sr_img_c * 255,
                                                      cropped_gt_img * 255)

                    avg_psnr_s += util.calculate_psnr(cropped_sr_img_s * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_s += util.calculate_ssim(cropped_sr_img_s * 255,
                                                      cropped_gt_img * 255)

                    avg_psnr_p += util.calculate_psnr(cropped_sr_img_p * 255,
                                                      cropped_gt_img * 255)
                    avg_ssim_p += util.calculate_ssim(cropped_sr_img_p * 255,
                                                      cropped_gt_img * 255)

                    # LPIPS only works for RGB images
                    # Using only the final perceptual image to calulate LPIPS
                    if sr_img_c.ndim == 3:
                        #avg_lpips += lpips.calculate_lpips([cropped_sr_img], [cropped_gt_img]) # If calculating for each image
                        val_gt_imgs_list.append(
                            cropped_gt_img
                        )  # If calculating LPIPS only once for all images
                        val_sr_imgs_list.append(
                            cropped_sr_img_p
                        )  # If calculating LPIPS only once for all images

                # PSNR
                avg_psnr_c = avg_psnr_c / idx
                avg_psnr_s = avg_psnr_s / idx
                avg_psnr_p = avg_psnr_p / idx
                # SSIM
                avg_ssim_c = avg_ssim_c / idx
                avg_ssim_s = avg_ssim_s / idx
                avg_ssim_p = avg_ssim_p / idx
                # LPIPS
                #avg_lpips = avg_lpips / idx # If calculating for each image
                avg_lpips = lpips.calculate_lpips(
                    val_sr_imgs_list, val_gt_imgs_list
                )  # If calculating only once for all images

                # log
                # PSNR
                logger.info('# Validation # PSNR_c: {:.5g}'.format(avg_psnr_c))
                logger.info('# Validation # PSNR_s: {:.5g}'.format(avg_psnr_s))
                logger.info('# Validation # PSNR_p: {:.5g}'.format(avg_psnr_p))
                # SSIM
                logger.info('# Validation # SSIM_c: {:.5g}'.format(avg_ssim_c))
                logger.info('# Validation # SSIM_s: {:.5g}'.format(avg_ssim_s))
                logger.info('# Validation # SSIM_p: {:.5g}'.format(avg_ssim_p))
                # LPIPS
                logger.info('# Validation # LPIPS: {:.5g}'.format(avg_lpips))

                logger_val = logging.getLogger('val')  # validation logger
                # logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format(
                # epoch, current_step, avg_psnr_c, avg_psnr_s, avg_psnr_p))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}>'.format(
                    epoch, current_step))
                logger_val.info(
                    'psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format(
                        avg_psnr_c, avg_psnr_s, avg_psnr_p))
                logger_val.info(
                    'ssim_c: {:.5g}, ssim_s: {:.5g}, ssim_p: {:.5g}'.format(
                        avg_ssim_c, avg_ssim_s, avg_ssim_p))
                logger_val.info('lpips: {:.5g}'.format(avg_lpips))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr_c', avg_psnr_c, current_step)
                    tb_logger.add_scalar('psnr_s', avg_psnr_s, current_step)
                    tb_logger.add_scalar('psnr_p', avg_psnr_p, current_step)
                    tb_logger.add_scalar('ssim_c', avg_ssim_c, current_step)
                    tb_logger.add_scalar('ssim_s', avg_ssim_s, current_step)
                    tb_logger.add_scalar('ssim_p', avg_ssim_p, current_step)
                    tb_logger.add_scalar('lpips', avg_lpips, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Пример #12
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            # model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    # model.feed_data2(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            model.update_learning_rate()

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Пример #13
0
        np.ascontiguousarray(np.transpose(
            img_Ref, (2, 0, 1)))).float().unsqueeze(0).cuda()

    img_Ref_DUX4 = cv2.imread(osp.join(Ref_DUX4_path, use_name)) / 255.
    img_Ref_DUX4 = img_Ref_DUX4[:, :, [2, 1, 0]]
    img_Ref_DUX4 = torch.from_numpy(
        np.ascontiguousarray(np.transpose(
            img_Ref_DUX4, (2, 0, 1)))).float().unsqueeze(0).cuda()

    with torch.no_grad():
        begin_time = time.time()
        output = model(img_LR, img_LR_UX4, img_Ref, img_Ref_DUX4)
        end_time = time.time()
        stat_time += (end_time - begin_time)

    output = util.tensor2img(output.squeeze(0))

    # save images
    save_path_name = osp.join(
        save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name))
    util.save_img(output, save_path_name)

    # calculate PSNR
    sr_img, gt_img = util.crop_border([output, img_GT], scale)
    PSNR_avg += util.calculate_psnr(sr_img, gt_img)
    SSIM_avg += util.calculate_ssim(sr_img, gt_img)

print('average PSNR: ', PSNR_avg / len(img_list))
print('average SSIM: ', SSIM_avg / len(img_list))
print('time: ', stat_time / len(img_list))
Пример #14
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    Avg_train_psnr = AverageMeter()
    if opt['datasets']['train']['color'] == 'YUV':
        Avg_train_yuv_psnr = AverageMeter()
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        # reset psnr
        Avg_train_psnr.reset()
        if opt['datasets']['train']['color'] == 'YUV':
            Avg_train_yuv_psnr.reset()

        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8  NCHW

                print('GT_img', GT_img.shape)
                print('LQ', LQ.shape)

                if opt['datasets']['train']['color'] == 'YUV':
                    GT_img = data_util.ycbcr2rgb(GT_img)

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))

                if opt['datasets']['train']['color'] == 'YUV':
                    util.save_img(GT_img, save_img_path, mode='RGB')
                else:
                    util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[:, i, ...])  # uint8
                    if opt['datasets']['train']['color'] == 'YUV':
                        LQ_img = data_util.ycbcr2rgb(LQ_img)
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    if opt['datasets']['train']['color'] == 'YUV':
                        util.save_img(LQ_img, save_img_path, mode='RGB')
                    else:
                        util.save_img(LQ_img, save_img_path)

                if (train_idx >= 10):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            visuals = model.get_current_visuals(need_GT=True, save=False)

            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
            gt_img = util.tensor2img(visuals['GT'])  # uint8

            if opt['datasets']['train']['color'] == 'YUV':
                yuv_psnr = util.calculate_psnr(rlt_img, gt_img)
                rlt_img = data_util.ycbcr2rgb(rlt_img)
                gt_img = data_util.ycbcr2rgb(gt_img)

            # calculate PSNR
            psnr = util.calculate_psnr(rlt_img, gt_img)

            if opt['train']['pixel_criterion'] == 'cb+ssim':
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_psnr.update(psnr, 1)
                if opt['datasets']['train']['color'] == 'YUV':
                    Avg_train_yuv_psnr.update(yuv_psnr, 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)
                Avg_train_psnr.update(psnr, 1)
                if opt['datasets']['train']['color'] == 'YUV':
                    Avg_train_yuv_psnr.update(yuv_psnr, 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
                message_train_loss += ' psnr_inst : {:.2f}'.format(psnr)
                message_train_loss += ' psnr_avg : {:.2f}'.format(
                    Avg_train_psnr.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
                if opt['datasets']['train']['color'] == 'YUV':
                    message_train_loss += ' yuv_psnr_inst : {:.2f}'.format(
                        yuv_psnr)
                message_train_loss += ' psnr_inst : {:.2f}'.format(psnr)
                if opt['datasets']['train']['color'] == 'YUV':
                    message_train_loss += ' yuv_psnr_avg : {:.2f}'.format(
                        Avg_train_yuv_psnr.avg)
                message_train_loss += ' psnr_avg : {:.2f}'.format(
                    Avg_train_psnr.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                # tensorboard logger - avg part
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar('train_avg_loss',
                                                 Avg_train_loss.avg,
                                                 current_step)
                            if opt['datasets']['train']['color'] == 'YUV':
                                tb_logger.add_scalar('yuv_psnr_avg',
                                                     Avg_train_yuv_psnr.avg,
                                                     current_step)
                            tb_logger.add_scalar('psnr_avg',
                                                 Avg_train_psnr.avg,
                                                 current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        if 'debug' in opt['name']:
                            print('idx', idx)
                        #     if (idx >= 3):
                        #         break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        if opt['datasets']['train']['color'] == 'YUV':
                            rlt_img = data_util.ycbcr2rgb(rlt_img)
                            gt_img = data_util.ycbcr2rgb(gt_img)

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        # ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim = 0  # to do save time do not use it
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # SSIM: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        # if 'debug' in opt['name']:
                        #     if (val_inx >= 5):
                        #         break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        # print(val_data['LQs'].size())

                        # H_S = val_data['LQs'].size(3)  # 540
                        # W_S = val_data['LQs'].size(4)  # 960

                        # print(H_S)
                        # print(W_S)

                        # blank_1_S = 0
                        # blank_2_S = 0

                        # print(val_data['LQs'][0,2,0,:,:].size())

                        # for i in range(H_S):
                        #     if not sum(val_data['LQs'][0,2,0,i,:]) == 0:
                        #         blank_1_S = i - 1
                        #         # assert not sum(data_S[:, :, 0][i+1]) == 0
                        #         break

                        # for i in range(H_S):
                        #     if not sum(val_data['LQs'][0,2,0,:,H_S - i - 1]) == 0:
                        #         blank_2_S = (H_S - 1) - i - 1
                        #         # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                        #         break
                        # print('LQ :', blank_1_S, blank_2_S)

                        # if blank_1_S == -1:
                        #     print('LQ has no blank')
                        #     blank_1_S = 0
                        #     blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        # print("LQ",val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        # blank_1_L = blank_1_S << 2
                        # blank_2_L = blank_2_S << 2
                        # print(blank_1_L, blank_2_L)

                        # print(model.fake_H.size())

                        # if not blank_1_S == 0:
                        #     # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                        #     model.fake_H[:, :,0:blank_1_L, :] = 0
                        #     model.fake_H[:, :,blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        if opt['datasets']['train']['color'] == 'YUV':
                            rlt_img = data_util.ycbcr2rgb(rlt_img)
                            gt_img = data_util.ycbcr2rgb(gt_img)

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        # ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim = 0  # to do save time do not use it
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            if rank <= 0:
                logger.info(
                    "Finished an epoch, Check and Save the model weights")
                # we check the validation loss instead of training loss. OK~
                if saved_total_loss >= val_loss_total_avg:
                    saved_total_loss = val_loss_total_avg
                    #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                    model.save('best')
                    logger.info(
                        "Best Weights updated for decreased validation loss")

                else:
                    logger.info(
                        "Weights Not updated for undecreased validation loss")

                if saved_total_PSNR <= psnr_total_avg:
                    saved_total_PSNR = psnr_total_avg
                    model.save('bestPSNR')
                    logger.info(
                        "Best Weights updated for increased validation PSNR")

                else:
                    logger.info(
                        "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
Пример #15
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    parser.add_argument('-single_GPU',
                        action='store_true',
                        help='Utilize only one GPU')
    if parser.parse_args().single_GPU:
        available_GPUs = util.Assign_GPU()
    else:
        available_GPUs = util.Assign_GPU(max_GPUs=None)
    opt = option.parse(parser.parse_args().opt,
                       is_train=True,
                       batch_size_multiplier=len(available_GPUs))

    if not opt['train']['resume']:
        util.mkdir_and_rename(
            opt['path']
            ['experiments_root'])  # Modify experiment name if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
            not key == 'pretrained_model_G' and not key == 'pretrained_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.
    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])
    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            max_accumulation_steps = max([
                opt['train']['grad_accumulation_steps_G'],
                opt['train']['grad_accumulation_steps_D']
            ])
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'] *
                              max_accumulation_steps)  #-current_step
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    # Create model
    if max_accumulation_steps != 1:
        model = create_model(opt, max_accumulation_steps)
    else:
        model = create_model(opt)
    # create logger
    logger = Logger(opt)
    # Save validation set results as image collage:
    SAVE_IMAGE_COLLAGE = True
    per_image_saved_patch = min(
        [min(im['HR'].shape[1:]) for im in val_loader.dataset]) - 2
    num_val_images = len(val_loader.dataset)
    val_images_collage_rows = int(np.floor(np.sqrt(num_val_images)))
    while val_images_collage_rows > 1:
        if np.round(num_val_images / val_images_collage_rows
                    ) == num_val_images / val_images_collage_rows:
            break
        val_images_collage_rows -= 1
    start_time = time.time()
    min_accumulation_steps = min([
        opt['train']['grad_accumulation_steps_G'],
        opt['train']['grad_accumulation_steps_D']
    ])
    save_GT_HR = True
    lr_too_low = False
    print('---------- Start training -------------')
    last_saving_time = time.time()
    recently_saved_models = deque(maxlen=4)
    for epoch in range(int(math.floor(model.step / train_size)),
                       total_epoches):
        for i, train_data in enumerate(train_loader):
            gradient_step_num = model.step // max_accumulation_steps
            not_within_batch = model.step % max_accumulation_steps == (
                max_accumulation_steps - 1)
            saving_step = (
                (time.time() - last_saving_time) > 60 *
                opt['logger']['save_checkpoint_freq']) and not_within_batch
            if saving_step:
                last_saving_time = time.time()

            # save models
            if lr_too_low or saving_step:
                recently_saved_models.append(model.save(gradient_step_num))
                model.save_log()
                if len(recently_saved_models) > 3:
                    model_2_delete = recently_saved_models.popleft()
                    os.remove(model_2_delete)
                    if model.D_exists:
                        os.remove(model_2_delete.replace('_G.', '_D.'))
                print('{}: Saving the model before iter {:d}.'.format(
                    datetime.now().strftime('%H:%M:%S'), gradient_step_num))
                if lr_too_low:
                    break

            if model.step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters()
            if not model.D_exists:  #Avoid using the naive MultiLR scheduler when using adversarial loss
                for scheduler in model.schedulers:
                    scheduler.step(model.gradient_step_num)
            time_elapsed = time.time() - start_time
            if not_within_batch: start_time = time.time()

            # log
            if gradient_step_num % opt['logger'][
                    'print_freq'] == 0 and not_within_batch:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = gradient_step_num
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train',
                                            print_rlt,
                                            keys_ignore_list=IGNORED_KEYS_LIST)
                model.display_log_figure()

            # validation
            if not_within_batch and (gradient_step_num) % opt['train'][
                    'val_freq'] == 0:  # and gradient_step_num>=opt['train']['D_init_iters']:
                print_rlt = OrderedDict()
                if model.generator_changed:
                    print('---------- validation -------------')
                    start_time = time.time()
                    if False and SAVE_IMAGE_COLLAGE and model.gradient_step_num % opt[
                            'train'][
                                'val_save_freq'] == 0:  #Saving training images:
                        GT_image_collage = []
                        cur_train_results = model.get_current_visuals(
                            entire_batch=True)
                        train_psnrs = [
                            util.calculate_psnr(
                                util.tensor2img(
                                    cur_train_results['SR'][im_num],
                                    out_type=np.float32) * 255,
                                util.tensor2img(
                                    cur_train_results['HR'][im_num],
                                    out_type=np.float32) * 255)
                            for im_num in range(len(cur_train_results['SR']))
                        ]
                        #Save latest training batch output:
                        save_img_path = os.path.join(
                            os.path.join(opt['path']['val_images']),
                            '{:d}_Tr_PSNR{:.3f}.png'.format(
                                gradient_step_num, np.mean(train_psnrs)))
                        util.save_img(
                            np.clip(
                                np.concatenate(
                                    (np.concatenate([
                                        util.tensor2img(
                                            cur_train_results['HR'][im_num],
                                            out_type=np.float32) * 255
                                        for im_num in range(
                                            len(cur_train_results['SR']))
                                    ], 0),
                                     np.concatenate([
                                         util.tensor2img(
                                             cur_train_results['SR'][im_num],
                                             out_type=np.float32) * 255
                                         for im_num in range(
                                             len(cur_train_results['SR']))
                                     ], 0)), 1), 0, 255).astype(np.uint8),
                            save_img_path)
                    Z_latent = [0] + ([-1, 1] if
                                      opt['network_G']['latent_input'] else [])
                    print_rlt['psnr'] = 0
                    for cur_Z in Z_latent:
                        sr_images = model.perform_validation(
                            data_loader=val_loader,
                            cur_Z=cur_Z,
                            print_rlt=print_rlt,
                            save_GT_HR=save_GT_HR,
                            save_images=((model.gradient_step_num) %
                                         opt['train']['val_save_freq'] == 0)
                            or save_GT_HR)
                        if logger.use_tb_logger:
                            logger.tb_logger.log_images(
                                'validation_Z%.2f' % (cur_Z),
                                [im[:, :, [2, 1, 0]] for im in sr_images],
                                model.gradient_step_num)

                        if save_GT_HR:  # Save GT Uncomp images
                            save_GT_HR = False
                    model.log_dict['psnr_val'].append(
                        (gradient_step_num, print_rlt['psnr'] / len(Z_latent)))
                else:
                    print('Skipping validation because generator is unchanged')
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = gradient_step_num
                print_rlt['time'] = time_elapsed
                model.display_log_figure()
                logger.print_format_results('val',
                                            print_rlt,
                                            keys_ignore_list=IGNORED_KEYS_LIST)
                print('-----------------------------------')

            # update learning rate
            if not_within_batch:
                lr_too_low = model.update_learning_rate(gradient_step_num)
        if lr_too_low:
            print('Stopping training because LR is too low')
            break

    print('Saving the final model.')
    model.save(gradient_step_num)
    print('End of training.')
Пример #16
0
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            target = test_set1[0]['HR']

            for i in range(5000):
                cur_out = model.netG(data['LR'], code_val_0, code_val_1,
                                     code_val_2)[-1]
                copied_cur_out = Variable(cur_out.detach().to(device1),
                                          requires_grad=True)
                output = model1.netG(copied_cur_out, code_val_3, None,
                                     None)[-1]
                dist = model1.loss_fn.forward(output, target, normalize=True)
                optimizer.zero_grad()
                optimizer1.zero_grad()
                dist.backward()
                cur_out.backward(copied_cur_out.grad)
                optimizer1.step()
                optimizer.step()
                if i % 10 == 0:
                    print('iter %d, dist %.3g' %
                          (i, dist.view(-1).data.cpu().numpy()[0]))
                if i % 100 == 0:
                    save_img_path = os.path.join(dataset_dir,
                                                 img_name + '_%d.png' % i)
                    sr_img = util.tensor2img(
                        output.detach()[0].float().cpu())  # uint8
                    print("saving: %s" % save_img_path)
                    util.save_img(sr_img, save_img_path)
Пример #17
0
def main():
    #### setup options of three networks
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt_P',
                        type=str,
                        help='Path to option YMAL file of Predictor.')
    parser.add_argument('-opt_C',
                        type=str,
                        help='Path to option YMAL file of Corrector.')
    parser.add_argument('-opt_F',
                        type=str,
                        help='Path to option YMAL file of SFTMD_Net.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt_P = option.parse(args.opt_P, is_train=True)
    opt_C = option.parse(args.opt_C, is_train=True)
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_P = option.dict_to_nonedict(opt_P)
    opt_C = option.dict_to_nonedict(opt_C)
    opt_F = option.dict_to_nonedict(opt_F)

    # choose small opt for SFTMD test, fill path of pre-trained model_F
    opt_F = opt_F['sftmd']

    #### set random seed
    seed = opt_P['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # load PCA matrix of enough kernel
    print('load PCA matrix')
    pca_matrix = torch.load('./pca_matrix.pth',
                            map_location=lambda storage, loc: storage)
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt_P['dist'] = False
        opt_F['dist'] = False
        opt_C['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt_P['dist'] = True
        opt_F['dist'] = True
        opt_C['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size(
        )  #Returns the number of processes in the current process group
        rank = torch.distributed.get_rank(
        )  #Returns the rank of current process group

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ###### Predictor&Corrector train ######

    #### loading resume state if exists
    if opt_P['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt_P['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt_P,
                            resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0-7)
        if resume_state is None:
            # Predictor path
            util.mkdir_and_rename(
                opt_P['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_P['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
            # Corrector path
            util.mkdir_and_rename(
                opt_C['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_C['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt_P['path']['log'],
                          'val_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt_P))
        logger.info(option.dict2str(opt_C))
        # tensorboard logger
        if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_P['name'])
    else:
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt_P['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt_P['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_P['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_P,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_P, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)  #load pretrained model of SFTMD
    model_P = create_model(opt_P)
    model_C = create_model(opt_C)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model_P.resume_training(
            resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_P['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate, schedulers
            # model.update_learning_rate(current_step, warmup_iter=opt_P['train']['warmup_iter'])

            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(opt_P['scale'],
                                            pca_matrix,
                                            random=True,
                                            para_input=opt_P['code_length'],
                                            kernel=opt_P['kernel_size'],
                                            noise=False,
                                            cuda=True,
                                            sig=opt_P['sig'],
                                            sig_min=opt_P['sig_min'],
                                            sig_max=opt_P['sig_max'],
                                            rate_iso=1.0,
                                            scaling=3,
                                            rate_cln=0.2,
                                            noise_high=0.0)
            LR_img, ker_map = prepro(train_data['GT'])

            #### training Predictor
            model_P.feed_data(LR_img, ker_map)
            model_P.optimize_parameters(current_step)
            P_visuals = model_P.get_current_visuals()
            est_ker_map = P_visuals['Batch_est_ker_map']

            #### log of model_P
            if current_step % opt_P['logger']['print_freq'] == 0:
                logs = model_P.get_current_log()
                message = 'Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_P.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            #### training Corrector
            for step in range(opt_C['step']):
                # test SFTMD for corresponding SR image
                model_F.feed_data(train_data, LR_img, est_ker_map)
                model_F.test()
                F_visuals = model_F.get_current_visuals()
                SR_img = F_visuals['Batch_SR']
                # Test SFTMD to produce SR images

                # train corrector given SR image and estimated kernel map
                model_C.feed_data(SR_img, est_ker_map, ker_map)
                model_C.optimize_parameters(current_step)
                C_visuals = model_C.get_current_visuals()
                est_ker_map = C_visuals['Batch_est_ker_map']

                #### log of model_C
                if current_step % opt_C['logger']['print_freq'] == 0:
                    logs = model_C.get_current_log()
                    message = 'Corrector <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                        epoch, current_step,
                        model_C.get_current_learning_rate())
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt_C['use_tb_logger'] and 'debug' not in opt_C[
                                'name']:
                            if rank <= 0:
                                tb_logger.add_scalar(k, v, current_step)
                    if rank <= 0:
                        logger.info(message)

            # validation, to produce ker_map_list(fake)
            if current_step % opt_P['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    prepro = util.SRMDPreprocessing(
                        opt_P['scale'],
                        pca_matrix,
                        random=True,
                        para_input=opt_P['code_length'],
                        kernel=opt_P['kernel_size'],
                        noise=False,
                        cuda=True,
                        sig=opt_P['sig'],
                        sig_min=opt_P['sig_min'],
                        sig_max=opt_P['sig_max'],
                        rate_iso=1.0,
                        scaling=3,
                        rate_cln=0.2,
                        noise_high=0.0)
                    LR_img, ker_map = prepro(val_data['GT'])
                    single_img_psnr = 0.0
                    lr_img = util.tensor2img(
                        LR_img)  #save LR image for reference

                    # valid Predictor
                    model_P.feed_data(LR_img, ker_map)
                    model_P.test()
                    P_visuals = model_P.get_current_visuals()
                    est_ker_map = P_visuals['Batch_est_ker_map']

                    # Save images for reference
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt_P['path']['val_images'],
                                           img_name)
                    # img_dir = os.path.join(opt_F['path']['val_images'], str(current_step), '_', str(step))
                    util.mkdir(img_dir)
                    save_lr_path = os.path.join(img_dir,
                                                '{:s}_LR.png'.format(img_name))
                    util.save_img(lr_img, save_lr_path)

                    for step in range(opt_C['step']):
                        step += 1
                        idx += 1
                        model_F.feed_data(val_data, LR_img, est_ker_map)
                        model_F.test()
                        F_visuals = model_F.get_current_visuals()
                        SR_img = F_visuals['Batch_SR']
                        # Test SFTMD to produce SR images

                        model_C.feed_data(SR_img, est_ker_map, ker_map)
                        model_C.test()
                        C_visuals = model_C.get_current_visuals()
                        est_ker_map = C_visuals['Batch_est_ker_map']

                        sr_img = util.tensor2img(F_visuals['SR'])  # uint8
                        gt_img = util.tensor2img(F_visuals['GT'])  # uint8

                        save_img_path = os.path.join(
                            img_dir, '{:s}_{:d}_{:d}.png'.format(
                                img_name, current_step, step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        crop_size = opt_P['scale']
                        gt_img = gt_img / 255.
                        sr_img = sr_img / 255.
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        step_psnr = util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)
                        logger.info(
                            '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, psnr: {:.6f}'
                            .format(epoch, current_step, step, img_name,
                                    step_psnr))
                        single_img_psnr += step_psnr
                        avg_psnr += util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)

                    avg_signle_img_psnr = single_img_psnr / step
                    logger.info(
                        '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, average psnr: {:.6f}'
                        .format(epoch, current_step, step, img_name,
                                avg_signle_img_psnr))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.6f}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> psnr: {:.6f}'.
                    format(epoch, current_step, step, avg_psnr))
                # tensorboard logger
                if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt_P['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model_P.save(current_step)
                    model_P.save_training_state(epoch, current_step)
                    model_C.save(current_step)
                    model_C.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model_P.save('latest')
        model_C.save('latest')
        logger.info('End of Predictor and Corrector training.')
    tb_logger.close()
Пример #18
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='options/train/train_ESRCNN_S2L8_2.json',
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)

    if opt['path']['resume_state']:
        resume_state = torch.load(opt['path']['resume_state'])
    else:
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)

    logger.info(option.dict2str(opt))

    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])

    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True

    # Setup TrainDataLoader
    trainloader = DataLoader(opt['datasets']['train']['dataroot'],
                             split='train')
    train_size = int(
        math.ceil(len(trainloader) / opt['datasets']['train']['batch_size']))
    logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
        len(trainloader), train_size))
    total_iters = int(opt['train']['niter'])
    total_epochs = int(math.ceil(total_iters / train_size))
    logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
        total_epochs, total_iters))
    TrainDataLoader = data.DataLoader(
        trainloader,
        batch_size=opt['datasets']['train']['batch_size'],
        num_workers=12,
        shuffle=True)
    #Setup for validate
    valloader = DataLoader(opt['datasets']['train']['dataroot'], split='val')
    VALDataLoader = data.DataLoader(
        valloader,
        batch_size=opt['datasets']['train']['batch_size'] // 5,
        num_workers=1,
        shuffle=True)
    logger.info('Number of val images:{:d}'.format(len(valloader)))

    # Setup Model
    model = get_model('esrcnn_s2l8_2', opt)

    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)
    else:
        current_step = 0
        start_epoch = 0

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for i, train_data in enumerate(TrainDataLoader):

            current_step += 1
            if current_step > total_iters:
                break

            model.update_learning_rate()
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}>'.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v[0])
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v[0], current_step)
                logger.info(message)

            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                for i_val, val_data in enumerate(VALDataLoader):
                    idx += 1
                    img_name = val_data[3][0].split('.')[0]
                    model.feed_data(val_data)
                    model.val()

                    visuals = model.get_current_visuals()
                    pred_img = util.tensor2img(visuals['Pred'])
                    gt_img = util.tensor2img(visuals['label'])
                    avg_psnr += util.calculate_psnr(pred_img, gt_img)

                avg_psnr = avg_psnr / idx

                logger.info('# Validation #PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr:{:.4e}'.format(
                        epoch, current_step, avg_psnr))

                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training')
Пример #19
0
def main():
    print('hello')
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/test'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    print('test_dataset_folder:', test_dataset_folder)
    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    print('list:', subfolder_l)
    subfolder_l = ['../datasets/test/dance_small']
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        print(subfolder_name)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        # img_GT_l = []
        # for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #     img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            print('path:', img_path)
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            print(select_idx)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                print('im_path:',
                      osp.join(save_subfolder, '{}.png'.format(img_name)))
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # # calculate PSNR
            # output = output / 255.
            # GT = np.copy(img_GT_l[img_idx])
            # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            # if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #     GT = data_util.bgr2ycbcr(GT, only_y=True)
            #     output = data_util.bgr2ycbcr(output, only_y=True)
            #
            # output, GT = util.crop_border([output, GT], crop_border)
            # crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            # logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))
            #
            # if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
            #     avg_psnr_center += crt_psnr
            #     N_center += 1
            # else:  # border frames
            #     avg_psnr_border += crt_psnr
            #     N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Пример #20
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger-
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(
                    current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)

                avg_psnr = avg_psnr / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
Пример #21
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options YMAL file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    opt = option.dict_to_nonedict(opt)

    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
    util.setup_logger('base',
                      opt['path']['log'],
                      'test_' + opt['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    #### Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    model = create_model(opt)
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_GT = False if test_loader.dataset.opt[
                'dataroot_GT'] is None else True
            model.feed_data(data, need_GT=need_GT)
            img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
            img_name = osp.splitext(osp.basename(img_path))[0]

            model.test()
            visuals = model.get_current_visuals(need_GT=need_GT)

            sr_img = util.tensor2img(visuals['rlt'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir,
                                         img_name + suffix + '.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(visuals['GT'])
                sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                  opt['scale'])
                psnr = util.calculate_psnr(sr_img, gt_img)
                ssim = util.calculate_ssim(sr_img, gt_img)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                    psnr_y = util.calculate_psnr(sr_img_y * 255,
                                                 gt_img_y * 255)
                    ssim_y = util.calculate_ssim(sr_img_y * 255,
                                                 gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_GT:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info(
                '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info(
                    '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                    .format(ave_psnr_y, ave_ssim_y))
Пример #22
0
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    for data in test_loader:
        need_GT = True
        model.feed_data(data, need_GT=need_GT)
        img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
        img_name = osp.splitext(osp.basename(img_path))[0]
        model.test()
        visuals = model.get_current_visuals(need_GT=need_GT)

        if which_model == 'RCAN':
            sr_img = util.tensor2img(visuals['rlt'],
                                     out_type=np.uint8,
                                     min_max=(0, 255))  # uint8
        else:
            sr_img = util.tensor2img(visuals['rlt'])  # uint8

        # save images
        suffix = opt['suffix']
        if suffix:
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)

        if need_GT:
            if which_model == 'RCAN':
                gt_img = util.tensor2img(visuals['GT'],
Пример #23
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 1 #200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            t0 = time.time()
            model.feed_data(train_data)
            model.optimize_parameters(current_step)
            t1 = time.time()

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8d}, speed:{:5.1f}, lr:('.format(epoch, dataset_opt['batch_size']/(t1-t0), current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.5f},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4f} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)
                    logger.info('Saving models and training states.')

            #### validation
            if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in ['sr', 'srgan'] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4f}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
                                                               device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4f}:'.format(psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4f}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt['name']:
                                tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4f}:'.format(psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4f}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        tb_logger.close()
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    parser.add_argument('-single_GPU', action='store_true',help='Utilize only one GPU')
    parser.add_argument('-chroma', action='store_true',help='Training the chroma-channels generator')
    if parser.parse_args().single_GPU:
        available_GPUs = util.Assign_GPU(maxMemory=0.66)
    else:
        # available_GPUs = util.Assign_GPU(max_GPUs=None,maxMemory=0.8,maxLoad=0.8)
        available_GPUs = util.Assign_GPU(max_GPUs=None)
    opt = option.parse(parser.parse_args().opt, is_train=True,batch_size_multiplier=len(available_GPUs),name='JPEG'+('_chroma' if parser.parse_args().chroma else ''))

    if not opt['train']['resume']:
        util.mkdir_and_rename(opt['path']['experiments_root'])  # Modify experiment name if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
            not key == 'pretrained_model_G' and not key == 'pretrained_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            max_accumulation_steps = max([opt['train']['grad_accumulation_steps_G'], opt['train']['grad_accumulation_steps_D']])
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
            total_iters = int(opt['train']['niter']*max_accumulation_steps)#-current_step
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    DEBUG = False
    # Create model
    if DEBUG:
        from models.base_model import BaseModel
        model = BaseModel
        model.step = 0
    else:
        model = create_model(opt,max_accumulation_steps,chroma_mode=opt['name'][:len('JPEG/chroma')]=='JPEG/chroma')

    # create logger
    logger = Logger(opt)
    # Save validation set results as image collage:
    SAVE_IMAGE_COLLAGE = True
    start_time,start_time_gradient_step = time.time(),model.step // max_accumulation_steps
    save_GT_Uncomp = True
    lr_too_low = False
    print('---------- Start training -------------')
    last_saving_time = time.time()
    recently_saved_models = deque(maxlen=4)
    for epoch in range(int(math.floor(model.step / train_size)),total_epoches):
        for i, train_data in enumerate(train_loader):
            model.gradient_step_num = model.step // max_accumulation_steps
            not_within_batch = model.step % max_accumulation_steps == (max_accumulation_steps - 1)
            saving_step = ((time.time()-last_saving_time)>60*opt['logger']['save_checkpoint_freq']) and not_within_batch
            if saving_step:
                last_saving_time = time.time()

            # save models
            if lr_too_low or saving_step:
                model.save_log()
                recently_saved_models.append(model.save(model.gradient_step_num))
                if len(recently_saved_models)>3:
                    model_2_delete = recently_saved_models.popleft()
                    os.remove(model_2_delete)
                    if model.D_exists:
                        os.remove(model_2_delete.replace('_G.','_D.'))
                print('{}: Saving the model before iter {:d}.'.format(datetime.now().strftime('%H:%M:%S'),model.gradient_step_num))
                if lr_too_low:
                    break

            if model.step > total_iters:
                break

            # time_elapsed = time.time() - start_time
            # if not_within_batch:    start_time = time.time()
            # log
            if model.gradient_step_num % opt['logger']['print_freq'] == 0 and not_within_batch:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = model.gradient_step_num
                # time_elapsed = time.time() - start_time
                print_rlt['time'] = (time.time() - start_time)/np.maximum(1,model.gradient_step_num-start_time_gradient_step)
                start_time, start_time_gradient_step = time.time(), model.gradient_step_num
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt,keys_ignore_list=['avg_est_err'])
                model.display_log_figure()

            # validation
            if (not_within_batch or i==0) and (model.gradient_step_num) % opt['train']['val_freq'] == 0: # and model.gradient_step_num>=opt['train']['D_init_iters']:
                print_rlt = OrderedDict()
                if model.generator_changed:
                    print('---------- validation -------------')
                    start_time = time.time()
                    if False and SAVE_IMAGE_COLLAGE and model.gradient_step_num%opt['train']['val_save_freq'] == 0: #Saving training images:
                        # GT_image_collage,quantized_image_collage = [],[]
                        cur_train_results = model.get_current_visuals(entire_batch=True)
                        train_psnrs = [util.calculate_psnr(util.tensor2img(cur_train_results['Decomp'][im_num], out_type=np.uint8,min_max=[0,255]),
                            util.tensor2img(cur_train_results['Uncomp'][im_num], out_type=np.uint8,min_max=[0,255])) for im_num in range(len(cur_train_results['Decomp']))]
                        #Save latest training batch output:
                        save_img_path = os.path.join(os.path.join(opt['path']['val_images']),
                                                     '{:d}_Tr_PSNR{:.3f}.png'.format(model.gradient_step_num, np.mean(train_psnrs)))
                        util.save_img(np.clip(np.concatenate((np.concatenate([util.tensor2img(cur_train_results['Uncomp'][im_num], out_type=np.uint8,min_max=[0,255]) for im_num in
                                 range(len(cur_train_results['Decomp']))],0), np.concatenate(
                                [util.tensor2img(cur_train_results['Decomp'][im_num], out_type=np.uint8,min_max=[0,255]) for im_num in range(len(cur_train_results['Decomp']))],
                                0)), 1), 0, 255).astype(np.uint8), save_img_path)
                    Z_latent = [0]+([-0.5,0.5] if opt['network_G']['latent_input'] else [])
                    print_rlt['psnr'] = 0
                    for cur_Z in Z_latent:
                        model.perform_validation(data_loader=val_loader,cur_Z=cur_Z,print_rlt=print_rlt,GT_and_quantized=save_GT_Uncomp,
                                                 save_images=((model.gradient_step_num) % opt['train']['val_save_freq'] == 0) or save_GT_Uncomp)
                    if save_GT_Uncomp:  # Save GT Uncomp images
                        save_GT_Uncomp = False
                    print_rlt['psnr'] /= len(Z_latent)
                    model.log_dict['psnr_val'].append((model.gradient_step_num,print_rlt['psnr']))
                else:
                    print('Skipping validation because generator is unchanged')
                # time_elapsed = time.time() - start_time
                # Save to log
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = model.gradient_step_num
                # print_rlt['time'] = time_elapsed
                print_rlt['time'] = (time.time() - start_time)/np.maximum(1,model.gradient_step_num-start_time_gradient_step)
                # model.display_log_figure()
                # model.generator_changed = False
                logger.print_format_results('val', print_rlt,keys_ignore_list=['avg_est_err'])
                print('-----------------------------------')

            model.feed_data(train_data,mixed_Y=True)
            model.optimize_parameters()


            # update learning rate
            if not_within_batch:
                lr_too_low = model.update_learning_rate(model.gradient_step_num)
            # current_step += 1
        if lr_too_low:
            print('Stopping training because LR is too low')
            break

    print('Saving the final model.')
    model.save(model.gradient_step_num)
    model.save_log()
    print('End of training.')
Пример #25
0
    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []
    test_results['niqe'] = []
    test_results['niqe_gt'] = []

    for data in test_loader:
        need_HR = False if test_loader.dataset.opt[
            'dataroot_HR'] is None else True

        model.feed_data(data, need_HR=need_HR)
        img_path = data['LR_path'][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        model.test()  # test
        visuals = model.get_current_visuals(need_HR=need_HR)

        sr_img = util.tensor2img(visuals['SR'])  # uint8

        # save images
        suffix = opt['suffix']
        if suffix:
            save_img_path = os.path.join(dataset_dir,
                                         img_name + suffix + '.png')
        else:
            save_img_path = os.path.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)
        # print(save_img_path)
Пример #26
0
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic (REDS)

    # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
    scale = 4
    layer = 52
    assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), (4, 52)
                              ], 'Unrecognized (scale, layer) combination'

    # model
    N_in = 7
    model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format(
        scale, layer)
    adapt_official = True if 'official' in model_path else False
    DUF_downsampling = True  # True | False
    if layer == 16:
        model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official)
    elif layer == 28:
        model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official)
    elif layer == 52:
        model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4/*'
    else:  # sharp_bicubic (REDS)
        test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)

    #### evaluation
    crop_border = 8
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'  # different from the official testing codes, which pads zeros.
    save_imgs = True
    ############################################################################
    device = torch.device('cuda')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder in sub_folder_l:
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []
        if data_mode == 'Vid4':
            sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*')
        else:
            sub_folder_GT = osp.join(
                sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
        for img_GT_path in sorted(glob.glob(sub_folder_GT)):
            img_GT_l.append(read_image(img_GT_path))

        # When using the downsampling in DUF official code, we downsample the HR images
        if DUF_downsampling:
            sub_folder = sub_folder_GT
            img_path_l = sorted(glob.glob(sub_folder))
            max_idx = len(img_path_l)
            imgs = read_seq_imgs(sub_folder[:-2])

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx,
                                          max_idx,
                                          N_in,
                                          padding=padding)
            # get input images
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            # Downsample the HR images
            H, W = imgs_in.size(3), imgs_in.size(4)
            if DUF_downsampling:
                imgs_in = util.DUF_downsample(imgs_in, scale=scale)

            output = single_forward(model, imgs_in)

            # Crop to the original shape
            if scale == 3:
                pad_h = 3 - (H % 3)
                pad_w = 3 - (W % 3)
                if pad_h > 0:
                    output = output[:, :, :-pad_h, :]
                if pad_w > 0:
                    output = output[:, :, :, :-pad_w]
            output_f = output.data.float().cpu().squeeze(0)

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(
                    osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)),
                    output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border,
                                        crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border,
                                crop_border:-crop_border]
            crt_psnr = util.calculate_psnr(cropped_output * 255,
                                           cropped_GT * 255)
            logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(
                img_idx + 1, c_idx, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center +
                                                          cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        sub_folder_name, avg_psnr,
                        (cal_n_center + cal_n_border), avg_psnr_center,
                        cal_n_center, avg_psnr_border, cal_n_border))

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l,
                                                    avg_psnr_l,
                                                    avg_psnr_center_l,
                                                    avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    #parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument("--use_screen_notation", type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        required=True,
                        help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    print("GPU ", torch.cuda.device_count())
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    #GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'],
                           nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'],
                           front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'],
                           HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        # GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    #subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    # import json
    # with open(args.screen_notation) as f:
    #     frame_notation = json.load(f)
    frame_notation = None

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        # subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder)
        #
        # if not os.path.exists(subfolder_GT):
        #     continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        # img_GT_l = []
        # for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #     img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            #select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

            select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(
                input_subfolder,
                frame_notation,
                img_idx,
                max_idx,
                N_in,
                padding=padding,
                enable=args.use_screen_notation)

            if not log1 == None:
                logger.info('screen change')
                logger.info(nota)
                logger.info(log1)
                logger.info(log2)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(
                    device)  # 960 x 540

            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX +
                                   1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY +
                                    1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d(
                [0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " +
                  str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([
                intPaddingLeft, intPaddingRight, intPaddingTop,
                intPaddingBottom
            ])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            #imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            # assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            # split_lengthY = int(split_lengthY)
            # split_lengthX = int(split_lengthX)
            # split_numY = int(float(intHeight_ori) / split_lengthY)
            # split_numX = int(float(intWidth_ori) / split_lengthX)
            # splitsY = range(0, split_numY)
            # splitsX = range(0, split_numX)

            # intWidth = split_lengthX
            # intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            # intHeight = split_lengthY
            # intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            # y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC

            # todo: output 4k

            X0 = imgs_in
            X0 = torch.unsqueeze(X0, 0)
            if flip_test:
                output = util.flipx4_forward(model, X0)
            else:
                output = util.single_forward(model, X0)

            # todo remove padding
            output = output[0, :, intPaddingTop *
                            scale:(intPaddingTop + intHeight_ori) * scale,
                            intPaddingLeft *
                            scale:(intPaddingLeft + intWidth_ori) * scale]

            output = util.tensor2img(output.squeeze(0))

            # for split_j, split_i in itertools.product(splitsY, splitsX):
            #     # print(str(split_j) + ", \t " + str(split_i))
            #     X0 = imgs_in[:, :,
            #          split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
            #          split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]
            #
            #     # y_ = torch.FloatTensor()
            #
            #     X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W
            #
            #     if flip_test:
            #         output = util.flipx4_forward(model, X0)
            #     else:
            #         output = util.single_forward(model, X0)
            #
            #     output_depadded = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale,
            #                       intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale]
            #     output_depadded = output_depadded.squeeze(0)
            #     output = util.tensor2img(output_depadded)
            #
            #     y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
            #     split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \
            #         np.round(output).astype(np.uint8)

            # plt.figure(0)
            # plt.title("pic")
            # plt.imshow(y_all)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            print("*****************current image process time \t " +
                  str(time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder,
                                                       img_idx + 1, img_name))

            # # calculate PSNR
            # y_all = y_all / 255.
            # GT = np.copy(img_GT_l[img_idx])
            # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            # if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #     GT = data_util.bgr2ycbcr(GT, only_y=True)
            #     y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            # y_all, GT = util.crop_border([y_all, GT], crop_border)
            # crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            # logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            # if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
            #     avg_psnr_center += crt_psnr
            #     N_center += 1
            # else:  # border frames
            #     avg_psnr_border += crt_psnr
            #     N_border += 1

        # avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        # avg_psnr_center = avg_psnr_center / N_center
        # avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        # avg_psnr_l.append(avg_psnr)
        # avg_psnr_center_l.append(avg_psnr_center)
        # avg_psnr_border_l.append(avg_psnr_border)

        # logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
        #             'Center PSNR: {:.6f} dB for {} frames; '
        #             'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
        #                                                            (N_center + N_border),
        #                                                            avg_psnr_center, N_center,
        #                                                            avg_psnr_border, N_border))

    # logger.info('################ Tidy Outputs ################')
    # for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
    #                                                           avg_psnr_center_l, avg_psnr_border_l):
    #     logger.info('Folder {} - Average PSNR: {:.6f} dB. '
    #                 'Center PSNR: {:.6f} dB. '
    #                 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
    #                                                  psnr_border))
    # logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
Пример #28
0
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth')
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth')
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth')
    elif data_mode == 'blur':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth')
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth')
        else:
            model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth')
    else:
        raise NotImplementedError
    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5
    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = osp.join(root, '../datasets/Vid4/BIx4/*')
        GT_dataset_folder = osp.join(root, '../datasets/Vid4/GT/*')
    else:
        if stage == 1:
            test_dataset_folder = osp.join(root, f'../datasets/REDS4/{data_mode}/*')
        else:
            raise ValueError('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = osp.join(root, '../datasets/REDS4/GT/*')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    device = torch.device('cuda')
    save_folder = f'../results/{data_mode}'
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info(f'Data: {data_mode} - {test_dataset_folder}')
    logger.info(f'Padding mode: {padding}')
    logger.info(f'Model path: {model_path}')
    logger.info(f'Save images: {save_imgs}')
    logger.info(f'Flip Test: {flip_test}')

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))
    sub_folder_GT_l = sorted(glob.glob(GT_dataset_folder))
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l):
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))):
            img_GT_l.append(read_image(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx, max_idx, N_in, padding=padding)
            # get input images
            imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            output = single_forward(model, imgs_in)
            output_f = output.data.float().cpu().squeeze(0)

            if flip_test:
                # flip W
                output = single_forward(model, torch.flip(imgs_in, (-1, )))
                output = torch.flip(output, (-1, ))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output
                # flip H
                output = single_forward(model, torch.flip(imgs_in, (-2, )))
                output = torch.flip(output, (-2, ))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output
                # flip both H and W
                output = single_forward(model, torch.flip(imgs_in, (-2, -1)))
                output = torch.flip(output, (-2, -1))
                output = output.data.float().cpu().squeeze(0)
                output_f = output_f + output

                output_f = output_f / 4

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(osp.join(save_sub_folder, f'{c_idx:08d}.png'), output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
            crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
            logger.info(f'{img_idx+1:3d} - {c_idx:25}.png \tPSNR: {crt_psnr:.6f} dB')

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info(f'Folder {sub_folder_name} - Average PSNR: {avg_psnr:.6f} dB for {(cal_n_center + cal_n_border)} frames; '
                    f'Center PSNR: {avg_psnr_center:.6f} dB for {cal_n_center} frames; '
                    f'Border PSNR: {avg_psnr_border:.6f} dB for {cal_n_border} frames.')

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
                                                    avg_psnr_center_l, avg_psnr_border_l):
        logger.info(f'Folder {name} - Average PSNR: {psnr:.6f} dB. '
                    f'Center PSNR: {psnr_center:.6f} dB. '
                    f'Border PSNR: {psnr_border:.6f} dB.')
    logger.info('################ Final Results ################')
    logger.info(f'Data: {data_mode} - {test_dataset_folder}')
    logger.info(f'Padding mode: {padding}')
    logger.info(f'Model path: {model_path}')
    logger.info(f'Save images: {save_imgs}')
    logger.info(f'Flip Test: {flip_test}')
    logger.info(f'Total Average PSNR: {sum(avg_psnr_l) / len(avg_psnr_l):.6f} dB for {len(sub_folder_l)} clips. '
                f'Center PSNR: {sum(avg_psnr_center_l) / len(avg_psnr_center_l):.6f} dB. '
                f'Border PSNR: {sum(avg_psnr_border_l) / len(avg_psnr_border_l):.6f} dB.')
Пример #29
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    first_time = True
    save_count = 0
    max_psnr = 0
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            if first_time:
                start_time = time.time()
                first_time = False
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                end_time = time.time()
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, , time:{:.3f}> '.format(
                    epoch, current_step, model.get_current_learning_rate(),
                    end_time - start_time)
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
                start_time = time.time()

            # validation
            if current_step % opt['train'][
                    'val_freq'] == 0 and rank <= 0 and current_step >= opt[
                        'train']['val_min_iter']:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['img_GT_bic4x'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    model.save_training_state(epoch, current_step)

            if current_step % opt['logger'][
                    'save_checkpoint_freq'] == 0 and current_step >= opt[
                        'train']['val_min_iter']:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    save_count += 1
                    if avg_psnr >= max_psnr:
                        max_psnr = avg_psnr
                        model.save('best')
                    else:
                        model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Пример #30
0
def main():
    need_chop = True
    cal_metrics = True
    save_images = False
    scale = 4
    # data_mode = 'KON_GRU_NONLOCAL_ratio=1'
    # test_dataset_folder = '../../datasets/KON/HR'
    # file_list = '../datasets/KON/test_100.txt'
    data_mode = 'Vimeo_GRU'
    test_dataset_folder = '../../datasets/vimeo_septuplet/sequences'
    file_list = '../datasets/Vimeo/sep_testlist.txt'
    # model path
    # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_GRU/models/latest_G.pth'
    # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_LSTM/models/latest_G.pth'
    # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_LSTM_NONLOCAL/models/latest_G.pth'
    # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_GRU_NONLOCAL_ratio=1/models/latest_G.pth'
    model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/Vimeo_GRU/models/480000_G.pth'
    # model = Sakuya_arch.LunaTokisGRU(64, 7, 8, 5, 5)
    # model = Sakuya_arch.LunaTokis(64, 7, 8, 5, 5)
    # model = Sakuya_arch.NonLocalNet(64, 7, 8, 5, 5)
    model = Sakuya_arch.LunaTokisGRU(64, 7, 8, 5, 5)

    if torch.cuda.is_available() and os.environ['CUDA_VISIBLE_DEVICES'] != '':
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    model_params = util.get_model_total_params(model)

    # log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Model parameters: {} M'.format(model_params))
    logger.info('Device: {}'.format(device))

    model.load_state_dict(torch.load(model_path), strict=True)
    model = model.to(device)
    model.eval()

    test_set = DatasetFromFolderTest(test_dataset_folder,
                                     7,
                                     scale,
                                     file_list,
                                     transform=transform())
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=8)
    test_num = len(test_loader)
    avg_psnr = 0.0
    avg_ssim = 0.0
    avg_time = 0.0
    with torch.no_grad():
        for data in test_loader:
            input = Variable(data['LQs']).to(device)
            target = Variable(data['GT']).to(device)
            info = os.path.join(data['INFO'][0].split('/')[-2],
                                data['INFO'][0].split('/')[-1])
            t0 = time.time()
            # 显存不足情况下,将整张图片分为多个部分进行测试
            if need_chop:
                predictions = chop_forward(input, model, scale, device)
            else:
                predictions = model(input)
            predictions = predictions[0]  # batch
            t1 = time.time()
            pre_num = len(predictions)
            time_predicted = (t1 - t0) / pre_num

            psnr_predicted = 0.0
            ssim_predicted = 0.0
            for i in range(pre_num):
                # save images
                pre = util.tensor2img(predictions[i])
                if save_images:
                    img_path = os.path.join(save_folder, info)
                    if not os.path.exists(img_path):
                        os.makedirs(img_path)
                    img_path = os.path.join(img_path, 'im{}.jpg'.format(i + 1))
                    util.save_img(pre, img_path)
                if cal_metrics:
                    # calculate PSNR and SSIM
                    tar = util.tensor2img(target[0][i])
                    psnr_predicted += util.PSNR(pre, tar)
                    ssim_predicted += util.SSIM(pre, tar)

            psnr_predicted /= pre_num
            ssim_predicted /= pre_num

            avg_psnr += psnr_predicted
            avg_ssim += ssim_predicted
            avg_time += time_predicted
            logger.info(
                "Processing: %s || PSNR: %.4f || SSIM: %.4f || Avg Timer: %.4f sec."
                % (info, psnr_predicted, ssim_predicted, time_predicted))

    avg_time /= test_num
    avg_psnr /= test_num
    avg_ssim /= test_num
    logger.info(
        "Finished: %s || PSNR: %.4f || SSIM: %.4f || Avg Timer: %.4f sec." %
        (data_mode, avg_psnr, avg_ssim, avg_time))