示例#1
0
def main():
    #################
    # configurations
    #################
    device = torch.device("cuda")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    data_mode = ("licensePlate_blur_bicubic"
                 )  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == "Vid4":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth"
        else:
            raise ValueError("Vid4 does not support stage 2.")
    elif data_mode == "sharp_bicubic":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SR_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth"
    elif data_mode == "blur_bicubic":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth"
    elif data_mode == "blur":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblur_L.pth"
        else:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth"
    elif data_mode == "blur_comp":
        if stage == 1:
            model_path = "../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth"
        else:
            model_path = (
                "../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth"
            )
    elif data_mode == "licensePlate_blur_bicubic":
        model_path = ("/workspace/video_sr/EDVR/experiments/" +
                      "pretrained_models/EDVR_licensePlate_SRblur_L.pth")
    else:
        raise NotImplementedError

    if data_mode == "Vid4":
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if (data_mode == "blur_bicubic") or (data_mode
                                         == "licensePlate_blur_bicubic"):
        predeblur = True
    elif data_mode == "blur" or data_mode == "blur_comp":
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### dataset
    if data_mode == "Vid4":
        test_dataset_folder = "../datasets/Vid4/BIx4"
        GT_dataset_folder = "../datasets/Vid4/GT"
    elif data_mode == "licensePlate_blur_bicubic":
        test_dataset_folder = "../datasets/license_plate2/BIx4"
        GT_dataset_folder = "../datasets/license_plate2/GT"
    else:
        if stage == 1:
            test_dataset_folder = "../datasets/REDS4/{}".format(data_mode)
        else:
            test_dataset_folder = "../results/REDS-EDVR_REDS_SR_L_flipx4"
            print("You should modify the test_dataset_folder path for stage 2")
        GT_dataset_folder = "../datasets/REDS4/GT"

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == "Vid4" or data_mode == "sharp_bicubic":
        padding = "new_info"
    else:
        padding = "replicate"
    save_imgs = True

    save_folder = "../results/{}".format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger("base",
                      save_folder,
                      "test",
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger("base")

    #### log info
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, "*")))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, "*")))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, "*")))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = test_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, "*"))):
            img_GT_l.append(test_util.read_img(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = test_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = (imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device))

            if flip_test:
                output = test_util.flipx4_forward(model, imgs_in)
            else:
                output = test_util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, "{}.png".format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.0
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == "Vid4":  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = test_util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info("{:3d} - {:25} \tPSNR: {:.6f} dB".format(
                img_idx + 1, img_name, crt_psnr))

            if (img_idx >= border_frame
                    and img_idx < max_idx - border_frame):  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info("Folder {} - Average PSNR: {:.6f} dB for {} frames; "
                    "Center PSNR: {:.6f} dB for {} frames; "
                    "Border PSNR: {:.6f} dB for {} frames.".format(
                        subfolder_name,
                        avg_psnr,
                        (N_center + N_border),
                        avg_psnr_center,
                        N_center,
                        avg_psnr_border,
                        N_border,
                    ))

    logger.info("################ Tidy Outputs ################")
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info("Folder {} - Average PSNR: {:.6f} dB. "
                    "Center PSNR: {:.6f} dB. "
                    "Border PSNR: {:.6f} dB.".format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info("################ Final Results ################")
    logger.info("Data: {} - {}".format(data_mode, test_dataset_folder))
    logger.info("Padding mode: {}".format(padding))
    logger.info("Model path: {}".format(model_path))
    logger.info("Save images: {}".format(save_imgs))
    logger.info("Flip test: {}".format(flip_test))
    logger.info("Total Average PSNR: {:.6f} dB for {} clips. "
                "Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.".format(
                    sum(avg_psnr_l) / len(avg_psnr_l),
                    len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l),
                ))
示例#2
0
def main(opts):
    ################## configurations #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus
    cache_all_imgs = opts.cache > 0
    n_gpus = len(opts.gpus.split(','))
    flip_test, save_imgs = False, False
    scale = 4
    N_in, nf = 5, 64
    back_RBs = 10
    w_TSA = False
    predeblur, HR_in = False, False
    crop_border = 0
    border_frame = N_in // 2
    padding = 'new_info'

    ################## model files ####################
    model_dir = opts.model_dir
    if osp.isfile(model_dir):
        model_names = [osp.basename(model_dir)]
        model_dir = osp.dirname(model_dir)
    elif osp.isdir(model_dir):
        model_names = [
            x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0])
        ]
        model_names = sorted(model_names, key=lambda x: int(x.split("_")[0]))
    else:
        raise IOError('Invalid model_dir: {}'.format(model_dir))

    ################## dataset ########################
    test_subs = sorted(os.listdir(opts.test_dir))
    gt_subs = os.listdir(opts.gt_dir)
    valid_test_subs = [sub in gt_subs for sub in test_subs]
    assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format(
        opts.test_dir)
    scale = float(os.path.basename(os.path.dirname(opts.test_dir))[1:])
    if cache_all_imgs:
        print('Cacheing all testing images ...')
        all_imgs = {}
        for sub in test_subs:
            print('Reading sub-folder: {} ...'.format(sub))
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            all_imgs[sub] = {'test': [], 'gt': []}
            im_names = sorted(os.listdir(test_sub_dir))
            for i, name in enumerate(im_names):
                test_im_path = osp.join(test_sub_dir, name)
                gt_im_path = osp.join(gt_sub_dir, name)
                test_im = cv2.imread(test_im_path,
                                     cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                test_im = test_im.astype(np.float32).transpose(
                    (2, 0, 1)) / 255.
                all_imgs[sub]['test'].append(test_im)
                gt_im = cv2.imread(gt_im_path,
                                   cv2.IMREAD_UNCHANGED).astype(np.float32)
                all_imgs[sub]['gt'].append(gt_im)

    all_psnrs = []
    for model_name in model_names:
        model_path = osp.join(model_dir, model_name)
        exp_name = model_name.split('_')[0]
        if 'meta' in opts.mode.lower():
            model = EDVR_arch.MetaEDVR(nf=nf,
                                       nframes=N_in,
                                       groups=8,
                                       front_RBs=5,
                                       center=None,
                                       back_RBs=back_RBs,
                                       predeblur=predeblur,
                                       HR_in=HR_in,
                                       w_TSA=w_TSA)
        elif opts.mode.lower() == 'edvr':
            model = EDVR_arch.EDVR(nf=nf,
                                   nframes=N_in,
                                   groups=8,
                                   front_RBs=5,
                                   center=None,
                                   back_RBs=back_RBs,
                                   predeblur=predeblur,
                                   HR_in=HR_in,
                                   w_TSA=w_TSA)
        elif opts.mode.lower() == 'upedvr':
            model = EDVR_arch.UPEDVR(nf=nf,
                                     nframes=N_in,
                                     groups=8,
                                     front_RBs=5,
                                     center=None,
                                     back_RBs=10,
                                     w_TSA=w_TSA,
                                     down_scale=True,
                                     align_target=True,
                                     ret_valid=True)
        elif opts.mode.lower() == 'upcont1':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=False)
        elif opts.mode.lower() == 'upcont3':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        elif opts.mode.lower() == 'upcont2':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        else:
            raise TypeError('Unknown model mode: {}'.format(opts.mode))
        save_folder = osp.join(opts.save_dir, exp_name)
        util.mkdirs(save_folder)
        util.setup_logger(exp_name,
                          save_folder,
                          'test',
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger(exp_name)

        #### log info
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))

        #### set up the models
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        if n_gpus > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], []
        avg_ssims, avg_ssim_centers, avg_ssim_borders = [], [], []
        evaled_subs = []

        # for each subfolder
        for sub in test_subs:
            evaled_subs.append(sub)
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            img_names = sorted(os.listdir(test_sub_dir))
            max_idx = len(img_names)
            if save_imgs:
                save_subfolder = osp.join(save_folder, sub)
                util.mkdirs(save_subfolder)

            #### get LQ and GT images
            if not cache_all_imgs:
                img_LQs, img_GTs = [], []
                for i, name in enumerate(img_names):
                    test_im_path = osp.join(test_sub_dir, name)
                    gt_im_path = osp.join(gt_sub_dir, name)
                    test_im = cv2.imread(test_im_path,
                                         cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                    test_im = test_im.astype(np.float32).transpose(
                        (2, 0, 1)) / 255.
                    gt_im = cv2.imread(gt_im_path,
                                       cv2.IMREAD_UNCHANGED).astype(np.float32)
                    img_LQs.append(test_im)
                    img_GTs.append(gt_im)
            else:
                img_LQs = all_imgs[sub]['test']
                img_GTs = all_imgs[sub]['gt']

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
            avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

            # process each image
            for i in range(0, max_idx, n_gpus):
                end = min(i + n_gpus, max_idx)
                select_idxs = [
                    data_util.index_generation(j,
                                               max_idx,
                                               N_in,
                                               padding=padding)
                    for j in range(i, end)
                ]
                imgs = []
                for select_idx in select_idxs:
                    im = torch.from_numpy(
                        np.stack([img_LQs[k] for k in select_idx]))
                    imgs.append(im)
                if (i + n_gpus) > max_idx:
                    for _ in range(max_idx, i + n_gpus):
                        imgs.append(torch.zeros_like(im))
                imgs = torch.stack(imgs, 0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs)
                else:
                    if 'meta' in opts.mode.lower():
                        output = util.meta_single_forward(
                            model, imgs, scale, n_gpus)
                    if 'up' in opts.mode.lower():
                        output = util.up_single_forward(model, imgs, scale)
                    else:
                        output = util.single_forward(model, imgs)
                output = [
                    util.tensor2img(x).astype(np.float32) for x in output
                ]

                if save_imgs:
                    for ii in range(i, end):
                        cv2.imwrite(
                            osp.join(save_subfolder,
                                     '{}.png'.format(img_names[ii])),
                            output[ii - i].astype(np.uint8))

                # calculate PSNR
                GT = np.copy(img_GTs[i:end])

                output = util.crop_border(output, crop_border)
                GT = util.crop_border(GT, crop_border)
                for m in range(i, end):
                    crt_psnr = util.calculate_psnr(output[m - i], GT[m - i])
                    crt_ssim = util.calculate_ssim(output[m - i], GT[m - i])
                    logger.info(
                        '{:3d} - {:25} \tPSNR: {:.6f} dB    SSIM: {:.6}'.
                        format(m + 1, img_names[m], crt_psnr, crt_ssim))

                    if m >= border_frame and m < max_idx - border_frame:  # center frames
                        avg_psnr_center += crt_psnr
                        avg_ssim_center += crt_ssim
                        N_center += 1
                    else:  # border frames
                        avg_psnr_border += crt_psnr
                        avg_ssim_border += crt_ssim
                        N_border += 1

            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border

            avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center +
                                                              N_border)
            avg_ssim_center = avg_ssim_center / N_center
            avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border

            avg_psnrs.append(avg_psnr)
            avg_psnr_centers.append(avg_psnr_center)
            avg_psnr_borders.append(avg_psnr_border)

            avg_ssims.append(avg_ssim)
            avg_ssim_centers.append(avg_ssim_center)
            avg_ssim_borders.append(avg_ssim_border)

            logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                        'Center PSNR: {:.6f} dB for {} frames; '
                        'Border PSNR: {:.6f} dB for {} frames.'.format(
                            sub, avg_psnr, (N_center + N_border),
                            avg_psnr_center, N_center, avg_psnr_border,
                            N_border))
            logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                        'Center SSIM: {:.6f} for {} frames; '
                        'Border SSIM: {:.6f} for {} frames.'.format(
                            sub, avg_ssim, (N_center + N_border),
                            avg_ssim_center, N_center, avg_ssim_border,
                            N_border))

        logger.info('################ Tidy Outputs ################')
        for sub_name, psnr, psnr_center, psnr_border, ssim, ssim_center, ssim_border in zip(
                evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders,
                avg_ssims, avg_ssim_centers, avg_ssim_borders):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sub_name, psnr, psnr_center, psnr_border))
            logger.info('Folder {} - Average SSIM: {:.6f} '
                        'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                            sub_name, ssim, ssim_center, ssim_border))

        logger.info('################ Final Results ################')
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))
        logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                    'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                        sum(avg_psnrs) / len(avg_psnrs), len(test_subs),
                        sum(avg_psnr_centers) / len(avg_psnr_centers),
                        sum(avg_psnr_borders) / len(avg_psnr_borders)))
        logger.info('Total Average SSIM: {:.6f} for {} clips. '
                    'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                        sum(avg_ssims) / len(avg_ssims), len(test_subs),
                        sum(avg_ssim_centers) / len(avg_ssim_centers),
                        sum(avg_ssim_borders) / len(avg_ssim_borders)))
示例#3
0
文件: train.py 项目: csbhr/CNLRN
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='./options/train/train_CNLRN.yml',
                        help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in ['sr',
                                    'srgan']:  # image restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        idx_d = 0
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQ'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            img_name = os.path.splitext(
                                os.path.basename(val_data['LQ_path']))[0]
                            folder = 'valset'
                            idx_d += 1
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    len(val_set),
                                    dtype=torch.float32,
                                    device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            rlt_img, gt_img = util.crop_border(
                                [rlt_img, gt_img], opt['scale'])
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {}'.format(img_name))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr', psnr_total_avg,
                                                     current_step)
                    else:
                        # does not support multi-GPU validation
                        pbar = util.ProgressBar(len(val_loader))
                        avg_psnr = 0.
                        idx = 0
                        for val_data in val_loader:
                            idx += 1
                            img_name = os.path.splitext(
                                os.path.basename(val_data['LQ_path'][0]))[0]

                            model.feed_data(val_data)
                            model.test()

                            visuals = model.get_current_visuals()
                            sr_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                              opt['scale'])
                            avg_psnr += util.calculate_psnr(sr_img, gt_img)
                            pbar.update('Test {}'.format(img_name))

                        avg_psnr = avg_psnr / idx

                        # log
                        logger.info(
                            '# Validation # PSNR: {:.4e}'.format(avg_psnr))
                        # tensorboard logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr', avg_psnr,
                                                 current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'][0]
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        tb_logger.close()
示例#4
0
def main(jsonPath):
    # options
    opt = option.parse(jsonPath, is_train=False)
    util.mkdirs((path for key, path in opt["path"].items()
                 if not key == "pretrain_model_G"))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt["path"]["log"],
                      "test.log",
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger("base")
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt["datasets"].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info("Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt["name"]
        logger.info("\nTesting [{:s}]...".format(test_set_name))
        # test_start_time = time.time()
        dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results["psnr"] = []
        test_results["ssim"] = []
        test_results["psnr_y"] = []
        test_results["ssim_y"] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                "dataroot_HR"] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data["LR_path"][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            sr_img = util.tensor2img(visuals["SR"])  # uint8

            # save images
            suffix = opt["suffix"]
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + ".png")
            else:
                save_img_path = os.path.join(dataset_dir, img_name + ".png")
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals["HR"])
                gt_img = gt_img / 255.0
                sr_img = sr_img / 255.0

                crop_border = test_loader.dataset.opt["scale"]
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results["psnr"].append(psnr)
                test_results["ssim"].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}."
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
            ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
            logger.info(
                "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n"
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results["psnr_y"] and test_results["ssim_y"]:
                ave_psnr_y = sum(test_results["psnr_y"]) / len(
                    test_results["psnr_y"])
                ave_ssim_y = sum(test_results["ssim_y"]) / len(
                    test_results["ssim_y"])
                logger.info(
                    "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n"
                    .format(ave_psnr_y, ave_ssim_y))
示例#5
0
文件: test_IKC.py 项目: zhwzhong/DAN
                gt_img = util.tensor2img(F_visuals["GT"])
                gt_img = gt_img / 255.0
                sr_img = sr_img / 255.0

                crop_border = (opt_P["crop_border"]
                               if opt_P["crop_border"] else opt_P["scale"])
                if crop_border == 0:
                    cropped_sr_img = sr_img
                    cropped_gt_img = gt_img
                else:
                    cropped_sr_img = sr_img[crop_border:-crop_border,
                                            crop_border:-crop_border, :]
                    cropped_gt_img = gt_img[crop_border:-crop_border,
                                            crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                # test_results['psnr'].append(psnr)
                # test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    if crop_border == 0:
                        cropped_sr_img_y = sr_img_y
                        cropped_gt_img_y = gt_img_y
                    else:
                        cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                    crop_border:-crop_border]
                        cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
示例#6
0
def main():
    ###### SFTMD train ######
    #### setup options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt_F', type=str, help='Path to option YMAL file of SFTMD_Net.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_F = option.dict_to_nonedict(opt_F)

    #### random seed
    seed = opt_F['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # create PCA matrix of enough kernel
    batch_ker = util.random_batch_kernel(batch=30000, l=21, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, tensor=False)
    print('batch kernel shape: {}'.format(batch_ker.shape))
    b = np.size(batch_ker, 0)
    batch_ker = batch_ker.reshape((b, -1))
    pca_matrix = util.PCA(batch_ker, k=10).float()
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt_F['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt_F['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### loading resume state if exists
    if opt_F['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt_F['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt_F, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:
        if resume_state is None:
            util.mkdir_and_rename(
                opt_F['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt_F['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt_F['path']['log'], 'train_' + opt_F['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt_F['path']['log'], 'val_' + opt_F['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt_F))
        # tensorboard logger
        if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_F['name'])
    else:
        util.setup_logger('base', opt_F['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    #### create train and val dataloader
    dataset_ratio = 200   # enlarge the size of each epoch
    for phase, dataset_opt in opt_F['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt_F['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_F['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_F, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_F, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model_F.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_F['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=10, kernel=21, noise=False, cuda=True,
                                                  sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3,
                                                  rate_cln=0.2, noise_high=0.0)
            LR_img, ker_map = prepro(train_data['GT'])

            #### update learning rate, schedulers
            model_F.update_learning_rate(current_step, warmup_iter=opt_F['train']['warmup_iter'])

            #### training
            model_F.feed_data(train_data, LR_img, ker_map)
            model_F.optimize_parameters(current_step)

            #### log
            if current_step % opt_F['logger']['print_freq'] == 0:
                logs = model_F.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_F.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt_F['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    idx += 1
                    #### preprocessing for LR_img and kernel map
                    prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=15, noise=False, cuda=True,
                                                    sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3,
                                                    rate_cln=0.2, noise_high=0.0)
                    LR_img, ker_map = prepro(val_data['GT'])

                    model_F.feed_data(val_data, LR_img, ker_map)
                    model_F.test()

                    visuals = model_F.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    #img_dir = os.path.join(opt_F['path']['val_images'], img_name)
                    img_dir = os.path.join(opt_F['path']['val_images'], str(current_step))
                    util.mkdir(img_dir)

                    save_img_path = os.path.join(img_dir,'{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt_F['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)


            #### save models and training states
            if current_step % opt_F['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model_F.save(current_step)
                    model_F.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model_F.save('latest')
        logger.info('End of SFTMD training.')
示例#7
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # -------------------------------------------- ADDED --------------------------------------------
    filter_low = filters.FilterLow(gaussian=False)
    l1_loss = torch.nn.L1Loss()
    mse_loss = torch.nn.MSELoss()
    if torch.cuda.is_available():
        filter_low = filter_low.cuda()
        l1_loss = l1_loss.cuda()
        mse_loss = mse_loss.cuda()
    # -----------------------------------------------------------------------------------------------

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = val_pix_err_f = val_pix_err_nf = val_mean_color_err = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir,
                                                 '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                    # ----------------------------------------- ADDED -----------------------------------------
                    val_pix_err_f += l1_loss(filter_low(visuals['SR']), filter_low(visuals['GT']))
                    val_pix_err_nf += l1_loss(visuals['SR'], visuals['GT'])
                    val_mean_color_err += mse_loss(visuals['SR'].mean(2).mean(1), visuals['GT'].mean(2).mean(1))
                    # -----------------------------------------------------------------------------------------

                avg_psnr = avg_psnr / idx
                val_pix_err_f /= idx
                val_pix_err_nf /= idx
                val_mean_color_err /= idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                    epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('val_pix_err_f', val_pix_err_f, current_step)
                    tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step)
                    tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
 if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0 or SAVE_AVG_METRICS_WHEN_LATENT:
     if z_sample_num == 0:
         quantized_image = util.tensor2img(
             model.Return_Compressed(
                 gt_im_YCbCr.to(model.device)),
             out_type=np.uint8,
             min_max=[0, 255],
             chroma_mode=chroma_mode)
         # quantized_image = util.tensor2img(model.jpeg_extractor(model.jpeg_compressor(data['Uncomp'])), out_type=np.uint8,min_max=[0, 255],chroma_mode=chroma_mode)
         if SAVE_QUANTIZED:
             util.save_img(
                 quantized_image,
                 os.path.join(dataset_dir + '_Quant',
                              img_name + suffix + '.png'))
         test_results['psnr_quantized'].append(
             util.calculate_psnr(quantized_image, gt_img))
         test_results['ssim_quantized'].append(
             util.calculate_ssim(quantized_image, gt_img))
     psnr = util.calculate_psnr(sr_img, gt_img)
     ssim = util.calculate_ssim(sr_img, gt_img)
     test_results['psnr'].append(psnr)
     test_results['ssim'].append(ssim)
 if SAVE_IMAGE_COLLAGE:
     if len(test_set) > 1:
         margins2crop = ((np.array(sr_img.shape[:2]) -
                          per_image_saved_patch) / 2).astype(
                              np.int32)
     else:
         margins2crop = [0, 0]
     image_collage[-1].append(
         np.clip(util.crop_center(sr_img, margins2crop), 0,
示例#9
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work

    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        # resume_state[]
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(logdir='../../SRN_tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state,
                              opt['train'])  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data, True)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # training samples
            if opt['train']['save_tsamples'] and current_step % opt['train'][
                    'save_tsamples'] == 0:
                fake_LRs = os.listdir(
                    opt['datasets']['train']['dataroot_fake_LR'])
                real_LRs = os.listdir(
                    opt['datasets']['train']['dataroot_real_LR'])
                HRs = os.listdir(opt['datasets']['train']['dataroot_HR'])

                for i in range(5):
                    random_index = np.random.choice(range(len(fake_LRs)))
                    fake_LR_path = os.path.join(
                        opt['datasets']['train']['dataroot_fake_LR'],
                        fake_LRs[random_index])
                    real_LR_path = os.path.join(
                        opt['datasets']['train']['dataroot_real_LR'],
                        real_LRs[random_index])
                    HR_path = os.path.join(
                        opt['datasets']['train']['dataroot_HR'],
                        HRs[random_index])
                    fake_LR = np.array(Image.open(fake_LR_path))
                    real_LR = np.array(Image.open(real_LR_path))
                    HR = np.array(Image.open(HR_path))

                    h, w, _ = fake_LR.shape
                    fake_LR = fake_LR[h // 2 - 64:h // 2 + 64,
                                      w // 2 - 64:w // 2 + 64, :]
                    h, w, _ = HR.shape
                    HR = HR[h // 2 - 64 * 4:h // 2 + 64 * 4,
                            w // 2 - 64 * 4:w // 2 + 64 * 4, :]

                    h, w, _ = real_LR.shape
                    real_LR = real_LR[h // 2 - 64:h // 2 + 64,
                                      w // 2 - 64:w // 2 + 64, :]

                    fake_LR = torch.from_numpy(
                        np.ascontiguousarray(np.transpose(
                            fake_LR, (2, 0, 1)))).float().unsqueeze(0) / 255
                    real_LR = torch.from_numpy(
                        np.ascontiguousarray(np.transpose(
                            real_LR, (2, 0, 1)))).float().unsqueeze(0) / 255
                    HR = torch.from_numpy(
                        np.ascontiguousarray(np.transpose(
                            HR, (2, 0, 1)))).float().unsqueeze(0) / 255
                    LR = torch.cat([fake_LR, real_LR], dim=0)

                    data = {'LR': LR, 'HR': HR}
                    model.feed_data(data, False)
                    model.test(tsamples=True)
                    visuals = model.get_current_visuals(tsamples=True)
                    fake_SR = visuals['SR'][0]
                    real_SR = visuals['SR'][1]
                    fake_hf = visuals['hf'][0]
                    real_hf = visuals['hf'][1]
                    HR = visuals['HR']
                    HR_hf = visuals['HR_hf'][0]

                    # image_1 = torch.cat([fake_LR[0], fake_SR[0]], dim=2)
                    # image_2 = torch.cat([real_LR[0], real_SR[0]], dim=2)
                    image_1 = np.clip(torch.cat([fake_SR, HR, real_SR], dim=2),
                                      0, 1)
                    image_2 = np.clip(
                        torch.cat([fake_hf, HR_hf, real_hf], dim=2), 0, 1)
                    image = torch.cat([image_1, image_2], dim=1)
                    tb_logger.add_image(
                        'train/train_samples_{}'.format(str(i)), image,
                        current_step)
                logger.info('Saved training Samples')

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                avg_lpips = 0.0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data, False)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    if 'HR' in opt['datasets']['val']['mode']:
                        gt_img = util.tensor2img(visuals['HR'])  # uint8
                    log_info = '{}'.format(
                        val_data['HR_path'][0].split('/')[-1])

                    if opt['val_lpips']:
                        lpips = visuals['LPIPS']
                        avg_lpips += lpips
                        log_info += '         LPIPS:{:.3f}'.format(
                            lpips.numpy())
                    if opt['use_domain_distance_map']:
                        ada_w = visuals['adaptive_weights']
                        log_info += '         Adaptive weights:{:.2f}'.format(
                            ada_w.numpy())
                        # logger.info('{} LPIPS: {:.3f}'.format(val_data['HR_path'][0].split('/')[-1], lpips.numpy()))
                        # print('img:', val_data['HR_path'][0].split('/')[-1], 'LPIPS: %.3f' % lpips.numpy())
                    # else:
                    #     print('img:', val_data['LR_path'][0].split('/')[-1])
                    logger.info(log_info)
                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    if 'HR' in opt['datasets']['val']['mode']:
                        crop_size = opt['scale']
                        gt_img = gt_img / 255.
                        sr_img = sr_img / 255.
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        avg_psnr += util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)
                avg_psnr = avg_psnr / idx
                if opt['val_lpips']:
                    avg_lpips = avg_lpips / idx
                    print('Mean LPIPS:', avg_lpips.numpy())
                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                if opt['val_lpips']:
                    logger_val.info(
                        '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, LPIPS: {:.4f}'
                        .format(epoch, current_step, avg_psnr, avg_lpips))
                else:
                    logger_val.info(
                        '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                            epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('LPIPS', avg_lpips, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
示例#10
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    if opt['train']['which_state'] == 'b':
                        sr1_img = util.tensor2img(visuals['SR1'])  # uint8
                        sr2_img = util.tensor2img(visuals['SR2'])  # uint8
                        sr3_img = util.tensor2img(visuals['SR3'])  # uint8
                        save_img_path1 = os.path.join(img_dir, '{:s}_{:d}.png'.format( \
                            img_name + '_loop1', current_step))
                        save_img_path2 = os.path.join(img_dir, '{:s}_{:d}.png'.format( \
                            img_name + '_loop2', current_step))
                        save_img_path3 = os.path.join(img_dir, '{:s}_{:d}.png'.format( \
                            img_name + '_loop3', current_step))
                        util.save_img(sr1_img, save_img_path1)
                        util.save_img(sr2_img, save_img_path2)
                        util.save_img(sr3_img, save_img_path3)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
示例#11
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument("-opt",
                        type=str,
                        required=True,
                        help="Path to option JSON file.")
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt["path"]["resume_state"]:  # resuming training
        resume_state = torch.load(opt["path"]["resume_state"])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt["path"]["experiments_root"])  # rename old folder if exists
        util.mkdirs((path for key, path in opt["path"].items()
                     if not key == "experiments_root"
                     and "pretrain_model" not in key and "resume" not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt["path"]["log"],
                      "train",
                      level=logging.INFO,
                      screen=True)
    util.setup_logger("val", opt["path"]["log"], "val", level=logging.INFO)
    logger = logging.getLogger("base")

    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt["use_tb_logger"] and "debug" not in opt["name"]:
        from tensorboardX import SummaryWriter

        tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"])

    # random seed
    seed = opt["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info("Random seed: {}".format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            logger.info("Number of train images: {:,d}, iters: {:,d}".format(
                len(train_set), train_size))
            total_iters = int(opt["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info("Number of val images in [{:s}]: {:d}".format(
                dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = "<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt["train"]["val_freq"] == 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data["LR_path"][0]))[0]
                    img_dir = os.path.join(opt["path"]["val_images"], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals["SR"])  # uint8
                    gt_img = util.tensor2img(visuals["HR"])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        "{:s}_{:d}.png".format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt["scale"]
                    gt_img = gt_img / 255.0
                    sr_img = sr_img / 255.0
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                logger_val = logging.getLogger("val")  # validation logger
                logger_val.info(
                    "<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}".format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt["use_tb_logger"] and "debug" not in opt["name"]:
                    tb_logger.add_scalar("psnr", avg_psnr, current_step)

            # save models and training states
            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                logger.info("Saving models and training states.")
                model.save(current_step)
                model.save_training_state(epoch, current_step)
                copy_tree(
                    opt["path"]["experiments_root"],
                    "/content/gdrive/My Drive/LVTN/SuperResolution/SR_models/"
                    + "-ESRGAN/experiments/" + opt["name"],
                )

    logger.info("Saving the final model.")
    model.save("latest")
    logger.info("End of training.")
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt',
                        type=str,
                        required=True,
                        help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    PAD = 32

    total_run_time = AverageMeter()
    print("GPU ", torch.cuda.device_count())

    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'],
                           nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'],
                           front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'],
                           HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder)

        if not os.path.exists(subfolder_GT):
            continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(
                    device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX +
                                   1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY +
                                    1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d(
                [0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " +
                  str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([
                intPaddingLeft, intPaddingRight, intPaddingTop,
                intPaddingBottom
            ])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY)
                    and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY)
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :, split_j *
                             split_lengthY:(split_j + 1) * split_lengthY +
                             intPaddingBottom + intPaddingTop, split_i *
                             split_lengthX:(split_i + 1) * split_lengthX +
                             intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop *
                                         scale:(intPaddingTop + intHeight) *
                                         scale, intPaddingLeft *
                                         scale:(intPaddingLeft + intWidth) *
                                         scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)

                y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
                split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \
                    np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " +
                  str(time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            y_all = y_all / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            y_all, GT = util.crop_border([y_all, GT], crop_border)
            crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
示例#13
0
def main():
    opt = option.parse("options/train/train_degnet.json", is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options
    logger.info(option.dict2str(opt))

    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # update learning rate
            model.update_learning_rate()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_ssim = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)
                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    lr_img = util.tensor2img(visuals['FLR'])  # uint8
                    gt_img = util.tensor2img(visuals['LR'])  # uint8

                    # Save degradation map for reference
                    #save_DM_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                    #degradation_img = util.tensor2img(visuals['DM'])  # uint8
                    #util.save_img(degradation_img, save_DM_path[:-4] + '_degradation.png')

                    # Save fake LR
                    save_LR_path = opt['path']['root'] + 'validations/' + opt['name'] + '/{:d}/'.format(current_step)
                    util.mkdir(save_LR_path)
                    save_LR_img_path = os.path.join(save_LR_path, '{:s}.png'.format( \
                        img_name))
                    util.save_img(lr_img, save_LR_img_path)


                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    lr_img = lr_img / 255.
                    cropped_lr_img = lr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_lr_img * 255, cropped_gt_img * 255)
                    avg_ssim += util.calculate_ssim(cropped_lr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx

                # log
                logger.info('# Validation # PSNR: {:.4e} SSIM: {:.4e}'.format(avg_psnr, avg_ssim))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} ssim: {:.4e}'.format(
                epoch, current_step, avg_psnr, avg_ssim))
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('ssim', avg_ssim, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)
    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
示例#14
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/test/test_ppon.json', help='Path to options JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)
            
            img_c = util.tensor2img(visuals['img_c'])  # uint8
            img_s = util.tensor2img(visuals['img_s'])  # uint8
            img_p = util.tensor2img(visuals['img_p'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_c_img_path = os.path.join(dataset_dir, img_name + suffix + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + suffix + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + suffix + '_p.png')                
            else:
                save_c_img_path = os.path.join(dataset_dir, img_name + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + '_p.png')
            
            util.save_img(img_c, save_c_img_path)
            util.save_img(img_s, save_s_img_path)
            util.save_img(img_p, save_p_img_path)
            

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals['HR'])
                gt_img = gt_img / 255.
                sr_img = img_c / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
示例#15
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    print("Model Created! ")

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
        Avg_train_loss_vmaf = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'ssim'):
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'msssim'):
        Avg_train_loss_msssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_msssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
        elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
            Avg_train_loss_vmaf.reset()
        elif (opt['train']['pixel_criterion'] == 'ssim'):
            Avg_train_loss_ssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'msssim'):
            Avg_train_loss_msssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
            Avg_train_loss_pix = AverageMeter()
            Avg_train_loss_msssim = AverageMeter()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))
                util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[0, i, ...])  # uint8
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    util.save_img(LQ_img, save_img_path)

                if (train_idx >= 3):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' vmaf_avg_loss: {:.4e}'.format(
                    Avg_train_loss_vmaf.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                message_train_loss = ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                message_train_loss = ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states
        # if current_step % opt['logger']['save_checkpoint_freq'] == 0:
        #     if rank <= 0:
        #         logger.info('Saving models and training states.')
        #         model.save(current_step)
        #         model.save('latest')
        #         # model.save_training_state(epoch, current_step)
        #         # todo delete previous weights
        #         previous_step = current_step - opt['logger']['save_checkpoint_freq']
        #         save_filename = '{}_{}.pth'.format(previous_step, 'G')
        #         save_path = os.path.join(opt['path']['models'], save_filename)
        #         if os.path.exists(save_path):
        #             os.remove(save_path)

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        print('idx', idx)

                        if 'debug' in opt['name']:
                            if (idx >= 3):
                                break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        if 'debug' in opt['name']:
                            if (val_inx >= 5):
                                break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        print(val_data['LQs'].size())

                        H_S = val_data['LQs'].size(3)  # 540
                        W_S = val_data['LQs'].size(4)  # 960

                        print(H_S)
                        print(W_S)

                        blank_1_S = 0
                        blank_2_S = 0

                        print(val_data['LQs'][0, 2, 0, :, :].size())

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0:
                                blank_1_S = i - 1
                                # assert not sum(data_S[:, :, 0][i+1]) == 0
                                break

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, :,
                                                       H_S - i - 1]) == 0:
                                blank_2_S = (H_S - 1) - i - 1
                                # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                                break
                        print('LQ :', blank_1_S, blank_2_S)

                        if blank_1_S == -1:
                            print('LQ has no blank')
                            blank_1_S = 0
                            blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        print("LQ", val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        blank_1_L = blank_1_S << 2
                        blank_2_L = blank_2_S << 2
                        print(blank_1_L, blank_2_L)

                        print(model.fake_H.size())

                        if not blank_1_S == 0:
                            # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                            model.fake_H[:, :, 0:blank_1_L, :] = 0
                            model.fake_H[:, :, blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average VMAF

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            logger.info("Finished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
            if saved_total_loss >= val_loss_total_avg:
                saved_total_loss = val_loss_total_avg
                #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                model.save('best')
                logger.info(
                    "Best Weights updated for decreased validation loss")

            else:
                logger.info(
                    "Weights Not updated for undecreased validation loss")

            if saved_total_PSNR <= psnr_total_avg:
                saved_total_PSNR = psnr_total_avg
                model.save('bestPSNR')
                logger.info(
                    "Best Weights updated for increased validation PSNR")

            else:
                logger.info(
                    "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
示例#16
0
def val(model_name, current_step, arch='EDVR'):
    #################
    # configurations
    #################
    device = torch.device('cuda')
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4'
    test_set = 'REDS4'  # Vid4 | YouKu10 | REDS4 | AI4K_val
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_LR_png/'
        GT_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_val_small':
        test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_small/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_small/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    model_path = os.path.join('../experiments/', model_name,
                              'models/{}_G.pth'.format(current_step))

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    if arch == 'EDVR':
        model = EDVR_arch.EDVR(64,
                               N_in,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    elif arch == 'MY_EDVR':
        model = my_EDVR_arch.MYEDVR(64,
                                    N_in,
                                    8,
                                    5,
                                    back_RBs,
                                    predeblur=predeblur,
                                    HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = False

    save_folder = '../validation/{}'.format(test_set)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            #print(imgs_in.size())

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))

    return sum(avg_psnr_l) / len(avg_psnr_l)
示例#17
0
def main():
    scale = 4
    N_ot = 5  # 3
    N_in = 1 + N_ot // 2
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # model
    # TODO: change your model path here
    model_path = '../experiments/pretrained_models/xiang2020zooming.pth'
    model = Sakuya_arch.LunaTokis(64, N_ot, 8, 5, 40)

    # dataset
    data_mode = 'Custom'  # 'Vid4' #'SPMC'#'Middlebury'#

    if data_mode == 'Vid4':
        test_dataset_folder = '/data/xiang/SR/Vid4/LR/*'
    if data_mode == 'SPMC':
        test_dataset_folder = '/data/xiang/SR/spmc/*'
    if data_mode == 'Custom':
        test_dataset_folder = '../test_example/*'  # TODO: put your own data path here

    # evaluation
    flip_test = False  # True#
    crop_border = 0

    # temporal padding mode
    padding = 'replicate'
    save_imgs = False  # True#
    if 'Custom' in data_mode:
        save_imgs = True
    ############################################################################
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    model_params = util.get_model_total_params(model)

    # log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Model parameters: {} M'.format(model_params))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip Test: {}'.format(flip_test))

    def single_forward(model, imgs_in):
        with torch.no_grad():
            # imgs_in.size(): [1,n,3,h,w]
            b, n, c, h, w = imgs_in.size()
            h_n = int(4 * np.ceil(h / 4))
            w_n = int(4 * np.ceil(w / 4))
            imgs_temp = imgs_in.new_zeros(b, n, c, h_n, w_n)
            imgs_temp[:, :, :, 0:h, 0:w] = imgs_in
            t0 = time.time()
            model_output = model(imgs_temp)
            t1 = time.time()
            logger.info('Single Test time: {}'.format(t1 - t0))
            # model_output.size(): torch.Size([1, 3, 4h, 4w])
            model_output = model_output[:, :, :, 0:scale * h, 0:scale * w]
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))

    model.load_state_dict(torch.load(model_path), strict=True)

    model.eval()
    model = model.to(device)

    avg_psnr_l = []
    avg_psnr_y_l = []
    sub_folder_name_l = []
    # total_time = []
    # for each sub-folder
    for sub_folder in sub_folder_l:
        gt_tested_list = []
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        if data_mode == 'SPMC':
            sub_folder = sub_folder + '/LR/'
        img_LR_l = sorted(glob.glob(sub_folder + '/*'))

        if save_imgs:
            util.mkdirs(save_sub_folder)

        # read LR images
        imgs = util.read_seq_imgs(sub_folder)
        # read GT images
        img_GT_l = []
        if data_mode == 'SPMC':
            sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/truth/'))
        else:
            sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/HR/'))

        if 'Custom' not in data_mode:
            for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))):
                img_GT_l.append(util.read_image(img_GT_path))

        avg_psnr, avg_psnr_sum, cal_n = 0, 0, 0
        avg_psnr_y, avg_psnr_sum_y = 0, 0

        if len(img_LR_l) == len(img_GT_l):
            skip = True
        else:
            skip = False

        if 'Custom' in data_mode:
            select_idx_list = util.test_index_generation(
                False, N_ot, len(img_LR_l))
        else:
            select_idx_list = util.test_index_generation(
                skip, N_ot, len(img_LR_l))
        # process each image
        for select_idxs in select_idx_list:
            # get input images
            select_idx = select_idxs[0]
            gt_idx = select_idxs[1]
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            output = single_forward(model, imgs_in)
            print(select_idx, gt_idx, imgs_in.size(), output.size())
            outputs = output.data.float().cpu().squeeze(0)

            if flip_test:
                # flip W
                output = single_forward(model, torch.flip(imgs_in, (-1, )))
                output = torch.flip(output, (-1, ))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output
                # flip H
                output = single_forward(model, torch.flip(imgs_in, (-2, )))
                output = torch.flip(output, (-2, ))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output
                # flip both H and W
                output = single_forward(model, torch.flip(imgs_in, (-2, -1)))
                output = torch.flip(output, (-2, -1))
                output = output.data.float().cpu().squeeze(0)
                outputs = outputs + output

                outputs = outputs / 4

            # save imgs
            for idx, name_idx in enumerate(gt_idx):
                if name_idx in gt_tested_list:
                    continue
                gt_tested_list.append(name_idx)
                output_f = outputs[idx, :, :, :].squeeze(0)

                output = util.tensor2img(output_f)
                if save_imgs:
                    cv2.imwrite(
                        osp.join(save_sub_folder,
                                 '{:08d}.png'.format(name_idx + 1)), output)

                if 'Custom' not in data_mode:
                    # calculate PSNR
                    output = output / 255.

                    GT = np.copy(img_GT_l[name_idx])

                    if crop_border == 0:
                        cropped_output = output
                        cropped_GT = GT
                    else:
                        cropped_output = output[crop_border:-crop_border,
                                                crop_border:-crop_border, :]
                        cropped_GT = GT[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                    crt_psnr = util.calculate_psnr(cropped_output * 255,
                                                   cropped_GT * 255)
                    cropped_GT_y = data_util.bgr2ycbcr(cropped_GT, only_y=True)
                    cropped_output_y = data_util.bgr2ycbcr(cropped_output,
                                                           only_y=True)
                    crt_psnr_y = util.calculate_psnr(cropped_output_y * 255,
                                                     cropped_GT_y * 255)
                    logger.info(
                        '{:3d} - {:25}.png \tPSNR: {:.6f} dB  PSNR-Y: {:.6f} dB'
                        .format(name_idx + 1, name_idx + 1, crt_psnr,
                                crt_psnr_y))
                    avg_psnr_sum += crt_psnr
                    avg_psnr_sum_y += crt_psnr_y
                    cal_n += 1

        if 'Custom' not in data_mode:
            avg_psnr = avg_psnr_sum / cal_n
            avg_psnr_y = avg_psnr_sum_y / cal_n

            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} frames; '
                .format(sub_folder_name, avg_psnr, avg_psnr_y, cal_n))

            avg_psnr_l.append(avg_psnr)
            avg_psnr_y_l.append(avg_psnr_y)

    if 'Custom' not in data_mode:
        logger.info('################ Tidy Outputs ################')
        for name, psnr, psnr_y in zip(sub_folder_name_l, avg_psnr_l,
                                      avg_psnr_y_l):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB. '.
                format(name, psnr, psnr_y))
        logger.info('################ Final Results ################')
        logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip Test: {}'.format(flip_test))
        logger.info(
            'Total Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} clips. '.
            format(
                sum(avg_psnr_l) / len(avg_psnr_l),
                sum(avg_psnr_y_l) / len(avg_psnr_y_l), len(sub_folder_l)))
示例#18
0
             pixel_STD = np.mean(pixels_STDs[-1])
         else:
             # normalized_pixel_STD = 0
             pixel_STD = 0
         # Save GT image for reference:
         util.save_img(
             (255 * util.tensor2img(
                 visuals['HR'], out_type=np.float32)).astype(
                     np.uint8),
             os.path.join(
                 dataset_dir,
                 img_name + '_HR_STD%.3f_SR_STD%.3f.png' %
                 (HR_STD, pixel_STD)))
 sr_img *= 255.
 if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0:
     psnr = util.calculate_psnr(sr_img, gt_img)
     ssim = util.calculate_ssim(sr_img, gt_img)
     test_results['psnr'].append(psnr)
     test_results['ssim'].append(ssim)
 if SAVE_IMAGE_COLLAGE:
     if len(test_set) > 1:
         margins2crop = ((np.array(sr_img.shape[:2]) -
                          per_image_saved_patch) / 2).astype(
                              np.int32)
     else:
         margins2crop = [0, 0]
     image_collage[-1].append(
         np.clip(util.crop_center(sr_img, margins2crop), 0,
                 255).astype(np.uint8))
     if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0:
         # Save GT HR images:
示例#19
0
            # save images
            img_part_name = '_coef_{:.2f}.png'.format(coef)
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(img_dir,
                                         img_name + suffix + img_part_name)
            else:
                save_img_path = osp.join(img_dir, img_name + img_part_name)
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(visuals['GT'])
                sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                  opt['scale'])
                psnr = util.calculate_psnr(sr_img, gt_img)
                ssim = util.calculate_ssim(sr_img, gt_img)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                    psnr_y = util.calculate_psnr(sr_img_y * 255,
                                                 gt_img_y * 255)
                    ssim_y = util.calculate_ssim(sr_img_y * 255,
                                                 gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
def main():
    #################
    # configurations
    #################
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic (REDS)

    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode',
                      '-t',
                      type=str,
                      default='Vimeo',
                      help='train mode')
    prog.add_argument('--data_mode',
                      '-m',
                      type=str,
                      default=None,
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')

    args = prog.parse_args()

    train_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    # model
    scale = 2
    N_in = 7
    # model_path = '../experiments/pretrained_models/TOF_official.pth'
    # model_path = '../experiments/Pretrain_FT_TOF_V_S2/models/latest_G.pth'
    model_path = '../experiments/pretrained_models/TOF_{}_S2.pth'.format(
        train_mode)
    adapt_official = True  # if 'official' in model_path else False
    model = TOF_arch.TOFlow(adapt_official=adapt_official)

    #### dataset
    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str(
            '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    # folder_subname = degradation_mode + '_' + str('{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))
    if data_mode == 'Vid4':
        # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format(
            scale)
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format(
            folder_subname, scale)
        GT_dataset_folder = '../dataset/REDS/train/HR'
    '''
    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format(scale)
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        GT_dataset_folder = '../dataset/REDS4/HR'
    '''
    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'  # different from the official setting
    save_imgs = False  #True
    ############################################################################
    device = torch.device('cuda')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder + '/*'))
    sub_folder_GT_l = sorted(glob.glob(GT_dataset_folder + '/*'))
    if data_mode == 'REDS':
        sub_folder_GT_l = [
            k for k in sub_folder_GT_l if k.find('000') >= 0
            or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0
        ]
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    print('Eval')
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []

    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l):
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)

        #### read GT images
        img_GT_l = []
        '''
        if data_mode == 'Vid4':
            sub_folder_GT = osp.join(sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
        else:
            sub_folder_GT = osp.join(sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
        '''
        for img_GT_path in sorted(glob.glob(sub_folder_GT + '/*')):
            img_GT_l.append(read_image(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx,
                                          max_idx,
                                          N_in,
                                          padding=padding)
            # get input images
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).to(device)
            # In TOFLOW we have to upscale the image to HR scale
            imgs_in = F.interpolate(imgs_in,
                                    scale_factor=scale,
                                    mode='bicubic',
                                    align_corners=False)
            #imgs_in = F.interpolate(imgs_in, scale_factor=scale, mode='nearest')
            imgs_in.unsqueeze_(0)
            output = single_forward(model, imgs_in)
            output_f = output.data.float().cpu().squeeze(0)

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(
                    osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)),
                    output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])

            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border, crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border]
            '''
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = 0.001  #util.calculate_ssim(output, GT)
            #crt_ssim = 0.05

            # logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(img_idx + 1, c_idx, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center +
                                                          cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        sub_folder_name, avg_psnr,
                        (cal_n_center + cal_n_border), avg_psnr_center,
                        cal_n_center, avg_psnr_border, cal_n_border))

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (cal_n_center +
                                                          cal_n_border)
        avg_ssim_center = avg_ssim_center / cal_n_center
        avg_ssim_border = 0 if cal_n_border == 0 else avg_ssim_border / cal_n_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(
                        sub_folder_name, avg_ssim,
                        (cal_n_center + cal_n_border), avg_ssim_center,
                        cal_n_center, avg_ssim_border, cal_n_border))
    '''
    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l,
                                                    avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border))
    '''

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(sub_folder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
示例#21
0
def main():
    #################
    # configurations
    #################

    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)

    parser = argparse.ArgumentParser()

    parser.add_argument('dataset')
    parser.add_argument('n_frames')

    args = parser.parse_args()

    data_mode = str(args.dataset)
    N_in = int(args.n_frames)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #    data_mode = 'Vid4'  # Vid4 | sharp_bicubic (REDS)

    # model
    #    N_in = 7
    model_path = '../experiments/pretrained_models/TOF_official.pth'
    adapt_official = True if 'official' in model_path else False
    model = TOF_arch.TOFlow(adapt_official=adapt_official)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4up_direct/*'
    else:
        test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'  # different from the official setting
    save_imgs = True
    ############################################################################
    device = torch.device('cuda')
    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))

    def read_image(img_path):
        '''read one image from img_path
        Return img: HWC, BGR, [0,1], numpy
        '''
        img_GT = cv2.imread(img_path)
        img = img_GT.astype(np.float32) / 255.
        return img

    def read_seq_imgs(img_seq_path):
        '''read a sequence of images'''
        img_path_l = sorted(glob.glob(img_seq_path + '/*'))
        img_l = [read_image(v) for v in img_path_l]
        # stack to TCHW, RGB, [0,1], torch
        imgs = np.stack(img_l, axis=0)
        imgs = imgs[:, :, :, [2, 1, 0]]
        imgs = torch.from_numpy(
            np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
        return imgs

    def index_generation(crt_i, max_n, N, padding='reflection'):
        '''
        padding: replicate | reflection | new_info | circle
        '''
        max_n = max_n - 1
        n_pad = N // 2
        return_l = []

        for i in range(crt_i - n_pad, crt_i + n_pad + 1):
            if i < 0:
                if padding == 'replicate':
                    add_idx = 0
                elif padding == 'reflection':
                    add_idx = -i
                elif padding == 'new_info':
                    add_idx = (crt_i + n_pad) + (-i)
                elif padding == 'circle':
                    add_idx = N + i
                else:
                    raise ValueError('Wrong padding mode')
            elif i > max_n:
                if padding == 'replicate':
                    add_idx = max_n
                elif padding == 'reflection':
                    add_idx = max_n * 2 - i
                elif padding == 'new_info':
                    add_idx = (crt_i - n_pad) - (i - max_n)
                elif padding == 'circle':
                    add_idx = i - N
                else:
                    raise ValueError('Wrong padding mode')
            else:
                add_idx = i
            return_l.append(add_idx)
        return return_l

    def single_forward(model, imgs_in):
        with torch.no_grad():
            model_output = model(imgs_in)
            if isinstance(model_output, list) or isinstance(
                    model_output, tuple):
                output = model_output[0]
            else:
                output = model_output
        return output

    sub_folder_l = sorted(glob.glob(test_dataset_folder))
    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    sub_folder_name_l = []

    # for each sub-folder
    for sub_folder in sub_folder_l:
        sub_folder_name = sub_folder.split('/')[-1]
        sub_folder_name_l.append(sub_folder_name)
        save_sub_folder = osp.join(save_folder, sub_folder_name)

        img_path_l = sorted(glob.glob(sub_folder + '/*'))
        max_idx = len(img_path_l)

        if save_imgs:
            util.mkdirs(save_sub_folder)

        #### read LR images
        imgs = read_seq_imgs(sub_folder)
        #### read GT images
        img_GT_l = []
        if data_mode == 'Vid4':
            sub_folder_GT = osp.join(
                sub_folder.replace('/BIx4up_direct/', '/GT/'), '*')
        else:
            sub_folder_GT = osp.join(
                sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*')
        for img_GT_path in sorted(glob.glob(sub_folder_GT)):
            img_GT_l.append(read_image(img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        cal_n_border, cal_n_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            c_idx = int(osp.splitext(osp.basename(img_path))[0])
            select_idx = index_generation(c_idx,
                                          max_idx,
                                          N_in,
                                          padding=padding)
            # get input images
            imgs_in = imgs.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
            output = single_forward(model, imgs_in)
            output_f = output.data.float().cpu().squeeze(0)

            output = util.tensor2img(output_f)

            # save imgs
            if save_imgs:
                cv2.imwrite(
                    osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)),
                    output)

            #### calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT)
                output = data_util.bgr2ycbcr(output)
            if crop_border == 0:
                cropped_output = output
                cropped_GT = GT
            else:
                cropped_output = output[crop_border:-crop_border,
                                        crop_border:-crop_border]
                cropped_GT = GT[crop_border:-crop_border,
                                crop_border:-crop_border]
            crt_psnr = util.calculate_psnr(cropped_output * 255,
                                           cropped_GT * 255)
            logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format(
                img_idx + 1, c_idx, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                cal_n_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                cal_n_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center +
                                                          cal_n_border)
        avg_psnr_center = avg_psnr_center / cal_n_center
        if cal_n_border == 0:
            avg_psnr_border = 0
        else:
            avg_psnr_border = avg_psnr_border / cal_n_border

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        sub_folder_name, avg_psnr,
                        (cal_n_center + cal_n_border), avg_psnr_center,
                        cal_n_center, avg_psnr_border, cal_n_border))

        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

    logger.info('################ Tidy Outputs ################')
    for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l,
                                                    avg_psnr_l,
                                                    avg_psnr_center_l,
                                                    avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
示例#22
0
def SFTMD_test(opt_F):

    # config loggers. Before it, the log will not work
    util.setup_logger('base',
                      opt_F['path']['log'],
                      'test_' + opt_F['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt_F))

    for phase, dataset_opt in sorted(opt_F['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt, opt_F)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))

        model_F = create_model(opt_F)

        logger.info('\nTesting [{:s}]...'.format(opt_F['model']))
        dataset_dir = os.path.join(opt_F['path']['results_root'],
                                   dataset_opt['name'])
        util.mkdir(dataset_dir)

        avg_psnr = 0.0
        idx = 0
        for _, data in enumerate(test_loader):
            idx += 1
            model_F.feed_data(data)
            model_F.test()

            visuals = model_F.get_current_visuals()
            gt_img = util.tensor2img(visuals['GT'])  # uint8
            sr_img = util.tensor2img(visuals['SR'])  # uint8

            # Save SR images for reference
            img_name = os.path.splitext(os.path.basename(
                data['LQ_path'][0]))[0]
            img_dir = os.path.join(dataset_dir, img_name)
            util.mkdir(img_dir)

            save_img_path = os.path.join(img_dir, '{}.png'.format(img_name))
            util.save_img(sr_img, save_img_path)

            # calculate PSNR
            crop_size = opt_F['scale']
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            cropped_sr_img = sr_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            cropped_gt_img = gt_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                            cropped_gt_img * 255)

            avg_psnr = avg_psnr / idx

            # log
            logger.info('# test # PSNR: {:.4e}'.format(avg_psnr))

    print('SR image of SFTMD has been saved to list')
示例#23
0
def main():

    # Create object for parsing command-line options
    parser = argparse.ArgumentParser(description="Test with EDVR, requre path to test dataset folder.")
    # Add argument which takes path to a bag file as an input
    parser.add_argument("-i", "--input", type=str, help="Path to test folder")
    # Parse the command line arguments to an object
    args = parser.parse_args()
    # Safety if no parameter have been given
    if not args.input:
        print("No input paramater have been given.")
        print("For help type --help")
        exit()

    folder_name = args.input.split("/")[-1]
    if folder_name == '':
        index = len(args.input.split("/")) - 2
        folder_name = args.input.split("/")[index]

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40

    model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        # debug
        test_dataset_folder = os.path.join(args.input, 'BIx4')
        GT_dataset_folder = os.path.join(args.input, 'GT')
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(folder_name)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)
            imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
                                                     psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
示例#24
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='options/train/train_ESRCNN_S2L8_3.json',
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)

    if opt['path']['resume_state']:
        resume_state = torch.load(opt['path']['resume_state'])
    else:
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)

    logger.info(option.dict2str(opt))

    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])

    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True

    # Setup TrainDataLoader
    trainloader = DataLoader(opt['datasets']['train']['dataroot'],
                             split='train')
    train_size = int(
        math.ceil(len(trainloader) / opt['datasets']['train']['batch_size']))
    logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
        len(trainloader), train_size))
    total_iters = int(opt['train']['niter'])
    total_epochs = int(math.ceil(total_iters / train_size))
    logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
        total_epochs, total_iters))
    TrainDataLoader = data.DataLoader(
        trainloader,
        batch_size=opt['datasets']['train']['batch_size'],
        num_workers=12,
        shuffle=True)
    #Setup for validate
    valloader = DataLoader(opt['datasets']['train']['dataroot'], split='val')
    VALDataLoader = data.DataLoader(
        valloader,
        batch_size=opt['datasets']['train']['batch_size'] // 5,
        num_workers=1,
        shuffle=True)
    logger.info('Number of val images:{:d}'.format(len(valloader)))

    # Setup Model
    model = get_model('esrcnn_s2l8_3', opt)

    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)
    else:
        current_step = 0
        start_epoch = 0

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for i, train_data in enumerate(TrainDataLoader):

            current_step += 1
            if current_step > total_iters:
                break

            model.update_learning_rate()
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}>'.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v[0])
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v[0], current_step)
                logger.info(message)

            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                idx = 0
                for i_val, val_data in enumerate(VALDataLoader):
                    idx += 1
                    img_name = val_data[3][0].split('.')[0]
                    model.feed_data(val_data)
                    model.val()

                    visuals = model.get_current_visuals()
                    pred_img = util.tensor2img(visuals['Pred'])
                    gt_img = util.tensor2img(visuals['label'])
                    avg_psnr += util.calculate_psnr(pred_img, gt_img)

                avg_psnr = avg_psnr / idx

                logger.info('# Validation #PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr:{:.4e}'.format(
                        epoch, current_step, avg_psnr))

                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training')
示例#25
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='temp')
    parser.add_argument('--degradation_type', type=str, default=None)
    parser.add_argument('--sigma_x', type=float, default=None)
    parser.add_argument('--sigma_y', type=float, default=None)
    parser.add_argument('--theta', type=float, default=None)
    args = parser.parse_args()
    if args.exp_name == 'temp':
        opt = option.parse(args.opt, is_train=True)
    else:
        opt = option.parse(args.opt, is_train=True, exp_name=args.exp_name)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)
    inner_loop_name = opt['train']['maml']['optimizer'][0] + str(
        opt['train']['maml']['adapt_iter']) + str(
            math.floor(math.log10(opt['train']['maml']['lr_alpha'])))
    meta_loop_name = opt['train']['optim'][0] + str(
        math.floor(math.log10(opt['train']['lr_G'])))

    if args.degradation_type is not None:
        if args.degradation_type == 'preset':
            opt['datasets']['val']['degradation_mode'] = args.degradation_type
        else:
            opt['datasets']['val']['degradation_type'] = args.degradation_type
    if args.sigma_x is not None:
        opt['datasets']['val']['sigma_x'] = args.sigma_x
    if args.sigma_y is not None:
        opt['datasets']['val']['sigma_y'] = args.sigma_y
    if args.theta is not None:
        opt['datasets']['val']['theta'] = args.theta
    if opt['datasets']['val']['degradation_mode'] == 'set':
        degradation_name = str(opt['datasets']['val']['degradation_type'])\
                  + '_' + str(opt['datasets']['val']['sigma_x']) \
                  + '_' + str(opt['datasets']['val']['sigma_y'])\
                  + '_' + str(opt['datasets']['val']['theta'])
    else:
        degradation_name = opt['datasets']['val']['degradation_mode']
    patch_name = 'p{}x{}'.format(
        opt['train']['maml']['patch_size'], opt['train']['maml']
        ['num_patch']) if opt['train']['maml']['use_patch'] else 'full'
    use_real_flag = '_ideal' if opt['train']['use_real'] else ''
    folder_name = opt[
        'name'] + '_' + degradation_name  # + '_' + inner_loop_name + meta_loop_name + '_' + degradation_name + '_' + patch_name + use_real_flag

    if args.exp_name != 'temp':
        folder_name = args.exp_name

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            #util.mkdir_and_rename(
            #    opt['path']['experiments_root'])  # rename experiment folder if exists
            #util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
            #             and 'pretrain_model' not in key and 'resume' not in key))
            if not os.path.exists(opt['path']['experiments_root']):
                os.mkdir(opt['path']['experiments_root'])
                # raise ValueError('Path does not exists - check path')

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        #logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + folder_name)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            pass
        elif phase == 'val':
            if '+' in opt['datasets']['val']['name']:
                val_set, val_loader = [], []
                valname_list = opt['datasets']['val']['name'].split('+')
                for i in range(len(valname_list)):
                    val_set.append(
                        create_dataset(
                            dataset_opt,
                            scale=opt['scale'],
                            kernel_size=opt['datasets']['train']
                            ['kernel_size'],
                            model_name=opt['network_E']['which_model_E'],
                            idx=i))
                    val_loader.append(
                        create_dataloader(val_set[-1], dataset_opt, opt, None))
            else:
                val_set = create_dataset(
                    dataset_opt,
                    scale=opt['scale'],
                    kernel_size=opt['datasets']['train']['kernel_size'],
                    model_name=opt['network_E']['which_model_E'])
                # val_set = loader.get_dataset(opt, train=False)
                val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    models = create_model(opt)
    assert len(models) == 2
    model, est_model = models[0], models[1]
    modelcp, est_modelcp = create_model(opt)
    _, est_model_fixed = create_model(opt)

    center_idx = (opt['datasets']['val']['N_frames']) // 2
    lr_alpha = opt['train']['maml']['lr_alpha']
    update_step = opt['train']['maml']['adapt_iter']

    pd_log = pd.DataFrame(
        columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours'])

    def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44):
        """
        Crop given patches.

        Args:
            LR_seq: (B=1) x T x C x H x W
            HR: (B=1) x C x H x W

            patch_size (int, optional):

        Return:
            B(=batch_size) x T x C x H x W
        """
        # Find the lowest resolution
        cropped_lr = []
        cropped_hr = []
        assert HR.size(0) == 1
        LR_seq_ = LR_seq[0]
        HR_ = HR[0]
        for _ in range(num_patches_for_batch):
            patch_lr, patch_hr = preprocessing.common_crop(
                LR_seq_, HR_, patch_size=patch_size // 2)
            cropped_lr.append(patch_lr)
            cropped_hr.append(patch_hr)

        cropped_lr = torch.stack(cropped_lr, dim=0)
        cropped_hr = torch.stack(cropped_hr, dim=0)

        return cropped_lr, cropped_hr

    # Single GPU
    # PSNR_rlt: psnr_init, psnr_before, psnr_after
    psnr_rlt = [{}, {}]
    # SSIM_rlt: ssim_init, ssim_after
    ssim_rlt = [{}, {}]
    pbar = util.ProgressBar(len(val_set))
    for val_data in val_loader:
        folder = val_data['folder'][0]
        idx_d = int(val_data['idx'][0].split('/')[0])
        if 'name' in val_data.keys():
            name = val_data['name'][0][center_idx][0]
        else:
            #name = '{}/{:08d}'.format(folder, idx_d)
            name = folder

        train_folder = os.path.join('../results_for_paper', folder_name, name)

        hr_train_folder = os.path.join(train_folder, 'hr')
        bic_train_folder = os.path.join(train_folder, 'bic')
        maml_train_folder = os.path.join(train_folder, 'maml')
        #slr_train_folder = os.path.join(train_folder, 'slr')

        # print(train_folder)
        if not os.path.exists(train_folder):
            os.makedirs(train_folder, exist_ok=False)
        if not os.path.exists(hr_train_folder):
            os.mkdir(hr_train_folder)
        if not os.path.exists(bic_train_folder):
            os.mkdir(bic_train_folder)
        if not os.path.exists(maml_train_folder):
            os.mkdir(maml_train_folder)
        #if not os.path.exists(slr_train_folder):
        #    os.mkdir(slr_train_folder)

        for i in range(len(psnr_rlt)):
            if psnr_rlt[i].get(folder, None) is None:
                psnr_rlt[i][folder] = []
        for i in range(len(ssim_rlt)):
            if ssim_rlt[i].get(folder, None) is None:
                ssim_rlt[i][folder] = []

        if idx_d % 10 != 5:
            #continue
            pass

        cropped_meta_train_data = {}
        meta_train_data = {}
        meta_test_data = {}

        # Make SuperLR seq using estimation model
        meta_train_data['GT'] = val_data['LQs'][:, center_idx]
        meta_test_data['LQs'] = val_data['LQs'][0:1]
        meta_test_data['GT'] = val_data['GT'][0:1, center_idx]
        # Check whether the batch size of each validation data is 1
        assert val_data['SuperLQs'].size(0) == 1

        if opt['network_G']['which_model_G'] == 'TOF':
            LQs = meta_test_data['LQs']
            B, T, C, H, W = LQs.shape
            LQs = LQs.reshape(B * T, C, H, W)
            Bic_LQs = F.interpolate(LQs,
                                    scale_factor=opt['scale'],
                                    mode='bicubic',
                                    align_corners=True)
            meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H * opt['scale'],
                                                    W * opt['scale'])

        ## Before start training, first save the bicubic, real outputs
        # Bicubic
        modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG)
        modelcp.feed_data(meta_test_data)
        modelcp.test()
        model_start_visuals = modelcp.get_current_visuals(need_GT=True)
        hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb')
        start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb')

        #####imageio.imwrite(os.path.join(hr_train_folder, '{:08d}.png'.format(idx_d)), hr_image)
        #####imageio.imwrite(os.path.join(bic_train_folder, '{:08d}.png'.format(idx_d)), start_image)
        psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image))
        ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image))

        modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(
            est_model.netE)

        ########## SLR LOSS Preparation ############
        est_model_fixed.load_network(opt['path']['fixed_E'],
                                     est_model_fixed.netE)

        optim_params = []
        for k, v in modelcp.netG.named_parameters():
            if v.requires_grad:
                optim_params.append(v)

        if not opt['train']['use_real']:
            for k, v in est_modelcp.netE.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)

        if opt['train']['maml']['optimizer'] == 'Adam':
            inner_optimizer = torch.optim.Adam(
                optim_params,
                lr=lr_alpha,
                betas=(opt['train']['maml']['beta1'],
                       opt['train']['maml']['beta2']))
        elif opt['train']['maml']['optimizer'] == 'SGD':
            inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha)
        else:
            raise NotImplementedError()

        # Inner Loop Update
        st = time.time()
        for i in range(update_step):
            # Make SuperLR seq using UPDATED estimation model
            if not opt['train']['use_real']:
                est_modelcp.feed_data(val_data)
                # est_model.test()
                est_modelcp.forward_without_optim()
                superlr_seq = est_modelcp.fake_L
                meta_train_data['LQs'] = superlr_seq
            else:
                meta_train_data['LQs'] = val_data['SuperLQs']

            if opt['network_G']['which_model_G'] == 'TOF':
                # Bicubic upsample to match the size
                LQs = meta_train_data['LQs']
                B, T, C, H, W = LQs.shape
                LQs = LQs.reshape(B * T, C, H, W)
                Bic_LQs = F.interpolate(LQs,
                                        scale_factor=opt['scale'],
                                        mode='bicubic',
                                        align_corners=True)
                meta_train_data['LQs'] = Bic_LQs.reshape(
                    B, T, C, H * opt['scale'], W * opt['scale'])

            # Update both modelcp + estmodelcp jointly
            inner_optimizer.zero_grad()
            if opt['train']['maml']['use_patch']:
                cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \
                    crop(meta_train_data['LQs'], meta_train_data['GT'],
                         opt['train']['maml']['num_patch'],
                         opt['train']['maml']['patch_size'])
                modelcp.feed_data(cropped_meta_train_data)
            else:
                modelcp.feed_data(meta_train_data)

            loss_train = modelcp.calculate_loss()

            ##################### SLR LOSS ###################
            est_model_fixed.feed_data(val_data)
            est_model_fixed.test()
            slr_initialized = est_model_fixed.fake_L
            slr_initialized = slr_initialized.to('cuda')
            if opt['network_G']['which_model_G'] == 'TOF':
                loss_train += 10 * F.l1_loss(
                    LQs.to('cuda').squeeze(0), slr_initialized)
            else:
                loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'),
                                             slr_initialized)

            loss_train.backward()
            inner_optimizer.step()

        et = time.time()
        update_time = et - st

        modelcp.feed_data(meta_test_data)
        modelcp.test()

        model_update_visuals = modelcp.get_current_visuals(need_GT=False)
        update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb')
        # Save and calculate final image
        imageio.imwrite(
            os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)),
            update_image)
        psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image))
        ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image))

        name_df = '{}/{:08d}'.format(folder, idx_d)
        if name_df in pd_log.index:
            pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1]
            pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1]
            pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1]
            pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1]
        else:
            pd_log.loc[name_df] = [
                psnr_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]
            ]

        pd_log.to_csv(
            os.path.join('../results_for_paper', folder_name,
                         'psnr_update.csv'))

        pbar.update(
            'Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s'
            .format(folder, idx_d, psnr_rlt[0][folder][-1],
                    ssim_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                    ssim_rlt[1][folder][-1], update_time))

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[0].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[0])
    log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[1].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[1])
    log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[0].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[0])
    log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[1].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[1])
    log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    logger.info('End of evaluation.')
示例#26
0
文件: train.py 项目: TheStarkor/IRN
def main():
    ### parser
    ### diff : cannot support distribution
    parser: ArgumentParser = argparse.ArgumentParser()
    parser.add_argument("-opt", type=str, help="Path to YAML file.")
    args = parser.parse_args()
    opt: Dict[str, Any] = option.parse(args.opt, is_train=True)

    ### mkdir and loggers
    util.mkdir_and_rename(
        opt["path"]["experiments_root"]
    )  # rename experiment folder if exists
    util.mkdirs(
        (
            path
            for key, path in opt["path"].items()
            if not key == "experiments_root"
            and "pretrain_model" not in key
            and "resume" not in key
        )
    )

    util.setup_logger(
        "base", PATH, "train_" + NAME, level=logging.INFO, screen=True, tofile=False
    )
    util.setup_logger(
        "val", PATH, "val_" + NAME, level=logging.INFO, screen=True, tofile=False
    )

    logger: Logger = logging.getLogger("base")

    opt = option.dict_to_nonedict(opt)

    # TODO : tensorboard logger

    ### random seed
    seed: int = MANUAL_SEED
    logger.info(f"Random seed: {seed}")
    util.set_random_seed(seed)

    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/2
    torch.backends.cudnn.benchmark = True

    ### create train and val dataloader
    phase: str
    dataset_opt: Dict[str, Any]
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set: Union[LQGTDataset] = create_dataset(dataset_opt)
            train_size: int = int(math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters: int = int(opt["train"]["niter"])
            total_epochs: int = int(math.ceil(total_iters / train_size))

            train_loader = create_dataloader(train_set, dataset_opt, opt, None)

            logger.info(
                "Number of train images: {:,d}, iters: {:,d}".format(
                    len(train_set), train_size
                )
            )
            logger.info(
                "Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters
                )
            )
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
        else:
            raise NotImplementedError(f"Phase [{phase:s}] is not recognized")

    model = create_model(opt)

    current_step: int = 0
    start_epoch: int = 0

    # TODO : training
    logger.info(f"Start training from epoch: {start_epoch}, iter: {current_step}")
    # for epoch in range(start_epoch, total_epochs + 1):
    for epoch in range(start_epoch, total_epochs + 1):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            ### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            model.update_learning_rate(
                current_step, warmup_iter=opt["train"]["warmup_iter"]
            )

            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = f"<epoch:{epoch:3d}, iter:{current_step:8d}, lr:{model.get_current_learning_rate():3e}> "
                for k, v in logs.items():
                    message += f"{k:s}: {v:.4e} "

                    # TODO: tensorboard
                logger.info(message)

            # validation
            if current_step % opt["train"]["val_freq"] == 0:
                avg_psnr: float = 0.0
                idx: int = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data["LQ_path"][0])
                    )[0]
                    img_dir = os.path.join(opt["path"]["val_images"], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    # TODO: fix
                    if model.test() == False:
                        continue

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals["SR"])  # uint8
                    gt_img = util.tensor2img(visuals["GT"])  # uint8

                    lr_img = util.tensor2img(visuals["LR"])

                    gtl_img = util.tensor2img(visuals["LR_ref"])

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir, "{:s}_{:d}.png".format(img_name, current_step)
                    )
                    util.save_img(sr_img, save_img_path)

                    # Save LR images
                    save_img_path_L = os.path.join(
                        img_dir, "{:s}_forwLR_{:d}.png".format(img_name, current_step)
                    )
                    util.save_img(lr_img, save_img_path_L)

                    # Save ground truth
                    if current_step == opt["train"]["val_freq"]:
                        save_img_path_gt = os.path.join(
                            img_dir, "{:s}_GT_{:d}.png".format(img_name, current_step)
                        )
                        util.save_img(gt_img, save_img_path_gt)
                        save_img_path_gtl = os.path.join(
                            img_dir,
                            "{:s}_LR_ref_{:d}.png".format(img_name, current_step),
                        )
                        util.save_img(gtl_img, save_img_path_gtl)

                    # calculate PSNR
                    crop_size = opt["scale"]
                    gt_img = gt_img / 255.0
                    sr_img = sr_img / 255.0
                    cropped_sr_img = sr_img[
                        crop_size:-crop_size, crop_size:-crop_size, :
                    ]
                    cropped_gt_img = gt_img[
                        crop_size:-crop_size, crop_size:-crop_size, :
                    ]
                    avg_psnr += util.calculate_psnr(
                        cropped_sr_img * 255, cropped_gt_img * 255
                    )

                avg_psnr = avg_psnr / idx

                # log
                logger.info("# Validation # PSNR: {:.4e}.".format(avg_psnr))
                logger_val = logging.getLogger("val")  # validation logger
                logger_val.info(
                    "<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}.".format(
                        epoch, current_step, avg_psnr
                    )
                )
                # TODO: tensorboard

            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                logger.info("Saving models and training states.")
                model.save(current_step)
                model.save_training_state(epoch, current_step)
示例#27
0
def main():
    # Metric path
    metric_path = os.getcwd() + '/utils/metric'
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))

    eng = matlab.engine.connect_matlab()
    names = matlab.engine.find_matlab()
    print('matlab process name: ', names)
    eng.addpath(metric_path)

    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                scores = 0.0
                imrmse = 0.0
                avg_pirm_rmse = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.calculate_psnr(cropped_sr_img_y * 255,
                                                    cropped_gt_img_y * 255)
                    immse = util.mse(cropped_sr_img_y * 255,
                                     cropped_gt_img_y * 255)
                    avg_pirm_rmse += immse
                    scores += eng.calc_NIQE(save_img_path, 4)

                avg_psnr = avg_psnr / idx
                scores = scores / idx
                avg_pirm_rmse = math.sqrt(avg_pirm_rmse / idx)

                # log
                logger.info(
                    '# Validation # PSNR: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(avg_psnr, scores, avg_pirm_rmse))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(epoch, current_step, avg_psnr, scores,
                            avg_pirm_rmse))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('NIQE', scores, current_step)
                    tb_logger.add_scalar('pirm_rmse', avg_pirm_rmse,
                                         current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
示例#28
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='options/test/test_KPSAGAN.yml',
                        help='Path to option YMAL file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            # util.mkdir_and_rename(
            #opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        val_set = create_dataset(dataset_opt)
        val_loader = create_dataloader(val_set, dataset_opt, opt, None)
        if rank <= 0:
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))

    #### create model
    model = create_model(opt)

    avg_psnr = 0.0
    idx = 0
    dataset_dir = '/srv/wuyichao/Super-Resolution/KPSAGAN/BasicSR-master/BasicSR-master-c/result_600000/'
    util.mkdir(dataset_dir)
    for val_data in val_loader:
        idx += 1
        img_name = os.path.splitext(os.path.basename(
            val_data['LQ_path'][0]))[0]
        logger.info(img_name)
        #img_dir = os.path.join(opt['path']['val_images'], img_name)
        #util.mkdir(img_dir)

        model.feed_data(val_data)
        model.test()

        visuals = model.get_current_visuals()
        sr_img = util.tensor2img(visuals['SR'])  # uint8
        gt_img = util.tensor2img(visuals['GT'])  # uint8

        # save images
        suffix = 'cut'
        #opt['suffix']
        if suffix:
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)

        # calculate PSNR
        crop_size = opt['scale']
        gt_img = gt_img / 255.
        sr_img = sr_img / 255.
        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
        cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
        avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                        cropped_gt_img * 255)

    avg_psnr = avg_psnr / idx

    # log
    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
    logger_val = logging.getLogger('val')  # validation logger
    logger_val.info('psnr: {:.4e}'.format(avg_psnr))
示例#29
0
    img_LR = torch.from_numpy(
        np.ascontiguousarray(np.transpose(
            img_LR, (2, 0, 1)))).float().unsqueeze(0).cuda()

    with torch.no_grad():
        begin_time = time.time()
        frame, sr_base, output = model(img_LR)
        end_time = time.time()
        stat_time += (end_time - begin_time)
        #print(end_time-begin_time)

    output = util.tensor2img(output.squeeze(0))
    frame = util.tensor2img(frame.squeeze(0))
    sr_base = util.tensor2img(sr_base.squeeze(0))

    # save images
    save_path_name = osp.join(
        save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name))
    merge = np.concatenate((frame, sr_base, output), axis=1)
    util.save_img(merge, save_path_name)

    # calculate PSNR
    sr_img, gt_img = util.crop_border([output, img_GT], scale)
    PSNR_avg += util.calculate_psnr(sr_img, gt_img)
    SSIM_avg += util.calculate_ssim(sr_img, gt_img)

print('average PSNR: ', PSNR_avg / len(img_list))
print('average SSIM: ', SSIM_avg / len(img_list))
print('time: ', stat_time / len(img_list))
示例#30
0
def main():
    ###### SFTMD train ######
    #### setup options
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-opt_F",
        type=str,
        default="options/train/SFTMD/train_SFTMD_x4.yml",
        help="Path to option YMAL file of SFTMD_Net.",
    )
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_F = option.dict_to_nonedict(opt_F)

    #### random seed
    seed = opt_F["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # load PCA matrix of enough kernel
    print("load PCA matrix")
    pca_matrix = torch.load(
        "../../../pca_matrix/IKC/pca_matrix.pth",
        map_location=lambda storage, loc: storage,
    )
    print("PCA matrix shape: {}".format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt_F["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt_F["dist"] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### loading resume state if exists
    if opt_F["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt_F["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id),
        )
        option.check_resume(opt_F,
                            resume_state["iter"])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:
        if resume_state is None:
            util.mkdir_and_rename(
                opt_F["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_F["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))
            os.system("rm ./log")
            os.symlink(os.path.join(opt_F["path"]["experiments_root"], ".."),
                       "./log")

        # config loggers. Before it, the log will not work
        util.setup_logger(
            "base",
            opt_F["path"]["log"],
            "train_" + opt_F["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        util.setup_logger(
            "val",
            opt_F["path"]["log"],
            "val_" + opt_F["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt_F))
        # tensorboard logger
        if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    "You are using PyTorch {}. Tensorboard will use [tensorboardX]"
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(
                log_dir="log/{}/tb_logger/".format(opt_F["name"]))
    else:
        util.setup_logger("base",
                          opt_F["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt_F["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt_F["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_F["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_F,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_F, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)

    #### resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model_F.resume_training(
            resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_F["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(
                opt_F["scale"],
                pca_matrix,
                random=True,
                para_input=opt_F["code_length"],
                kernel=opt_F["kernel_size"],
                noise=False,
                cuda=True,
                sig=opt_F["sig"],
                sig_min=opt_F["sig_min"],
                sig_max=opt_F["sig_max"],
                rate_iso=1.0,
                scaling=3,
                rate_cln=0.2,
                noise_high=0.0,
            )
            LR_img, ker_map = prepro(train_data["GT"])

            #### update learning rate, schedulers
            model_F.update_learning_rate(
                current_step, warmup_iter=opt_F["train"]["warmup_iter"])

            #### training
            model_F.feed_data(train_data, LR_img, ker_map)
            model_F.optimize_parameters(current_step)

            #### log
            if current_step % opt_F["logger"]["print_freq"] == 0:
                logs = model_F.get_current_log()
                message = "<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format(
                    epoch, current_step, model_F.get_current_learning_rate())
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt_F["train"]["val_freq"] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    idx += 1
                    #### preprocessing for LR_img and kernel map
                    prepro = util.SRMDPreprocessing(
                        opt_F["scale"],
                        pca_matrix,
                        random=True,
                        para_input=opt_F["code_length"],
                        kernel=opt_F["kernel_size"],
                        noise=False,
                        cuda=True,
                        sig=opt_F["sig"],
                        sig_min=opt_F["sig_min"],
                        sig_max=opt_F["sig_max"],
                        rate_iso=1.0,
                        scaling=3,
                        rate_cln=0.2,
                        noise_high=0.0,
                    )
                    LR_img, ker_map = prepro(val_data["GT"])

                    model_F.feed_data(val_data, LR_img, ker_map)
                    model_F.test()

                    visuals = model_F.get_current_visuals()
                    sr_img = util.tensor2img(visuals["SR"])  # uint8
                    gt_img = util.tensor2img(visuals["GT"])  # uint8

                    # Save SR images for reference
                    img_name = os.path.splitext(
                        os.path.basename(val_data["LQ_path"][0]))[0]
                    # img_dir = os.path.join(opt_F['path']['val_images'], img_name)
                    img_dir = os.path.join(opt_F["path"]["val_images"],
                                           str(current_step))
                    util.mkdir(img_dir)

                    save_img_path = os.path.join(
                        img_dir,
                        "{:s}_{:d}.png".format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt_F["scale"]
                    gt_img = gt_img / 255.0
                    sr_img = sr_img / 255.0
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                logger_val = logging.getLogger("val")  # validation logger
                logger_val.info(
                    "<epoch:{:3d}, iter:{:8,d}> psnr: {:.6f}".format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]:
                    tb_logger.add_scalar("psnr", avg_psnr, current_step)

            #### save models and training states
            if current_step % opt_F["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model_F.save(current_step)
                    model_F.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model_F.save("latest")
        logger.info("End of SFTMD training.")