コード例 #1
0
ファイル: bin_model.py プロジェクト: mjt1312/BIN
    def compute_current_psnr_ssim(self, save=False, name=None, save_path=None):

        """
             compute ssim, psnr when validate the model
        """
        num = self.get_info()
        visuals = self.get_current_visuals()

        psnr_interp_t_t = []
        ssim_interp_t_t = []
        for i in range(num):
            rlt_img = util.tensor2img(visuals['rlt'][i])
            gt_img = util.tensor2img(visuals['GT'][i])
            psnr = util.calculate_psnr(rlt_img, gt_img)
            ssim = util.calculate_ssim(rlt_img, gt_img)

            psnr_interp_t_t.append(psnr)
            ssim_interp_t_t.append(ssim)

            if save == True:
                import os.path as osp
                import cv2
                cv2.imwrite(osp.join(save_path, 'rlt_{}_{}.png'.format(name, i)), rlt_img)
                cv2.imwrite(osp.join(save_path, 'gt_{}_{}.png'.format(name, i)), gt_img)

        return psnr_interp_t_t, ssim_interp_t_t
コード例 #2
0
                def _calculate_metrics(sr_vol, gt_vol, view='xy'):
                    sum_psnr = 0.
                    sum_ssim = 0.
                    sum_pdist = 0.
                    # [D,H,W]
                    num_val = 0 # psnr could be inf at xz or yz (near edges), will not calculate
                    for i, vol in enumerate(zip(sr_vol, gt_vol)):
                        sr_img, gt_img = vol[0], vol[1]
                        # calculate PSNR and SSIM
                        # range is assume to be [0,255] so  have to scale back from 1500 to 255 float64
                        crop_size = round(opt['scale'])
                        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size]\
                                             .astype(np.float64) / 1500. * 255.
                        cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size]\
                                             .astype(np.float64) / 1500. * 255.

                        psnr = util.calculate_psnr(cropped_sr_img, cropped_gt_img)
                        ssim = util.calculate_ssim(cropped_sr_img, cropped_gt_img)
                        if opt['datasets']['val']['get_pdist']:
                            pdist = util.calculate_pdist(pdist_model, cropped_sr_img, cropped_gt_img)
                        else:
                            pdist = float('nan')
                        
                        if psnr != float('inf'):
                            num_val += 1
                            sum_psnr += psnr
                            sum_ssim += ssim
                            sum_pdist += pdist
                        logger.info('{:20s} - {:3d}- PSNR: {:.6f} dB; SSIM: {:.6f}; pdist: {:.6f}.'\
                                    .format(patient_id, i+1, psnr, ssim, pdist))

                    pnsr_results[patient_id][view] = sum_psnr / num_val
                    ssim_results[patient_id][view] = sum_ssim / num_val
                    pdist_results[patient_id][view] = sum_pdist / num_val
                    return pnsr_results, ssim_results, pdist_results
コード例 #3
0
def cascade_test_main(opt, logger, model, test_loader):
    test_set_name = test_loader.dataset.opt['name']
    logger.info('\nTesting [{:s}]...'.format(test_set_name))
    results = []

    try:
        total_nfe = model.netG.module.conv_trunk.nfe
    except AttributeError:
        total_nfe = None

    for data in test_loader:
        need_GT = False if test_loader.dataset.opt[
            'dataroot_GT'] is None else True
        model.feed_data(data, need_GT=need_GT)
        img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
        img_name = osp.splitext(osp.basename(img_path))[0]

        model.test()
        visuals = model.get_current_visuals(need_GT=need_GT)
        sr_img = util.tensor2img(visuals['SR'])  # uint8

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.

            crop_border = opt['crop_border'] if opt['crop_border'] else opt[
                'scale']
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

            psnr = util.calculate_psnr(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            # niqe = util.calculate_niqe(cropped_sr_img * 255)

            if total_nfe is not None:
                last_nfe = model.netG.module.conv_trunk.nfe - total_nfe
                total_nfe = model.netG.module.conv_trunk.nfe
            else:
                last_nfe = None

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                if crop_border == 0:
                    cropped_sr_img_y = sr_img_y
                    cropped_gt_img_y = gt_img_y
                else:
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                results.append((psnr, ssim, psnr_y, ssim_y))
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}; NFE: {}'
                    .format(img_name, psnr, ssim, psnr_y, ssim_y, last_nfe))
            else:
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; NFE: {}'.format(
                        img_name, psnr, ssim, last_nfe))
        else:
            logger.info(img_name)

    return results
コード例 #4
0
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)



   # data_mode = str(args.dataset)
   # N_in = int(args.n_frames)
   # metrics = str(args.metrics)
   # output_format = str(args.output_format)


    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).


    # STAGE Vid4
    # Collecting results for Vid4

    model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    predeblur, HR_in = False, False
    back_RBs = 40

    N_model_default = 7
    data_mode = 'Vid4'

   # vid4_dir_map = {"calendar": 0, "city": 1, "foliage": 2, "walk": 3}
    vid4_results = {"calendar": {}, "city": {}, "foliage": {}, "walk": {}}

    #vid4_results = 4 * [[]]

    for N_in in range(1, N_model_default + 1):
        raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
        model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
        aposterior_GT_dataset_folder = '../datasets/Vid4/GT_7'

        crop_border = 0
        border_frame = N_in // 2  # border frames when evaluate
        padding = 'new_info'

        save_imgs = False

        raw_model.load_state_dict(torch.load(model_path), strict=True)

        model.nf = raw_model.nf
        model.center = N_in // 2  # if center is None else center
        model.is_predeblur = raw_model.is_predeblur
        model.HR_in = raw_model.HR_in
        model.w_TSA = raw_model.w_TSA

        if model.is_predeblur:
            model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
            model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
        else:
            if model.HR_in:
                model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            else:
                model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
        model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
        model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

        model.tsa_fusion.center = model.center

        model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
        model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

        model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
        model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
        model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

        model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
        model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
        model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
        model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
        model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
        model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
        model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
        model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
        model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
        model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
        model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

        model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

        model.recon_trunk = raw_model.recon_trunk

        model.upconv1 = raw_model.upconv1
        model.upconv2 = raw_model.upconv2
        model.pixel_shuffle = raw_model.pixel_shuffle
        model.HRconv = raw_model.HRconv
        model.conv_last = raw_model.conv_last

        model.lrelu = raw_model.lrelu

    #####################################################

        model.eval()
        model = model.to(device)

        #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
        subfolder_name_l = []

        subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
        subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

        subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
        for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):
            subfolder_name = osp.basename(subfolder)
            subfolder_name_l.append(subfolder_name)

            img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
            max_idx = len(img_path_l)

            print("MAX_IDX: ", max_idx)


            #### read LQ and GT images
            imgs_LQ = data_util.read_img_seq(subfolder)
            img_GT_l = []
            for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                img_GT_l.append(data_util.read_img(None, img_GT_path))

            img_GT_a = []
            for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                img_GT_a.append(data_util.read_img(None, img_GT_a_path))
            #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
            for img_idx, img_path in enumerate(img_path_l):
                img_name = osp.splitext(osp.basename(img_path))[0]
                select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs_in)
                else:
                    print("IMGS_IN SHAPE: ", imgs_in.shape)
                    output = util.single_forward(model, imgs_in)
                output = util.tensor2img(output.squeeze(0))

                # calculate PSNR
                output = output / 255.
                GT = np.copy(img_GT_l[img_idx])
                # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
                GT_a = np.copy(img_GT_a[img_idx])
                GT_a = data_util.bgr2ycbcr(GT_a, only_y=True)
                output_a = copy.deepcopy(output)

                output, GT = util.crop_border([output, GT], crop_border)
                crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                t = vid4_results[subfolder_name].get(str(img_name))

                if t != None:
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                else:
                    vid4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                    vid4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                    vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                    vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)


    ############################################################################
    #### model



#### writing vid4  results


    util.mkdirs('../results/calendar')
    util.mkdirs('../results/city')
    util.mkdirs('../results/foliage')
    util.mkdirs('../results/walk')
    save_folder = '../results/'

    for i, dir_name in enumerate(["calendar", "city", "foliage", "walk"]):
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in vid4_results[dir_name].items():
         #   cur_result = json.dumps(_)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
                #json.dump(cur_result, outfile)

            #cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)



###################################################################################





    # STAGE REDS

    reds4_results = {"000": {}, "011": {}, "015": {}, "020": {}}
    data_mode = 'sharp_bicubic'

    N_model_default = 5

    for N_in in range(1, N_model_default + 1):
        for stage in range(1,3):

            flip_test = False

            if data_mode == 'sharp_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
            elif data_mode == 'blur_bicubic':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
            elif data_mode == 'blur':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
            elif data_mode == 'blur_comp':
                if stage == 1:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
                else:
                    model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
            else:
                raise NotImplementedError

            predeblur, HR_in = False, False
            back_RBs = 40
            if data_mode == 'blur_bicubic':
                predeblur = True
            if data_mode == 'blur' or data_mode == 'blur_comp':
                predeblur, HR_in = True, True
            if stage == 2:
                HR_in = True
                back_RBs = 20

            if stage == 1:
                test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            else:
                test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
                print('You should modify the test_dataset_folder path for stage 2')
            GT_dataset_folder = '../datasets/REDS4/GT'

            raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)
            model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate
            # temporal padding mode
            if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
                padding = 'new_info'
            else:
                padding = 'replicate'
            save_imgs = True

            data_mode_t = copy.deepcopy(data_mode)
            if stage == 1 and data_mode_t != 'Vid4':
                data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
            save_folder = '../results/{}'.format(data_mode)
            data_mode = copy.deepcopy(data_mode_t)
            util.mkdirs(save_folder)
            util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)


            aposterior_GT_dataset_folder = '../datasets/REDS4/GT_5'

            crop_border = 0
            border_frame = N_in // 2  # border frames when evaluate

            raw_model.load_state_dict(torch.load(model_path), strict=True)

            model.nf = raw_model.nf
            model.center = N_in // 2  # if center is None else center
            model.is_predeblur = raw_model.is_predeblur
            model.HR_in = raw_model.HR_in
            model.w_TSA = raw_model.w_TSA

            if model.is_predeblur:
                model.pre_deblur = raw_model.pre_deblur  # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
                model.conv_1x1 = raw_model.conv_1x1  # nn.Conv2d(nf, nf, 1, 1, bias=True)
            else:
                if model.HR_in:
                    model.conv_first_1 = raw_model.conv_first_1  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
                    model.conv_first_2 = raw_model.conv_first_2  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                    model.conv_first_3 = raw_model.conv_first_3  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
                else:
                    model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.feature_extraction = raw_model.feature_extraction  # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
            model.fea_L2_conv1 = raw_model.fea_L2_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L2_conv2 = raw_model.fea_L2_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            model.fea_L3_conv1 = raw_model.fea_L3_conv1  # nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.fea_L3_conv2 = raw_model.fea_L3_conv2  # nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

            model.pcd_align = raw_model.pcd_align  # PCD_Align(nf=nf, groups=groups)

            model.tsa_fusion.center = model.center

            model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
            model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

            model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion)
            model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
            model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :]))

            model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
            model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
            model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
            model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
            model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
            model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
            model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
            model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
            model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
            model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
            model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

            model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

            model.recon_trunk = raw_model.recon_trunk

            model.upconv1 = raw_model.upconv1
            model.upconv2 = raw_model.upconv2
            model.pixel_shuffle = raw_model.pixel_shuffle
            model.HRconv = raw_model.HRconv
            model.conv_last = raw_model.conv_last

            model.lrelu = raw_model.lrelu

    #####################################################

            model.eval()
            model = model.to(device)

            #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
            subfolder_name_l = []

            subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
            subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))

            subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*")))
    # for each subfolder
            for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l):

                subfolder_name = osp.basename(subfolder)
                subfolder_name_l.append(subfolder_name)
                save_subfolder = osp.join(save_folder, subfolder_name)

                img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
                max_idx = len(img_path_l)

                print("MAX_IDX: ", max_idx)

                print("SAVE FOLDER::::::", save_folder)

                if save_imgs:
                    util.mkdirs(save_subfolder)


            #### read LQ and GT images
                imgs_LQ = data_util.read_img_seq(subfolder)
                img_GT_l = []
                for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
                    img_GT_l.append(data_util.read_img(None, img_GT_path))

                img_GT_a = []
                for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))):
                    img_GT_a.append(data_util.read_img(None, img_GT_a_path))
                #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

            # process each image
                for img_idx, img_path in enumerate(img_path_l):
                    img_name = osp.splitext(osp.basename(img_path))[0]
                    select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding)

                    imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

                    if flip_test:
                        output = util.flipx4_forward(model, imgs_in)
                    else:
                        print("IMGS_IN SHAPE: ", imgs_in.shape)
                        output = util.single_forward(model, imgs_in)
                    output = util.tensor2img(output.squeeze(0))

                    if save_imgs and stage == 1:
                        cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)
                    # calculate PSNR
                    if stage == 2:

                        output = output / 255.
                        GT = np.copy(img_GT_l[img_idx])
                        # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
                        #if data_mode == 'Vid4':  # bgr2y, [0, 1]

                        GT_a = np.copy(img_GT_a[img_idx])
                        output_a = copy.deepcopy(output)

                        output, GT = util.crop_border([output, GT], crop_border)
                        crt_psnr = util.calculate_psnr(output * 255, GT * 255)
                        crt_ssim = util.calculate_ssim(output * 255, GT * 255)

                        output_a, GT_a = util.crop_border([output_a, GT_a], crop_border)

                        crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255)  # CHANGE


                        t = reds4_results[subfolder_name].get(str(img_name))

                        if t != None:
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)
                        else:
                            reds4_results[subfolder_name].update({img_name: metrics_file(img_name)})
                            reds4_results[subfolder_name][img_name].add_psnr(crt_psnr)
                            reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim)
                            reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior)



    ############################################################################
    #### model



#### writing reds4  results

    util.mkdirs('../results/000')
    util.mkdirs('../results/011')
    util.mkdirs('../results/015')
    util.mkdirs('../results/020')
    save_folder = '../results/'

    for i, dir_name in enumerate(["000", "011", "015", "020"]):     #   +
        save_subfolder = osp.join(save_folder, dir_name)
        for j, value in reds4_results[dir_name].items():
           # cur_result = json.dumps(value.__dict__)
            with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile:
                json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
コード例 #5
0
def main():
    ####################
    # arguments parser #
    ####################
    #  [format] dataset(vid4, REDS4) N(number of frames)

    parser = argparse.ArgumentParser()

    parser.add_argument('dataset')
    parser.add_argument('n_frames')
    parser.add_argument('stage')

    args = parser.parse_args()

    data_mode = str(args.dataset)
    N_in = int(args.n_frames)
    stage = int(args.stage)

    #if args.command == 'start':
    #    start(int(args.params[0]))
    #elif args.command == 'stop':
    #    stop(args.params[0], int(args.params[1]))
    #elif args.command == 'stop_all':
    #    stop_all(args.params[0])

    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    #data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    #stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ############################################################################
    #### model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    else:
        raise NotImplementedError

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    #### dataset
    if data_mode == 'Vid4':
        N_model_default = 7
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        N_model_default = 5
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        GT_dataset_folder = '../datasets/REDS4/GT'

    raw_model = EDVR_arch.EDVR(128,
                               N_model_default,
                               8,
                               5,
                               back_RBs,
                               predeblur=predeblur,
                               HR_in=HR_in)
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    data_mode_t = copy.deepcopy(data_mode)
    if stage == 1 and data_mode_t != 'Vid4':
        data_mode = 'REDS-EDVR_REDS_SR_L_flipx4'
    save_folder = '../results/{}'.format(data_mode)
    data_mode = copy.deepcopy(data_mode_t)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    print([a for a in dir(model)
           if not callable(getattr(model, a))])  # not a.startswith('__') and

    #model.load_state_dict(torch.load(model_path), strict=True)
    raw_model.load_state_dict(torch.load(model_path), strict=True)

    #   model.load_state_dict(torch.load(model_path), strict=True)

    #### change model so it can work with less input

    model.nf = raw_model.nf
    model.center = N_in // 2  #  if center is None else center
    model.is_predeblur = raw_model.is_predeblur
    model.HR_in = raw_model.HR_in
    model.w_TSA = raw_model.w_TSA
    #ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)

    #### extract features (for each frame)
    if model.is_predeblur:
        model.pre_deblur = raw_model.pre_deblur  #Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
        model.conv_1x1 = raw_model.conv_1x1  #nn.Conv2d(nf, nf, 1, 1, bias=True)
    else:
        if model.HR_in:
            model.conv_first_1 = raw_model.conv_first_1  #nn.Conv2d(3, nf, 3, 1, 1, bias=True)
            model.conv_first_2 = raw_model.conv_first_2  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
            model.conv_first_3 = raw_model.conv_first_3  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        else:
            model.conv_first = raw_model.conv_first  # nn.Conv2d(3, nf, 3, 1, 1, bias=True)
    model.feature_extraction = raw_model.feature_extraction  #  arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
    model.fea_L2_conv1 = raw_model.fea_L2_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L2_conv2 = raw_model.fea_L2_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
    model.fea_L3_conv1 = raw_model.fea_L3_conv1  #nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
    model.fea_L3_conv2 = raw_model.fea_L3_conv2  #nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

    model.pcd_align = raw_model.pcd_align  #PCD_Align(nf=nf, groups=groups)

    ######## Resize TSA

    model.tsa_fusion.center = model.center
    # temporal attention (before fusion conv)
    model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1
    model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2

    # fusion conv: using 1x1 to save parameters and computation

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)

    #print(raw_model.tsa_fusion.fea_fusion.weight.shape)
    #print(raw_model.tsa_fusion.fea_fusion.weight[127][639].shape)
    #print("MAIN SHAPE(FEA): ", raw_model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.fea_fusion = copy.deepcopy(
        raw_model.tsa_fusion.fea_fusion)
    model.tsa_fusion.fea_fusion.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in *
                                                                  128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.fea_fusion.bias = raw_model.tsa_fusion.fea_fusion.bias

    # spatial attention (after fusion conv)
    model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1)
    model.tsa_fusion.sAtt_1.weight = copy.deepcopy(
        torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in *
                                                              128, :, :]))
    #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
    #model.tsa_fusion.sAtt_1.bias = raw_model.tsa_fusion.sAtt_1.bias

    #print(N_in * 128)
    #print(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :].shape)
    print("MODEL TSA SHAPE: ", model.tsa_fusion.fea_fusion.weight.shape)

    model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool
    model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool
    model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2
    model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3
    model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4
    model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5
    model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1
    model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2
    model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3
    model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1
    model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2

    model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu

    #if model.w_TSA:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
    #else:
    #    model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)

    #   print(self.tsa_fusion)

    #### reconstruction
    model.recon_trunk = raw_model.recon_trunk  # arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
    #### upsampling
    model.upconv1 = raw_model.upconv1  #nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
    model.upconv2 = raw_model.upconv2  #nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
    model.pixel_shuffle = raw_model.pixel_shuffle  # nn.PixelShuffle(2)
    model.HRconv = raw_model.HRconv
    model.conv_last = raw_model.conv_last

    #### activation function
    model.lrelu = raw_model.lrelu

    #####################################################

    model.eval()
    model = model.to(device)

    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)

        print("MAX_IDX: ", max_idx)

        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_ssim, avg_ssim_border, avg_ssim_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            if data_mode == "blur":
                select_idx = data_util.glarefree_index_generation(
                    img_idx, max_idx, N_in, padding=padding)
            else:
                select_idx = data_util.index_generation(
                    img_idx, max_idx, N_in, padding=padding)  #  HERE GOTCHA
            print("SELECT IDX: ", select_idx)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                print("IMGS_IN SHAPE: ", imgs_in.shape)  # check this
                output = util.single_forward(model, imgs_in)  # error here 1
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate SSIM
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, GT = util.crop_border([output, GT], crop_border)
            crt_ssim = util.calculate_ssim(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tSSIM: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_ssim_border += crt_ssim
                N_border += 1

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average SSIM: {:.6f} dB for {} frames; '
                    'Center SSIM: {:.6f} dB for {} frames; '
                    'Border SSIM: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_ssim, (N_center + N_border),
                        avg_ssim_center, N_center, avg_ssim_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, ssim, ssim_center, ssim_border in zip(
            subfolder_name_l, avg_ssim_l, avg_ssim_center_l,
            avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f} dB. '
                    'Center SSIM: {:.6f} dB. '
                    'Border SSIM: {:.6f} dB.'.format(subfolder_name, ssim,
                                                     ssim_center, ssim_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average SSIM: {:.6f} dB for {} clips. '
                'Center SSIM: {:.6f} dB. Border SSIM: {:.6f} dB.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
コード例 #6
0
ファイル: train.py プロジェクト: canhnht/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    # Convert to NoneDict, which return None for missing key.
    opt = option.dict_to_nonedict(opt)
    pytorch_ver = get_pytorch_ver()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(
                glob.glob(
                    os.path.normpath(opt['path']['resume_state']) +
                    '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        # rename old folder if exists
        util.mkdir_and_rename(opt['path']['experiments_root'])
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            # for version tensorboardX >= 1.7
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name'])
        except:
            # for version tensorboardX < 1.6
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
        # updated schedulers in case JSON configuration has changed
        model.update_schedulers(opt['train'])
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for n, train_data in enumerate(train_loader, start=1):
            current_step += 1
            if current_step > total_iters:
                break

            if pytorch_ver == "pre":  # Order for PyTorch ver < 1.1.0
                # update learning rate
                model.update_learning_rate(current_step - 1)
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
            elif pytorch_ver == "post":  # Order for PyTorch ver > 1.1.0
                # training
                model.feed_data(train_data)
                model.optimize_parameters(current_step)
                # update learning rate
                model.update_learning_rate(current_step - 1)
            else:
                print('Error identifying PyTorch version. ', torch.__version__)
                break

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # save models and training states (changed to save models before validation)
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                model.save(current_step)
                model.save_training_state(epoch + (n >= len(train_loader)),
                                          current_step)
                logger.info('Models and training states saved.')

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_ssim = 0.0
                avg_lpips = 0.0
                idx = 0
                val_sr_imgs_list = []
                val_gt_imgs_list = []
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()

                    if opt['datasets']['train'][
                            'znorm']:  # If the image range is [-1,1]
                        sr_img = util.tensor2img(visuals['SR'],
                                                 min_max=(-1, 1))  # uint8
                        gt_img = util.tensor2img(visuals['HR'],
                                                 min_max=(-1, 1))  # uint8
                    else:  # Default: Image range is [0,1]
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # sr_img = util.tensor2img(visuals['SR'])  # uint8
                    # gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # print("Min. SR value:",sr_img.min()) # Debug
                    # print("Max. SR value:",sr_img.max()) # Debug

                    # print("Min. GT value:",gt_img.min()) # Debug
                    # print("Max. GT value:",gt_img.max()) # Debug

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR, SSIM and LPIPS distance
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.

                    # For training models with only one channel ndim==2, if RGB ndim==3, etc.
                    if gt_img.ndim == 2:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                    if sr_img.ndim == 2:
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size]
                    else:  # Default: RGB images
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]

                    # If calculating only once for all images
                    val_gt_imgs_list.append(cropped_gt_img)
                    # If calculating only once for all images
                    val_sr_imgs_list.append(cropped_sr_img)

                    # LPIPS only works for RGB images
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)
                    avg_ssim += util.calculate_ssim(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)
                    # If calculating for each image
                    # avg_lpips += lpips.calculate_lpips([cropped_sr_img], [cropped_gt_img])

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                # avg_lpips=avg_lpips / idx  # If calculating for each image
                # If calculating only once for all images
                avg_lpips = lpips.calculate_lpips(val_sr_imgs_list,
                                                  val_gt_imgs_list)

                # log
                # logger.info('# Validation # PSNR: {:.5g}, SSIM: {:.5g}'.format(avg_psnr, avg_ssim))
                logger.info(
                    '# Validation # PSNR: {:.5g}, SSIM: {:.5g}, LPIPS: {:.5g}'.
                    format(avg_psnr, avg_ssim, avg_lpips))
                logger_val = logging.getLogger('val')  # validation logger
                # logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.5g}, ssim: {:.5g}'.format(
                # epoch, current_step, avg_psnr, avg_ssim))
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.5g}, ssim: {:.5g}, lpips: {:.5g}'
                    .format(epoch, current_step, avg_psnr, avg_ssim,
                            avg_lpips))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('ssim', avg_ssim, current_step)
                    tb_logger.add_scalar('lpips', avg_lpips, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
コード例 #7
0
ファイル: test_vsr.py プロジェクト: ele38/Colab-BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    scale = opt.get('scale', 4)

    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  #TMP
    # znorm_list = []
    '''
    video_list = os.listdir(cfg.testset_dir)
    for idx_video in range(len(video_list)):
        video_name = video_list[idx_video]
        # dataloader
        test_set = TestsetLoader(cfg, video_name)
        test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False)
    '''

    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt.get['znorm'] and znorm == False:
        #     znorm = True
        znorm = dataset_opt.get('znorm', False)
        # znorm_list.apped(znorm)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            # tmp_vis(data['LR'][:,1,:,:,:], True)

            if opt.get('chop_forward', None):
                # data
                if len(data['LR'].size()) == 4:
                    b, n_frames, h_lr, w_lr = data['LR'].size()
                    LR_y_cube = data['LR'].view(b, -1, 1, h_lr,
                                                w_lr)  # b, t, c, h, w
                elif len(data['LR'].size()
                         ) == 5:  #for networks that work with 3 channel images
                    _, n_frames, _, _, _ = data['LR'].size()
                    LR_y_cube = data['LR']  # b, t, c, h, w

                # print(LR_y_cube.shape)
                # print(data['LR_bicubic'].shape)

                # crop borders to ensure each patch can be divisible by 2
                #TODO: this is modcrop, not sure if really needed, check (the dataloader already does modcrop)
                _, _, _, h, w = LR_y_cube.size()
                h = int(h // 16) * 16
                w = int(w // 16) * 16
                LR_y_cube = LR_y_cube[:, :, :, :h, :w]
                if isinstance(data['LR_bicubic'], torch.Tensor):
                    # SR_cb = data['LR_bicubic'][:, 1, :, :][:, :, :h * scale, :w * scale]
                    SR_cb = data['LR_bicubic'][:, 1, :h * scale, :w * scale]
                    # SR_cr = data['LR_bicubic'][:, 2, :, :][:, :, :h * scale, :w * scale]
                    SR_cr = data['LR_bicubic'][:, 2, :h * scale, :w * scale]

                SR_y = chop_forward(LR_y_cube, model, scale,
                                    need_HR=need_HR).squeeze(0)
                # SR_y = np.array(SR_y.data.cpu())
                if test_loader.dataset.opt.get('srcolors', None):
                    print(SR_y.shape, SR_cb.shape, SR_cr.shape)
                    sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr),
                                                      -3))
                else:
                    sr_img = SR_y
            else:
                # data
                model.feed_data(data, need_HR=need_HR)
                # SR_y = net(LR_y_cube).squeeze(0)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                # ds = torch.nn.AvgPool2d(2, stride=2, count_include_pad=False)
                # tmp_vis(ds(visuals['SR']), True)
                # tmp_vis(visuals['SR'], True)
                if test_loader.dataset.opt.get(
                        'y_only', None) and test_loader.dataset.opt.get(
                            'srcolors', None):
                    SR_cb = data['LR_bicubic'][:, 1, :, :]
                    SR_cr = data['LR_bicubic'][:, 2, :, :]
                    # tmp_vis(ds(SR_cb), True)
                    # tmp_vis(ds(SR_cr), True)
                    sr_img = ycbcr_to_rgb(
                        torch.stack((visuals['SR'], SR_cb, SR_cr), -3))
                else:
                    sr_img = visuals['SR']

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            sr_img = tensor2np(sr_img, denormalize=znorm)  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #8
0
def main(opts):
    ################## configurations #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus
    cache_all_imgs = opts.cache > 0
    n_gpus = len(opts.gpus.split(','))
    flip_test, save_imgs = False, False
    scale = 4
    N_in, nf = 5, 64
    back_RBs = 10
    w_TSA = False
    predeblur, HR_in = False, False
    crop_border = 0
    border_frame = N_in // 2
    padding = 'new_info'

    ################## model files ####################
    model_dir = opts.model_dir
    if osp.isfile(model_dir):
        model_names = [osp.basename(model_dir)]
        model_dir = osp.dirname(model_dir)
    elif osp.isdir(model_dir):
        model_names = [
            x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0])
        ]
        model_names = sorted(model_names, key=lambda x: int(x.split("_")[0]))
    else:
        raise IOError('Invalid model_dir: {}'.format(model_dir))

    ################## dataset ########################
    test_subs = sorted(os.listdir(opts.test_dir))
    gt_subs = os.listdir(opts.gt_dir)
    valid_test_subs = [sub in gt_subs for sub in test_subs]
    assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format(
        opts.test_dir)
    scale = float(os.path.basename(os.path.dirname(opts.test_dir))[1:])
    if cache_all_imgs:
        print('Cacheing all testing images ...')
        all_imgs = {}
        for sub in test_subs:
            print('Reading sub-folder: {} ...'.format(sub))
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            all_imgs[sub] = {'test': [], 'gt': []}
            im_names = sorted(os.listdir(test_sub_dir))
            for i, name in enumerate(im_names):
                test_im_path = osp.join(test_sub_dir, name)
                gt_im_path = osp.join(gt_sub_dir, name)
                test_im = cv2.imread(test_im_path,
                                     cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                test_im = test_im.astype(np.float32).transpose(
                    (2, 0, 1)) / 255.
                all_imgs[sub]['test'].append(test_im)
                gt_im = cv2.imread(gt_im_path,
                                   cv2.IMREAD_UNCHANGED).astype(np.float32)
                all_imgs[sub]['gt'].append(gt_im)

    all_psnrs = []
    for model_name in model_names:
        model_path = osp.join(model_dir, model_name)
        exp_name = model_name.split('_')[0]
        if 'meta' in opts.mode.lower():
            model = EDVR_arch.MetaEDVR(nf=nf,
                                       nframes=N_in,
                                       groups=8,
                                       front_RBs=5,
                                       center=None,
                                       back_RBs=back_RBs,
                                       predeblur=predeblur,
                                       HR_in=HR_in,
                                       w_TSA=w_TSA)
        elif opts.mode.lower() == 'edvr':
            model = EDVR_arch.EDVR(nf=nf,
                                   nframes=N_in,
                                   groups=8,
                                   front_RBs=5,
                                   center=None,
                                   back_RBs=back_RBs,
                                   predeblur=predeblur,
                                   HR_in=HR_in,
                                   w_TSA=w_TSA)
        elif opts.mode.lower() == 'upedvr':
            model = EDVR_arch.UPEDVR(nf=nf,
                                     nframes=N_in,
                                     groups=8,
                                     front_RBs=5,
                                     center=None,
                                     back_RBs=10,
                                     w_TSA=w_TSA,
                                     down_scale=True,
                                     align_target=True,
                                     ret_valid=True)
        elif opts.mode.lower() == 'upcont1':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=False)
        elif opts.mode.lower() == 'upcont3':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        elif opts.mode.lower() == 'upcont2':
            model = EDVR_arch.UPControlEDVR(nf=nf,
                                            nframes=N_in,
                                            groups=8,
                                            front_RBs=5,
                                            center=None,
                                            back_RBs=10,
                                            w_TSA=w_TSA,
                                            down_scale=True,
                                            align_target=True,
                                            ret_valid=True,
                                            multi_scale_cont=True)
        else:
            raise TypeError('Unknown model mode: {}'.format(opts.mode))
        save_folder = osp.join(opts.save_dir, exp_name)
        util.mkdirs(save_folder)
        util.setup_logger(exp_name,
                          save_folder,
                          'test',
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger(exp_name)

        #### log info
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))

        #### set up the models
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        if n_gpus > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], []
        avg_ssims, avg_ssim_centers, avg_ssim_borders = [], [], []
        evaled_subs = []

        # for each subfolder
        for sub in test_subs:
            evaled_subs.append(sub)
            test_sub_dir = osp.join(opts.test_dir, sub)
            gt_sub_dir = osp.join(opts.gt_dir, sub)
            img_names = sorted(os.listdir(test_sub_dir))
            max_idx = len(img_names)
            if save_imgs:
                save_subfolder = osp.join(save_folder, sub)
                util.mkdirs(save_subfolder)

            #### get LQ and GT images
            if not cache_all_imgs:
                img_LQs, img_GTs = [], []
                for i, name in enumerate(img_names):
                    test_im_path = osp.join(test_sub_dir, name)
                    gt_im_path = osp.join(gt_sub_dir, name)
                    test_im = cv2.imread(test_im_path,
                                         cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)]
                    test_im = test_im.astype(np.float32).transpose(
                        (2, 0, 1)) / 255.
                    gt_im = cv2.imread(gt_im_path,
                                       cv2.IMREAD_UNCHANGED).astype(np.float32)
                    img_LQs.append(test_im)
                    img_GTs.append(gt_im)
            else:
                img_LQs = all_imgs[sub]['test']
                img_GTs = all_imgs[sub]['gt']

            avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
            avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

            # process each image
            for i in range(0, max_idx, n_gpus):
                end = min(i + n_gpus, max_idx)
                select_idxs = [
                    data_util.index_generation(j,
                                               max_idx,
                                               N_in,
                                               padding=padding)
                    for j in range(i, end)
                ]
                imgs = []
                for select_idx in select_idxs:
                    im = torch.from_numpy(
                        np.stack([img_LQs[k] for k in select_idx]))
                    imgs.append(im)
                if (i + n_gpus) > max_idx:
                    for _ in range(max_idx, i + n_gpus):
                        imgs.append(torch.zeros_like(im))
                imgs = torch.stack(imgs, 0).to(device)

                if flip_test:
                    output = util.flipx4_forward(model, imgs)
                else:
                    if 'meta' in opts.mode.lower():
                        output = util.meta_single_forward(
                            model, imgs, scale, n_gpus)
                    if 'up' in opts.mode.lower():
                        output = util.up_single_forward(model, imgs, scale)
                    else:
                        output = util.single_forward(model, imgs)
                output = [
                    util.tensor2img(x).astype(np.float32) for x in output
                ]

                if save_imgs:
                    for ii in range(i, end):
                        cv2.imwrite(
                            osp.join(save_subfolder,
                                     '{}.png'.format(img_names[ii])),
                            output[ii - i].astype(np.uint8))

                # calculate PSNR
                GT = np.copy(img_GTs[i:end])

                output = util.crop_border(output, crop_border)
                GT = util.crop_border(GT, crop_border)
                for m in range(i, end):
                    crt_psnr = util.calculate_psnr(output[m - i], GT[m - i])
                    crt_ssim = util.calculate_ssim(output[m - i], GT[m - i])
                    logger.info(
                        '{:3d} - {:25} \tPSNR: {:.6f} dB    SSIM: {:.6}'.
                        format(m + 1, img_names[m], crt_psnr, crt_ssim))

                    if m >= border_frame and m < max_idx - border_frame:  # center frames
                        avg_psnr_center += crt_psnr
                        avg_ssim_center += crt_ssim
                        N_center += 1
                    else:  # border frames
                        avg_psnr_border += crt_psnr
                        avg_ssim_border += crt_ssim
                        N_border += 1

            avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center +
                                                              N_border)
            avg_psnr_center = avg_psnr_center / N_center
            avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border

            avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center +
                                                              N_border)
            avg_ssim_center = avg_ssim_center / N_center
            avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border

            avg_psnrs.append(avg_psnr)
            avg_psnr_centers.append(avg_psnr_center)
            avg_psnr_borders.append(avg_psnr_border)

            avg_ssims.append(avg_ssim)
            avg_ssim_centers.append(avg_ssim_center)
            avg_ssim_borders.append(avg_ssim_border)

            logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                        'Center PSNR: {:.6f} dB for {} frames; '
                        'Border PSNR: {:.6f} dB for {} frames.'.format(
                            sub, avg_psnr, (N_center + N_border),
                            avg_psnr_center, N_center, avg_psnr_border,
                            N_border))
            logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                        'Center SSIM: {:.6f} for {} frames; '
                        'Border SSIM: {:.6f} for {} frames.'.format(
                            sub, avg_ssim, (N_center + N_border),
                            avg_ssim_center, N_center, avg_ssim_border,
                            N_border))

        logger.info('################ Tidy Outputs ################')
        for sub_name, psnr, psnr_center, psnr_border, ssim, ssim_center, ssim_border in zip(
                evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders,
                avg_ssims, avg_ssim_centers, avg_ssim_borders):
            logger.info(
                'Folder {} - Average PSNR: {:.6f} dB. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sub_name, psnr, psnr_center, psnr_border))
            logger.info('Folder {} - Average SSIM: {:.6f} '
                        'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                            sub_name, ssim, ssim_center, ssim_border))

        logger.info('################ Final Results ################')
        logger.info('Data: {}'.format(opts.test_dir))
        logger.info('Padding mode: {}'.format(padding))
        logger.info('Model path: {}'.format(model_path))
        logger.info('Save images: {}'.format(save_imgs))
        logger.info('Flip test: {}'.format(flip_test))
        logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                    'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                        sum(avg_psnrs) / len(avg_psnrs), len(test_subs),
                        sum(avg_psnr_centers) / len(avg_psnr_centers),
                        sum(avg_psnr_borders) / len(avg_psnr_borders)))
        logger.info('Total Average SSIM: {:.6f} for {} clips. '
                    'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format(
                        sum(avg_ssims) / len(avg_ssims), len(test_subs),
                        sum(avg_ssim_centers) / len(avg_ssim_centers),
                        sum(avg_ssim_borders) / len(avg_ssim_borders)))
コード例 #9
0
ファイル: test.py プロジェクト: johnnylu305/MIMO-VRN
def cal_pnsr_ssim(sr_img, gt_img, lr_img, lrgt_img):
    # save images
    suffix = opt['suffix']
    if suffix:
        save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '.png')
    else:
        save_img_path = osp.join(dataset_dir, folder, img_name + '.png')
    util.save_img(sr_img, save_img_path)
    #
    # if suffix:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_GT.png')
    # else:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + '_GT.png')
    # util.save_img(gt_img, save_img_path)
    #
    if suffix:
        save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_LR.png')
    else:
        save_img_path = osp.join(dataset_dir, folder, img_name + '_LR.png')
    util.save_img(lr_img, save_img_path)
    #
    # if suffix:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + suffix + '_LR_ref.png')
    # else:
    #     save_img_path = osp.join(dataset_dir, folder, img_name + '_LR_ref.png')
    # util.save_img(lrgt_img, save_img_path)

    # calculate PSNR and SSIM
    gt_img = gt_img / 255.
    sr_img = sr_img / 255.

    lr_img = lr_img / 255.
    lrgt_img = lrgt_img / 255.

    crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale']
    if crop_border == 0:
        cropped_sr_img = sr_img
        cropped_gt_img = gt_img
    else:
        cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
        cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

    psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
    ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
    test_results['psnr'].append(psnr)
    test_results['ssim'].append(ssim)

    # PSNR and SSIM for LR
    psnr_lr = util.calculate_psnr(lr_img * 255, lrgt_img * 255)
    ssim_lr = util.calculate_ssim(lr_img * 255, lrgt_img * 255)
    test_results['psnr_lr'].append(psnr_lr)
    test_results['ssim_lr'].append(ssim_lr)

    if gt_img.shape[2] == 3:  # RGB image
        sr_img_y = bgr2ycbcr(sr_img, only_y=True)
        gt_img_y = bgr2ycbcr(gt_img, only_y=True)
        if crop_border == 0:
            cropped_sr_img_y = sr_img_y
            cropped_gt_img_y = gt_img_y
        else:
            cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
            cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
        psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
        ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
        test_results['psnr_y'].append(psnr_y)
        test_results['ssim_y'].append(ssim_y)

        lr_img_y = bgr2ycbcr(lr_img, only_y=True)
        lrgt_img_y = bgr2ycbcr(lrgt_img, only_y=True)
        psnr_y_lr = util.calculate_psnr(lr_img_y * 255, lrgt_img_y * 255)
        ssim_y_lr = util.calculate_ssim(lr_img_y * 255, lrgt_img_y * 255)
        test_results['psnr_y_lr'].append(psnr_y_lr)
        test_results['ssim_y_lr'].append(ssim_y_lr)

        writer.writerow([osp.join(folder, img_name), psnr_y, psnr_y_lr, ssim_y, ssim_y_lr])
        logger.info(
            '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
                format(osp.join(folder, img_name), psnr, ssim, psnr_y, ssim_y, psnr_lr, ssim_lr, psnr_y_lr, ssim_y_lr))
    else:
        writer.writerow([osp.join(folder, img_name), psnr, psnr_lr])
        logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}. LR PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
            osp.join(folder, img_name), psnr, ssim, psnr_lr, ssim_lr))

    return test_results
コード例 #10
0
ファイル: train_diffuse.py プロジェクト: wwf-3/AdvMCDenoise
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data_diffuse(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_ssim = 0.0
                avg_mrse = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['NOISY_path'][0]))[0]
                    img_dir = opt['path'][
                        'val_images']  #os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data_diffuse(val_data)
                    model.test()
                    if opt["image_type"] == "exr":
                        y = val_data["x_offset"]
                        x = val_data["y_offset"]
                    visuals = model.get_current_visuals()
                    avg_mrse += util.calculate_mrse(
                        visuals["DENOISED"].numpy(), visuals["GT"].numpy())
                    lr_img = util.tensor2img(visuals['NOISY'])
                    sr_img = util.tensor2img(visuals['DENOISED'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    ##############################################################################################
                    # sr_img = util.tensor2img(visuals['DENOISED'])  # uint8
                    # lr_img = util.tensor2img(visuals['NOISY'])
                    # gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # if opt["image_type"] == "exr":
                    #     sr_img = sr_img[y:1280-y, x:1280-x, :]
                    #     lr_img = lr_img[y:1280-y, x:1280-x, :]
                    #     gt_img = gt_img[y:1280-y, x:1280-x, :]

                    ##############################################################################################

                    # Save DENOISED images for reference
                    save_DENOISED_img_path = os.path.join(
                        img_dir, '{:s}_{:d}_1denoised.png'.format(
                            img_name, current_step))
                    save_NOISY_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_0noisy.png'.format(img_name, current_step))
                    save_GT_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}_2gt.png'.format(img_name, current_step))
                    # if current_step % 10000 == 0 :#and idx%100 ==0:
                    #     util.save_img(sr_img, save_DENOISED_img_path)
                    #     util.save_img(lr_img, save_NOISY_img_path)
                    #     util.save_img(gt_img, save_GT_img_path)

                    # calculate PSNR
                    # crop_size = opt['scale']
                    gt_img = gt_img  #/ 255.
                    sr_img = sr_img  #/ 255.
                    # cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    # cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    # avg_psnr += util.calculate_psnr(sr_img * 255, gt_img * 255)
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    avg_ssim += util.calculate_ssim(sr_img, gt_img)

                    ##############################################################################################

                    if opt["image_type"] == "exr" and current_step % 10000 == 0:
                        sr_exr = util.tensor2exr(visuals['DENOISED'])  # uint8
                        lr_exr = util.tensor2exr(visuals['NOISY'])
                        gt_exr = util.tensor2exr(visuals['GT'])  # uint8

                        # sr_exr = sr_exr[y:1280-y, x:1280-x, :]
                        # lr_exr = lr_exr[y:1280-y, x:1280-x, :]
                        # gt_exr = gt_exr[y:1280-y, x:1280-x, :]
                        save_DENOISED_img_path = os.path.join(
                            img_dir, '{:s}_{:d}_1denoised.exr'.format(
                                img_name, current_step))
                        save_NOISY_img_path = os.path.join(
                            img_dir, '{:s}_{:d}_0noisy.exr'.format(
                                img_name, current_step))
                        save_GT_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}_2gt.exr'.format(img_name, current_step))

                        util.saveEXRfromMatrix(save_DENOISED_img_path, sr_exr,
                                               (x, y))
                        util.saveEXRfromMatrix(save_NOISY_img_path, lr_exr,
                                               (x, y))
                        util.saveEXRfromMatrix(save_GT_img_path, gt_exr,
                                               (x, y))


##############################################################################################

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                avg_mrse = avg_mrse / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim))
                logger.info('# Validation # MRSE: {:.4e}'.format(avg_mrse))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} ssim: {:.4e} mrse:  {:.4e}'
                    .format(epoch, current_step, avg_psnr, avg_ssim, avg_mrse))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('ssim', avg_ssim, current_step)
                    tb_logger.add_scalar('mrse', avg_mrse, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
コード例 #11
0
ファイル: train.py プロジェクト: cognaclee/OTSR
def main():
    ymlPath = './options/df2k/train_bicubic_noise.yml'
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default=ymlPath, help='Path to option YMAL file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    if opt['path'].get('resume_state', None):
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
        print('resume_state=', opt['path']['resume_state'])
    else:
        resume_state = None

    if rank <= 0:
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))

        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    opt = option.dict_to_nonedict(opt)

    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True

    dataset_ratio = 1  # enlarge the size of each epoch
    train_loader = None
    val_loader = None
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    model = create_model(opt)

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
            
            loss_and_index = epoch-start_epoch

            if current_step % opt['train']['val_freq'] == 0 and rank <= 0 and val_loader is not None:
                avg_psnr = val_pix_err_f = val_pix_err_nf = val_mean_color_err = avg_ssim = 0.0
                idx = 0
                for val_data in val_loader:
                    if idx > 2:
                        break
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    print("!!!Now start to save Image!!!!")
                    save_img_path = os.path.join(img_dir,
                                                 '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                    avg_ssim += util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                val_pix_err_f /= idx
                val_pix_err_nf /= idx
                val_mean_color_err /= idx

                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                    epoch, current_step, avg_psnr))
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('val_pix_err_f', val_pix_err_f, current_step)
                    tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step)
                    tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step)

            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
コード例 #12
0
def valid(opt, val_loader, current_step, epoch, model, logger, vis, vis_plots):
    """Validation -> get PSNR and SSIM or D_labmda D_s QNR"""
    range_max = val_loader.dataset.dynamic_range
    crop_size = opt['scale']
    # calculate psnr and ssim
    if opt['dataset_type'] == 'reduced':
        avg_psnr = 0.0
        avg_ssim = 0.0
    elif opt['dataset_type'] == 'full':
        avg_D_lambda = 0.0
        avg_D_s = 0.0
        avg_qnr = 0.0
    idx = 0
    for val_data in val_loader:
        idx += 1
        # forward
        model.feed_data(val_data)
        model.test()
        visuals = model.get_current_visuals()
        # Generated image
        sr_img = util.tensor2img(visuals['SR'],
                                 dynamic_range=range_max)  # uint
        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
        # "Reference" image
        if opt['dataset_type'] == 'reduced':
            gt_img = util.tensor2img(visuals['HRx'],
                                     dynamic_range=range_max)  # uint
            cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:
                                    -crop_size, :]
            avg_psnr += util.calculate_psnr(cropped_sr_img, cropped_gt_img,
                                            range_max)
            avg_ssim += util.calculate_ssim(cropped_sr_img, cropped_gt_img,
                                            range_max)
        elif opt['dataset_type'] == 'full':
            # lr_x
            lr_x_img = util.tensor2img(visuals['LRx'], dynamic_range=range_max)
            # lr_p
            lr_p_img = util.tensor2img(visuals['LRp'], dynamic_range=range_max)
            now_qnr, now_D_lambda, now_D_s = util.qnr(sr_img,
                                                      lr_x_img,
                                                      lr_p_img,
                                                      satellite='QuickBird',
                                                      scale=4,
                                                      block_size=32,
                                                      p=1,
                                                      q=1,
                                                      alpha=1,
                                                      beta=1)
            avg_D_lambda += now_D_lambda
            avg_D_s += now_D_s
            avg_qnr += now_qnr
        #  if idx == 10:
        #  break
    if opt['dataset_type'] == 'reduced':
        avg_psnr = avg_psnr / idx
        avg_ssim = avg_ssim / idx
        # log
        logger.info('# Validation # PSNR: {:.4e} # SSIM: {:.4e}'.format(
            avg_psnr, avg_ssim))
        # logger for validation only
        logger_val = logging.getLogger('val')  # validation logger
        logger_val.info(
            '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} ssim: {:.4e}'.format(
                epoch, current_step, avg_psnr, avg_ssim))
        # visdom
        if opt['use_vis_logger'] and 'debug' not in opt['name']:
            util.update_vis(vis, vis_plots['val_psnr'], current_step, avg_psnr)
            util.update_vis(vis, vis_plots['val_ssim'], current_step, avg_ssim)
        return avg_psnr, avg_ssim
    if opt['dataset_type'] == 'full':
        avg_D_lambda = avg_D_lambda / idx
        avg_D_s = avg_D_s / idx
        avg_qnr = avg_qnr / idx
        # log
        logger.info(
            '# Validation # D_lambda: {:.4e} # D_s: {:.4e} # QNR: {:.4e}'.
            format(avg_D_lambda, avg_D_s, avg_qnr))
        # logger for validation only
        logger_val = logging.getLogger('val')  # validation logger
        logger_val.info(
            '<epoch:{:3d}, iter:{:8,d}> D_lambda: {:.4e} D_s: {:.4e} QNR: {:.4e}'
            .format(epoch, current_step, avg_D_lambda, avg_D_s, avg_qnr))
        # visdom
        if opt['use_vis_logger'] and 'debug' not in opt['name']:
            util.update_vis(vis, vis_plots['val_no_ref'], current_step,
                            *(avg_D_lambda, avg_D_s, avg_qnr))
        return avg_D_lambda, avg_D_s, avg_qnr
コード例 #13
0
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

            psnr = util.calculate_psnr(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            psnr_ori_rgb = util.calculate_psnr(lr_img * 255,
                                               cropped_gt_img * 255)
            dpsnr_rgb = psnr - psnr_ori_rgb
            ssim = util.calculate_ssim(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            ssim_ori_rgb = util.calculate_ssim(lr_img * 255,
                                               cropped_gt_img * 255)
            dssim_rgb = ssim - ssim_ori_rgb
            test_results['psnr'].append(psnr)
            test_results['dpsnr_rgb'].append(dpsnr_rgb)
            test_results['ssim'].append(ssim)
            test_results['dssim_rgb'].append(dssim_rgb)

            if gt_img.ndim > 2:
                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    lq_img_y = bgr2ycbcr(lr_img, only_y=True)

                    if opt['enhance_uv']:
コード例 #14
0
ファイル: test_abpn.py プロジェクト: BlueAmulet/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)
    chop2 = opt['chop']
    chop_patch_size = opt['chop_patch_size']
    multi_upscale = opt['multi_upscale']
    scale = opt['scale']

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    # znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        #Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt['znorm'] and znorm == False: 
            # znorm = True
            
    # Create model
    model = create_model(opt)
    
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True
            img_path = data['LR_path'][0] # because there's only 1 image per "data" dataset loader?
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            znorm = test_loader.dataset.opt['znorm']
            
            if chop2==True:
                lowres_img = data['LR'] #.to('cuda')
                if multi_upscale: # Upscale 8 times in different rotations/flips and average the results in a single image
                    LR_90 = lowres_img.transpose(2, 3).flip(2) #PyTorch > 0.4.1
                    LR_180 = LR_90.transpose(2, 3).flip(2) #PyTorch > 0.4.1 
                    LR_270 = LR_180.transpose(2, 3).flip(2) #PyTorch > 0.4.1 
                    LR_f = lowres_img.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_90f = LR_90.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_180f = LR_180.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    LR_270f = LR_270.flip(3) # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1 
                    
                    pred = chop_forward2(lowres_img, model, scale=scale, patch_size=chop_patch_size)
                    pred_90 = chop_forward2(LR_90, model, scale=scale, patch_size=chop_patch_size)
                    pred_180 = chop_forward2(LR_180, model, scale=scale, patch_size=chop_patch_size)
                    pred_270 = chop_forward2(LR_270, model, scale=scale, patch_size=chop_patch_size)
                    pred_f = chop_forward2(LR_f, model, scale=scale, patch_size=chop_patch_size)
                    pred_90f = chop_forward2(LR_90f, model, scale=scale, patch_size=chop_patch_size)
                    pred_180f = chop_forward2(LR_180f, model, scale=scale, patch_size=chop_patch_size)
                    pred_270f = chop_forward2(LR_270f, model, scale=scale, patch_size=chop_patch_size)
                    
                    #convert to numpy array
                    if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                        pred = util.tensor2img(pred,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_90 = util.tensor2img(pred_90,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_180 = util.tensor2img(pred_180,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_270 = util.tensor2img(pred_270,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_f = util.tensor2img(pred_f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_90f = util.tensor2img(pred_90f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_180f = util.tensor2img(pred_180f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                        pred_270f = util.tensor2img(pred_270f,min_max=(-1, 1)).clip(0, 255)  # uint8                        
                    else: # Default: Image range is [0,1]
                        pred = util.tensor2img(pred).clip(0, 255)  # uint8
                        pred_90 = util.tensor2img(pred_90).clip(0, 255)  # uint8
                        pred_180 = util.tensor2img(pred_180).clip(0, 255)  # uint8
                        pred_270 = util.tensor2img(pred_270).clip(0, 255)  # uint8
                        pred_f = util.tensor2img(pred_f).clip(0, 255)  # uint8
                        pred_90f = util.tensor2img(pred_90f).clip(0, 255)  # uint8
                        pred_180f = util.tensor2img(pred_180f).clip(0, 255)  # uint8
                        pred_270f = util.tensor2img(pred_270f).clip(0, 255)  # uint8
                    
                    pred_90 = np.rot90(pred_90, 3)
                    pred_180 = np.rot90(pred_180, 2)
                    pred_270 = np.rot90(pred_270, 1)
                    pred_f = np.fliplr(pred_f)
                    pred_90f = np.rot90(np.fliplr(pred_90f), 3)
                    pred_180f = np.rot90(np.fliplr(pred_180f), 2)
                    pred_270f = np.rot90(np.fliplr(pred_270f), 1)
                    
                    #The reason for overflow is that your NumPy arrays (im1arr im2arr) are of the uint8 type (i.e. 8-bit). This means each element of the array can only hold values up to 255, so when your sum exceeds 255, it loops back around 0:
                    #To avoid overflow, your arrays should be able to contain values beyond 255. You need to convert them to floats for instance, perform the blending operation and convert the result back to uint8:
                    # sr_img = (pred + pred_90 + pred_180 + pred_270 + pred_f + pred_90f + pred_180f + pred_270f) / 8.0
                    sr_img = (pred.astype('float') + pred_90.astype('float') + pred_180.astype('float') + pred_270.astype('float') + pred_f.astype('float') + pred_90f.astype('float') + pred_180f.astype('float') + pred_270f.astype('float')) / 8.0
                    sr_img = sr_img.astype('uint8')                    
                    
                else:
                    highres_output = chop_forward2(lowres_img, model, scale=scale, patch_size=chop_patch_size)
                
                    #convert to numpy array
                    #highres_image = highres_output[0].permute(1, 2, 0).clamp(0.0, 1.0).cpu()
                    if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                        sr_img = util.tensor2img(highres_output,min_max=(-1, 1))  # uint8
                    else: # Default: Image range is [0,1]
                        sr_img = util.tensor2img(highres_output)  # uint8
            
            else: # will upscale each image in the batch without chopping 
                model.feed_data(data, need_HR=need_HR)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                
                if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    sr_img = util.tensor2img(visuals['SR'],min_max=(-1, 1))  # uint8
                else: # Default: Image range is [0,1]
                    sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],min_max=(-1, 1))  # uint8
                else: # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR']) # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #15
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='temp')
    parser.add_argument('--degradation_type', type=str, default=None)
    parser.add_argument('--sigma_x', type=float, default=None)
    parser.add_argument('--sigma_y', type=float, default=None)
    parser.add_argument('--theta', type=float, default=None)
    args = parser.parse_args()
    if args.exp_name == 'temp':
        opt = option.parse(args.opt, is_train=False)
    else:
        opt = option.parse(args.opt, is_train=False, exp_name=args.exp_name)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)
    inner_loop_name = opt['train']['maml']['optimizer'][0] + str(opt['train']['maml']['adapt_iter']) + str(math.floor(math.log10(opt['train']['maml']['lr_alpha'])))
    meta_loop_name = opt['train']['optim'][0] + str(math.floor(math.log10(opt['train']['lr_G'])))

    if args.degradation_type is not None:
        if args.degradation_type == 'preset':
            opt['datasets']['val']['degradation_mode'] = args.degradation_type
        else:
            opt['datasets']['val']['degradation_type'] = args.degradation_type
    if args.sigma_x is not None:
        opt['datasets']['val']['sigma_x'] = args.sigma_x
    if args.sigma_y is not None:
        opt['datasets']['val']['sigma_y'] = args.sigma_y
    if args.theta is not None:
        opt['datasets']['val']['theta'] = args.theta
    
    if 'degradation_mode' not in opt['datasets']['val'].keys():
        degradation_name = ''
    elif opt['datasets']['val']['degradation_mode'] == 'set':
        degradation_name = '_' + str(opt['datasets']['val']['degradation_type'])\
                  + '_' + str(opt['datasets']['val']['sigma_x']) \
                  + '_' + str(opt['datasets']['val']['sigma_y'])\
                  + '_' + str(opt['datasets']['val']['theta'])
    else:
        degradation_name = '_' + opt['datasets']['val']['degradation_mode']
    folder_name = opt['name'] + '_' + degradation_name

    if args.exp_name != 'temp':
        folder_name = args.exp_name

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            pass
        elif phase == 'val':
            if '+' in opt['datasets']['val']['name']:
                raise NotImplementedError('Do not use + signs in test mode')
            else:
                val_set = create_dataset(dataset_opt, scale=opt['scale'],
                                         kernel_size=opt['datasets']['train']['kernel_size'],
                                         model_name=opt['network_E']['which_model_E'])
                # val_set = loader.get_dataset(opt, train=False)
                val_loader = create_dataloader(val_set, dataset_opt, opt, None)

            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    models = create_model(opt)
    assert len(models) == 2
    model, est_model = models[0], models[1]
    modelcp, est_modelcp = create_model(opt)
    _, est_model_fixed = create_model(opt)

    center_idx = (opt['datasets']['val']['N_frames']) // 2
    lr_alpha = opt['train']['maml']['lr_alpha']
    update_step = opt['train']['maml']['adapt_iter']
    with_GT = False if opt['datasets']['val']['mode'] == 'demo' else True

    pd_log = pd.DataFrame(columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours'])

    def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44):
        """
        Crop given patches.

        Args:
            LR_seq: (B=1) x T x C x H x W
            HR: (B=1) x C x H x W

            patch_size (int, optional):

        Return:
            B(=batch_size) x T x C x H x W
        """
        # Find the lowest resolution
        cropped_lr = []
        cropped_hr = []
        assert HR.size(0) == 1
        LR_seq_ = LR_seq[0]
        HR_ = HR[0]
        for _ in range(num_patches_for_batch):
            patch_lr, patch_hr = preprocessing.common_crop(LR_seq_, HR_, patch_size=patch_size // 2)
            cropped_lr.append(patch_lr)
            cropped_hr.append(patch_hr)

        cropped_lr = torch.stack(cropped_lr, dim=0)
        cropped_hr = torch.stack(cropped_hr, dim=0)

        return cropped_lr, cropped_hr

    # Single GPU
    # PSNR_rlt: psnr_init, psnr_before, psnr_after
    psnr_rlt = [{}, {}]
    # SSIM_rlt: ssim_init, ssim_after
    ssim_rlt = [{}, {}]

    pbar = util.ProgressBar(len(val_set))
    for val_data in val_loader:
        folder = val_data['folder'][0]
        idx_d = int(val_data['idx'][0].split('/')[0])
        if 'name' in val_data.keys():
            name = val_data['name'][0][center_idx][0]
        else:
            name = folder

        train_folder = os.path.join('../test_results', folder_name, name)
        maml_train_folder = os.path.join(train_folder, 'DynaVSR')

        if not os.path.exists(train_folder):
            os.makedirs(train_folder, exist_ok=False)
        if not os.path.exists(maml_train_folder):
            os.mkdir(maml_train_folder)

        for i in range(len(psnr_rlt)):
            if psnr_rlt[i].get(folder, None) is None:
                psnr_rlt[i][folder] = []
        for i in range(len(ssim_rlt)):
            if ssim_rlt[i].get(folder, None) is None:
                ssim_rlt[i][folder] = []
        
        cropped_meta_train_data = {}
        meta_train_data = {}
        meta_test_data = {}

        # Make SuperLR seq using estimation model
        meta_train_data['GT'] = val_data['LQs'][:, center_idx]
        meta_test_data['LQs'] = val_data['LQs'][0:1]
        meta_test_data['GT'] = val_data['GT'][0:1, center_idx] if with_GT else None
        # Check whether the batch size of each validation data is 1
        assert val_data['LQs'].size(0) == 1

        if opt['network_G']['which_model_G'] == 'TOF':
            LQs = meta_test_data['LQs']
            B, T, C, H, W = LQs.shape
            LQs = LQs.reshape(B*T, C, H, W)
            Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True)
            meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale'])
        
        ## Before start testing
        # Bicubic Model Results
        modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG)
        modelcp.feed_data(meta_test_data, need_GT=with_GT)
        modelcp.test()

        if with_GT:
            model_start_visuals = modelcp.get_current_visuals(need_GT=True)
            hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb')
            start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb')
            psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image))
            ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image))

        modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(est_model.netE)

        ########## SLR LOSS Preparation ############
        est_model_fixed.load_network(opt['path']['fixed_E'], est_model_fixed.netE)

        optim_params = []
        for k, v in modelcp.netG.named_parameters():
            if v.requires_grad:
                optim_params.append(v)
        
        if not opt['train']['use_real']:
            for k, v in est_modelcp.netE.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
        
        if opt['train']['maml']['optimizer'] == 'Adam':
            inner_optimizer = torch.optim.Adam(optim_params, lr=lr_alpha,
                                               betas=(
                                                   opt['train']['maml']['beta1'],
                                                   opt['train']['maml']['beta2']))
        elif opt['train']['maml']['optimizer'] == 'SGD':
            inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha)
        else:
            raise NotImplementedError()

        # Inner Loop Update
        st = time.time()
        for i in range(update_step):
            # Make SuperLR seq using UPDATED estimation model
            if not opt['train']['use_real']:
                est_modelcp.feed_data(val_data)
                est_modelcp.forward_without_optim()
                superlr_seq = est_modelcp.fake_L
                meta_train_data['LQs'] = superlr_seq
            else:
                meta_train_data['LQs'] = val_data['SuperLQs']

            if opt['network_G']['which_model_G'] == 'TOF':
                # Bicubic upsample to match the size
                LQs = meta_train_data['LQs']
                B, T, C, H, W = LQs.shape
                LQs = LQs.reshape(B*T, C, H, W)
                Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True)
                meta_train_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale'])

            # Update both modelcp + estmodelcp jointly
            inner_optimizer.zero_grad()
            if opt['train']['maml']['use_patch']:
                cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \
                    crop(meta_train_data['LQs'], meta_train_data['GT'],
                         opt['train']['maml']['num_patch'],
                         opt['train']['maml']['patch_size'])
                modelcp.feed_data(cropped_meta_train_data)
            else:
                modelcp.feed_data(meta_train_data)

            loss_train = modelcp.calculate_loss()
            
            ##################### SLR LOSS ###################
            est_model_fixed.feed_data(val_data)
            est_model_fixed.test()
            slr_initialized = est_model_fixed.fake_L
            slr_initialized = slr_initialized.to('cuda')
            if opt['network_G']['which_model_G'] == 'TOF':
                loss_train += 10 * F.l1_loss(LQs.to('cuda').squeeze(0), slr_initialized)
            else:
                loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'), slr_initialized)
            
            loss_train.backward()
            inner_optimizer.step()

        et = time.time()
        update_time = et - st

        modelcp.feed_data(meta_test_data, need_GT=with_GT)
        modelcp.test()

        model_update_visuals = modelcp.get_current_visuals(need_GT=False)
        update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb')
        # Save and calculate final image
        imageio.imwrite(os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)), update_image)

        if with_GT:
            psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image))
            ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image))

            name_df = '{}/{:08d}'.format(folder, idx_d)
            if name_df in pd_log.index:
                pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1]
                pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1]
                pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1]
                pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1]
            else:
                pd_log.loc[name_df] = [psnr_rlt[0][folder][-1],
                                    psnr_rlt[1][folder][-1],
                                    ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]]

            pd_log.to_csv(os.path.join('../test_results', folder_name, 'psnr_update.csv'))

            pbar.update('Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s'
                            .format(folder, idx_d,
                                    psnr_rlt[0][folder][-1], ssim_rlt[0][folder][-1],
                                    psnr_rlt[1][folder][-1], ssim_rlt[1][folder][-1],
                                    update_time
                                    ))
        else:
            pbar.update()

    if with_GT:
        psnr_rlt_avg = {}
        psnr_total_avg = 0.
        # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
        for k, v in psnr_rlt[0].items():
            psnr_rlt_avg[k] = sum(v) / len(v)
            psnr_total_avg += psnr_rlt_avg[k]
        psnr_total_avg /= len(psnr_rlt[0])
        log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg)
        for k, v in psnr_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        psnr_rlt_avg = {}
        psnr_total_avg = 0.
        # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
        for k, v in psnr_rlt[1].items():
            psnr_rlt_avg[k] = sum(v) / len(v)
            psnr_total_avg += psnr_rlt_avg[k]
        psnr_total_avg /= len(psnr_rlt[1])
        log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
        for k, v in psnr_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        ssim_rlt_avg = {}
        ssim_total_avg = 0.
        # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
        for k, v in ssim_rlt[0].items():
            ssim_rlt_avg[k] = sum(v) / len(v)
            ssim_total_avg += ssim_rlt_avg[k]
        ssim_total_avg /= len(ssim_rlt[0])
        log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg)
        for k, v in ssim_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        ssim_rlt_avg = {}
        ssim_total_avg = 0.
        # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
        for k, v in ssim_rlt[1].items():
            ssim_rlt_avg[k] = sum(v) / len(v)
            ssim_total_avg += ssim_rlt_avg[k]
        ssim_total_avg /= len(ssim_rlt[1])
        log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg)
        for k, v in ssim_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

    print('End of evaluation.')
コード例 #16
0
ファイル: test.py プロジェクト: shekdat789/-ESRGAN
def main(jsonPath):
    # options
    opt = option.parse(jsonPath, is_train=False)
    util.mkdirs((path for key, path in opt["path"].items()
                 if not key == "pretrain_model_G"))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt["path"]["log"],
                      "test.log",
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger("base")
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt["datasets"].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info("Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt["name"]
        logger.info("\nTesting [{:s}]...".format(test_set_name))
        # test_start_time = time.time()
        dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results["psnr"] = []
        test_results["ssim"] = []
        test_results["psnr_y"] = []
        test_results["ssim_y"] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                "dataroot_HR"] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data["LR_path"][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            sr_img = util.tensor2img(visuals["SR"])  # uint8

            # save images
            suffix = opt["suffix"]
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + ".png")
            else:
                save_img_path = os.path.join(dataset_dir, img_name + ".png")
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals["HR"])
                gt_img = gt_img / 255.0
                sr_img = sr_img / 255.0

                crop_border = test_loader.dataset.opt["scale"]
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results["psnr"].append(psnr)
                test_results["ssim"].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}."
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
            ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
            logger.info(
                "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n"
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results["psnr_y"] and test_results["ssim_y"]:
                ave_psnr_y = sum(test_results["psnr_y"]) / len(
                    test_results["psnr_y"])
                ave_ssim_y = sum(test_results["ssim_y"]) / len(
                    test_results["ssim_y"])
                logger.info(
                    "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n"
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #17
0
ファイル: test_dynavsr.py プロジェクト: zhigaloff/DynaVSR
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='temp')
    parser.add_argument('--degradation_type', type=str, default=None)
    parser.add_argument('--sigma_x', type=float, default=None)
    parser.add_argument('--sigma_y', type=float, default=None)
    parser.add_argument('--theta', type=float, default=None)
    args = parser.parse_args()
    if args.exp_name == 'temp':
        opt = option.parse(args.opt, is_train=True)
    else:
        opt = option.parse(args.opt, is_train=True, exp_name=args.exp_name)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)
    inner_loop_name = opt['train']['maml']['optimizer'][0] + str(
        opt['train']['maml']['adapt_iter']) + str(
            math.floor(math.log10(opt['train']['maml']['lr_alpha'])))
    meta_loop_name = opt['train']['optim'][0] + str(
        math.floor(math.log10(opt['train']['lr_G'])))

    if args.degradation_type is not None:
        if args.degradation_type == 'preset':
            opt['datasets']['val']['degradation_mode'] = args.degradation_type
        else:
            opt['datasets']['val']['degradation_type'] = args.degradation_type
    if args.sigma_x is not None:
        opt['datasets']['val']['sigma_x'] = args.sigma_x
    if args.sigma_y is not None:
        opt['datasets']['val']['sigma_y'] = args.sigma_y
    if args.theta is not None:
        opt['datasets']['val']['theta'] = args.theta
    if opt['datasets']['val']['degradation_mode'] == 'set':
        degradation_name = str(opt['datasets']['val']['degradation_type'])\
                  + '_' + str(opt['datasets']['val']['sigma_x']) \
                  + '_' + str(opt['datasets']['val']['sigma_y'])\
                  + '_' + str(opt['datasets']['val']['theta'])
    else:
        degradation_name = opt['datasets']['val']['degradation_mode']
    patch_name = 'p{}x{}'.format(
        opt['train']['maml']['patch_size'], opt['train']['maml']
        ['num_patch']) if opt['train']['maml']['use_patch'] else 'full'
    use_real_flag = '_ideal' if opt['train']['use_real'] else ''
    folder_name = opt[
        'name'] + '_' + degradation_name  # + '_' + inner_loop_name + meta_loop_name + '_' + degradation_name + '_' + patch_name + use_real_flag

    if args.exp_name != 'temp':
        folder_name = args.exp_name

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            #util.mkdir_and_rename(
            #    opt['path']['experiments_root'])  # rename experiment folder if exists
            #util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
            #             and 'pretrain_model' not in key and 'resume' not in key))
            if not os.path.exists(opt['path']['experiments_root']):
                os.mkdir(opt['path']['experiments_root'])
                # raise ValueError('Path does not exists - check path')

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        #logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + folder_name)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            pass
        elif phase == 'val':
            if '+' in opt['datasets']['val']['name']:
                val_set, val_loader = [], []
                valname_list = opt['datasets']['val']['name'].split('+')
                for i in range(len(valname_list)):
                    val_set.append(
                        create_dataset(
                            dataset_opt,
                            scale=opt['scale'],
                            kernel_size=opt['datasets']['train']
                            ['kernel_size'],
                            model_name=opt['network_E']['which_model_E'],
                            idx=i))
                    val_loader.append(
                        create_dataloader(val_set[-1], dataset_opt, opt, None))
            else:
                val_set = create_dataset(
                    dataset_opt,
                    scale=opt['scale'],
                    kernel_size=opt['datasets']['train']['kernel_size'],
                    model_name=opt['network_E']['which_model_E'])
                # val_set = loader.get_dataset(opt, train=False)
                val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    models = create_model(opt)
    assert len(models) == 2
    model, est_model = models[0], models[1]
    modelcp, est_modelcp = create_model(opt)
    _, est_model_fixed = create_model(opt)

    center_idx = (opt['datasets']['val']['N_frames']) // 2
    lr_alpha = opt['train']['maml']['lr_alpha']
    update_step = opt['train']['maml']['adapt_iter']

    pd_log = pd.DataFrame(
        columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours'])

    def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44):
        """
        Crop given patches.

        Args:
            LR_seq: (B=1) x T x C x H x W
            HR: (B=1) x C x H x W

            patch_size (int, optional):

        Return:
            B(=batch_size) x T x C x H x W
        """
        # Find the lowest resolution
        cropped_lr = []
        cropped_hr = []
        assert HR.size(0) == 1
        LR_seq_ = LR_seq[0]
        HR_ = HR[0]
        for _ in range(num_patches_for_batch):
            patch_lr, patch_hr = preprocessing.common_crop(
                LR_seq_, HR_, patch_size=patch_size // 2)
            cropped_lr.append(patch_lr)
            cropped_hr.append(patch_hr)

        cropped_lr = torch.stack(cropped_lr, dim=0)
        cropped_hr = torch.stack(cropped_hr, dim=0)

        return cropped_lr, cropped_hr

    # Single GPU
    # PSNR_rlt: psnr_init, psnr_before, psnr_after
    psnr_rlt = [{}, {}]
    # SSIM_rlt: ssim_init, ssim_after
    ssim_rlt = [{}, {}]
    pbar = util.ProgressBar(len(val_set))
    for val_data in val_loader:
        folder = val_data['folder'][0]
        idx_d = int(val_data['idx'][0].split('/')[0])
        if 'name' in val_data.keys():
            name = val_data['name'][0][center_idx][0]
        else:
            #name = '{}/{:08d}'.format(folder, idx_d)
            name = folder

        train_folder = os.path.join('../results_for_paper', folder_name, name)

        hr_train_folder = os.path.join(train_folder, 'hr')
        bic_train_folder = os.path.join(train_folder, 'bic')
        maml_train_folder = os.path.join(train_folder, 'maml')
        #slr_train_folder = os.path.join(train_folder, 'slr')

        # print(train_folder)
        if not os.path.exists(train_folder):
            os.makedirs(train_folder, exist_ok=False)
        if not os.path.exists(hr_train_folder):
            os.mkdir(hr_train_folder)
        if not os.path.exists(bic_train_folder):
            os.mkdir(bic_train_folder)
        if not os.path.exists(maml_train_folder):
            os.mkdir(maml_train_folder)
        #if not os.path.exists(slr_train_folder):
        #    os.mkdir(slr_train_folder)

        for i in range(len(psnr_rlt)):
            if psnr_rlt[i].get(folder, None) is None:
                psnr_rlt[i][folder] = []
        for i in range(len(ssim_rlt)):
            if ssim_rlt[i].get(folder, None) is None:
                ssim_rlt[i][folder] = []

        if idx_d % 10 != 5:
            #continue
            pass

        cropped_meta_train_data = {}
        meta_train_data = {}
        meta_test_data = {}

        # Make SuperLR seq using estimation model
        meta_train_data['GT'] = val_data['LQs'][:, center_idx]
        meta_test_data['LQs'] = val_data['LQs'][0:1]
        meta_test_data['GT'] = val_data['GT'][0:1, center_idx]
        # Check whether the batch size of each validation data is 1
        assert val_data['SuperLQs'].size(0) == 1

        if opt['network_G']['which_model_G'] == 'TOF':
            LQs = meta_test_data['LQs']
            B, T, C, H, W = LQs.shape
            LQs = LQs.reshape(B * T, C, H, W)
            Bic_LQs = F.interpolate(LQs,
                                    scale_factor=opt['scale'],
                                    mode='bicubic',
                                    align_corners=True)
            meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H * opt['scale'],
                                                    W * opt['scale'])

        ## Before start training, first save the bicubic, real outputs
        # Bicubic
        modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG)
        modelcp.feed_data(meta_test_data)
        modelcp.test()
        model_start_visuals = modelcp.get_current_visuals(need_GT=True)
        hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb')
        start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb')

        #####imageio.imwrite(os.path.join(hr_train_folder, '{:08d}.png'.format(idx_d)), hr_image)
        #####imageio.imwrite(os.path.join(bic_train_folder, '{:08d}.png'.format(idx_d)), start_image)
        psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image))
        ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image))

        modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(
            est_model.netE)

        ########## SLR LOSS Preparation ############
        est_model_fixed.load_network(opt['path']['fixed_E'],
                                     est_model_fixed.netE)

        optim_params = []
        for k, v in modelcp.netG.named_parameters():
            if v.requires_grad:
                optim_params.append(v)

        if not opt['train']['use_real']:
            for k, v in est_modelcp.netE.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)

        if opt['train']['maml']['optimizer'] == 'Adam':
            inner_optimizer = torch.optim.Adam(
                optim_params,
                lr=lr_alpha,
                betas=(opt['train']['maml']['beta1'],
                       opt['train']['maml']['beta2']))
        elif opt['train']['maml']['optimizer'] == 'SGD':
            inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha)
        else:
            raise NotImplementedError()

        # Inner Loop Update
        st = time.time()
        for i in range(update_step):
            # Make SuperLR seq using UPDATED estimation model
            if not opt['train']['use_real']:
                est_modelcp.feed_data(val_data)
                # est_model.test()
                est_modelcp.forward_without_optim()
                superlr_seq = est_modelcp.fake_L
                meta_train_data['LQs'] = superlr_seq
            else:
                meta_train_data['LQs'] = val_data['SuperLQs']

            if opt['network_G']['which_model_G'] == 'TOF':
                # Bicubic upsample to match the size
                LQs = meta_train_data['LQs']
                B, T, C, H, W = LQs.shape
                LQs = LQs.reshape(B * T, C, H, W)
                Bic_LQs = F.interpolate(LQs,
                                        scale_factor=opt['scale'],
                                        mode='bicubic',
                                        align_corners=True)
                meta_train_data['LQs'] = Bic_LQs.reshape(
                    B, T, C, H * opt['scale'], W * opt['scale'])

            # Update both modelcp + estmodelcp jointly
            inner_optimizer.zero_grad()
            if opt['train']['maml']['use_patch']:
                cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \
                    crop(meta_train_data['LQs'], meta_train_data['GT'],
                         opt['train']['maml']['num_patch'],
                         opt['train']['maml']['patch_size'])
                modelcp.feed_data(cropped_meta_train_data)
            else:
                modelcp.feed_data(meta_train_data)

            loss_train = modelcp.calculate_loss()

            ##################### SLR LOSS ###################
            est_model_fixed.feed_data(val_data)
            est_model_fixed.test()
            slr_initialized = est_model_fixed.fake_L
            slr_initialized = slr_initialized.to('cuda')
            if opt['network_G']['which_model_G'] == 'TOF':
                loss_train += 10 * F.l1_loss(
                    LQs.to('cuda').squeeze(0), slr_initialized)
            else:
                loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'),
                                             slr_initialized)

            loss_train.backward()
            inner_optimizer.step()

        et = time.time()
        update_time = et - st

        modelcp.feed_data(meta_test_data)
        modelcp.test()

        model_update_visuals = modelcp.get_current_visuals(need_GT=False)
        update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb')
        # Save and calculate final image
        imageio.imwrite(
            os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)),
            update_image)
        psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image))
        ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image))

        name_df = '{}/{:08d}'.format(folder, idx_d)
        if name_df in pd_log.index:
            pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1]
            pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1]
            pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1]
            pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1]
        else:
            pd_log.loc[name_df] = [
                psnr_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]
            ]

        pd_log.to_csv(
            os.path.join('../results_for_paper', folder_name,
                         'psnr_update.csv'))

        pbar.update(
            'Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s'
            .format(folder, idx_d, psnr_rlt[0][folder][-1],
                    ssim_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                    ssim_rlt[1][folder][-1], update_time))

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[0].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[0])
    log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[1].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[1])
    log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[0].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[0])
    log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[1].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[1])
    log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    logger.info('End of evaluation.')
コード例 #18
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options YMAL file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    opt = option.dict_to_nonedict(opt)

    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
    util.setup_logger('base',
                      opt['path']['log'],
                      'test_' + opt['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    #### Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    model = create_model(opt)
    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_GT = False if test_loader.dataset.opt[
                'dataroot_GT'] is None else True
            model.feed_data(data, need_GT=need_GT)
            img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
            img_name = osp.splitext(osp.basename(img_path))[0]

            model.test()
            visuals = model.get_current_visuals(need_GT=need_GT)

            sr_img = util.tensor2img(visuals['rlt'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = osp.join(dataset_dir,
                                         img_name + suffix + '.png')
            else:
                save_img_path = osp.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_GT:
                gt_img = util.tensor2img(visuals['GT'])
                sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                  opt['scale'])
                psnr = util.calculate_psnr(sr_img, gt_img)
                ssim = util.calculate_ssim(sr_img, gt_img)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                    psnr_y = util.calculate_psnr(sr_img_y * 255,
                                                 gt_img_y * 255)
                    ssim_y = util.calculate_ssim(sr_img_y * 255,
                                                 gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_GT:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info(
                '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info(
                    '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #19
0
ファイル: test.py プロジェクト: xian1234/SRBuildSeg
    img_LR = torch.from_numpy(
        np.ascontiguousarray(np.transpose(
            img_LR, (2, 0, 1)))).float().unsqueeze(0).cuda()

    with torch.no_grad():
        begin_time = time.time()
        frame, sr_base, output = model(img_LR)
        end_time = time.time()
        stat_time += (end_time - begin_time)
        #print(end_time-begin_time)

    output = util.tensor2img(output.squeeze(0))
    frame = util.tensor2img(frame.squeeze(0))
    sr_base = util.tensor2img(sr_base.squeeze(0))

    # save images
    save_path_name = osp.join(
        save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name))
    merge = np.concatenate((frame, sr_base, output), axis=1)
    util.save_img(merge, save_path_name)

    # calculate PSNR
    sr_img, gt_img = util.crop_border([output, img_GT], scale)
    PSNR_avg += util.calculate_psnr(sr_img, gt_img)
    SSIM_avg += util.calculate_ssim(sr_img, gt_img)

print('average PSNR: ', PSNR_avg / len(img_list))
print('average SSIM: ', SSIM_avg / len(img_list))
print('time: ', stat_time / len(img_list))
コード例 #20
0
opt = option.dict_to_nonedict(
    opt)  # Convert to NoneDict, which return None for missing key.

dataset = DstlDataset(opt['datasets']['val'])

model = create_model(opt).netG.module
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.cuda()

print('sftgan testing...')

for i in range(len(dataset)):
    data = dataset[i]
    img_HR = data['HR']
    img_LR = data['LR']
    output = model(img_LR.cuda()).data

    img_HR = img_HR.numpy()
    output = output.squeeze().cpu().numpy()

    img_HR = denormalize(img_HR)
    output = denormalize(output)

    psnr = util.calculate_psnr(img_HR, output)
    ssim = util.calculate_ssim(img_HR, output)
    print(psnr, ssim)

    util.save_img(img_HR, os.path.join('../results', '{}_hr.png'.format(i)))
    util.save_img(output, os.path.join('../results', '{}_fake.png'.format(i)))
コード例 #21
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in [
                        'sr', 'srgan'
                ] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    avg_ssim = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        # we don't need GT data to be run with the model
                        gt_img = util.tensor2img(val_data['GT'])  # uint8
                        val_data['GT'] = torch.tensor(0)

                        import sys
                        #print(val_data['LQ'].shape)
                        model.feed_data(val_data, need_GT=False)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8

                        if "mask" in visuals:
                            mask = util.tensor2img(visuals['mask'])
                            # print(mask)
                            # Save SR mask for reference
                            save_img_path = os.path.join(
                                img_dir, '{:s}_{:d}_mask.png'.format(
                                    img_name, current_step))
                            util.save_img(mask, save_img_path)

                            #gt_mask_dir = os.path.join(opt['path']['gt_mask'], img_name)

                            #gt_mask = util.tensor2img(val_data['GT_mask'])
                            #print(gt_mask)
                            #import cv2
                            #cv2.imshow("window", gt_mask)
                            #cv2.waitKey()
                            #util.mkdir(gt_mask_dir)
                            #save_img_path = os.path.join(gt_mask_dir,
                            #                            '{:s}_{:d}_GT_mask.png'.format(img_name, current_step))
                            #util.save_img(gt_mask, save_img_path)

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        avg_ssim += util.calculate_ssim(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx
                    avg_ssim = avg_ssim / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                        tb_logger.add_scalar('ssim', avg_ssim, current_step)

                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        tb_logger.close()
コード例 #22
0
ファイル: test_ppon.py プロジェクト: ngcthuong/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/test/test_ppon.json', help='Path to options JSON file.')

    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)
            
            img_c = util.tensor2img(visuals['img_c'])  # uint8
            img_s = util.tensor2img(visuals['img_s'])  # uint8
            img_p = util.tensor2img(visuals['img_p'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_c_img_path = os.path.join(dataset_dir, img_name + suffix + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + suffix + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + suffix + '_p.png')                
            else:
                save_c_img_path = os.path.join(dataset_dir, img_name + '_c.png')
                save_s_img_path = os.path.join(dataset_dir, img_name + '_s.png')
                save_p_img_path = os.path.join(dataset_dir, img_name + '_p.png')
            
            util.save_img(img_c, save_c_img_path)
            util.save_img(img_s, save_s_img_path)
            util.save_img(img_p, save_p_img_path)
            

            # calculate PSNR and SSIM
            if need_HR:
                gt_img = util.tensor2img(visuals['HR'])
                gt_img = gt_img / 255.
                sr_img = img_c / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #23
0
         else:
             # normalized_pixel_STD = 0
             pixel_STD = 0
         # Save GT image for reference:
         util.save_img(
             (255 * util.tensor2img(
                 visuals['HR'], out_type=np.float32)).astype(
                     np.uint8),
             os.path.join(
                 dataset_dir,
                 img_name + '_HR_STD%.3f_SR_STD%.3f.png' %
                 (HR_STD, pixel_STD)))
 sr_img *= 255.
 if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0:
     psnr = util.calculate_psnr(sr_img, gt_img)
     ssim = util.calculate_ssim(sr_img, gt_img)
     test_results['psnr'].append(psnr)
     test_results['ssim'].append(ssim)
 if SAVE_IMAGE_COLLAGE:
     if len(test_set) > 1:
         margins2crop = ((np.array(sr_img.shape[:2]) -
                          per_image_saved_patch) / 2).astype(
                              np.int32)
     else:
         margins2crop = [0, 0]
     image_collage[-1].append(
         np.clip(util.crop_center(sr_img, margins2crop), 0,
                 255).astype(np.uint8))
     if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0:
         # Save GT HR images:
         GT_image_collage[-1].append(
コード例 #24
0
ファイル: train_degnet.py プロジェクト: FVL2020/SRDRL
def main():
    opt = option.parse("options/train/train_degnet.json", is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options
    logger.info(option.dict2str(opt))

    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # update learning rate
            model.update_learning_rate()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_ssim = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)
                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    lr_img = util.tensor2img(visuals['FLR'])  # uint8
                    gt_img = util.tensor2img(visuals['LR'])  # uint8

                    # Save degradation map for reference
                    #save_DM_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                    #degradation_img = util.tensor2img(visuals['DM'])  # uint8
                    #util.save_img(degradation_img, save_DM_path[:-4] + '_degradation.png')

                    # Save fake LR
                    save_LR_path = opt['path']['root'] + 'validations/' + opt['name'] + '/{:d}/'.format(current_step)
                    util.mkdir(save_LR_path)
                    save_LR_img_path = os.path.join(save_LR_path, '{:s}.png'.format( \
                        img_name))
                    util.save_img(lr_img, save_LR_img_path)


                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    lr_img = lr_img / 255.
                    cropped_lr_img = lr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_lr_img * 255, cropped_gt_img * 255)
                    avg_ssim += util.calculate_ssim(cropped_lr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx

                # log
                logger.info('# Validation # PSNR: {:.4e} SSIM: {:.4e}'.format(avg_psnr, avg_ssim))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} ssim: {:.4e}'.format(
                epoch, current_step, avg_psnr, avg_ssim))
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('ssim', avg_ssim, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)
    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
コード例 #25
0
ファイル: test.py プロジェクト: tianxiaguixin002/DASR
        if need_HR:
            gt_img = util.tensor2img(visuals['HR'])
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            if opt['val_lpips']:
                lpips = visuals['LPIPS'].numpy()

            crop_border = test_loader.dataset.opt['scale']
            cropped_sr_img = sr_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]
            cropped_gt_img = gt_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]

            psnr = util.calculate_psnr(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255,
                                       cropped_gt_img * 255)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            if opt['val_lpips']:
                test_results['lpips'].append(lpips)

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                            crop_border:-crop_border]
                cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                            crop_border:-crop_border]
                psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                             cropped_gt_img_y * 255)
                ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
コード例 #26
0
         quantized_image = util.tensor2img(
             model.Return_Compressed(
                 gt_im_YCbCr.to(model.device)),
             out_type=np.uint8,
             min_max=[0, 255],
             chroma_mode=chroma_mode)
         # quantized_image = util.tensor2img(model.jpeg_extractor(model.jpeg_compressor(data['Uncomp'])), out_type=np.uint8,min_max=[0, 255],chroma_mode=chroma_mode)
         if SAVE_QUANTIZED:
             util.save_img(
                 quantized_image,
                 os.path.join(dataset_dir + '_Quant',
                              img_name + suffix + '.png'))
         test_results['psnr_quantized'].append(
             util.calculate_psnr(quantized_image, gt_img))
         test_results['ssim_quantized'].append(
             util.calculate_ssim(quantized_image, gt_img))
     psnr = util.calculate_psnr(sr_img, gt_img)
     ssim = util.calculate_ssim(sr_img, gt_img)
     test_results['psnr'].append(psnr)
     test_results['ssim'].append(ssim)
 if SAVE_IMAGE_COLLAGE:
     if len(test_set) > 1:
         margins2crop = ((np.array(sr_img.shape[:2]) -
                          per_image_saved_patch) / 2).astype(
                              np.int32)
     else:
         margins2crop = [0, 0]
     image_collage[-1].append(
         np.clip(util.crop_center(sr_img, margins2crop), 0,
                 255).astype(np.uint8))
     if LATENT_DISTRIBUTION in NON_ARBITRARY_Z_INPUTS or cur_channel_cur_Z == 0:
コード例 #27
0
ファイル: test.py プロジェクト: JMU2021/DESRGAN
        sr_img = util.tensor2img(visuals['rlt'])  # uint8

        # save images
        suffix = opt['suffix']
        if suffix:
            save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
        else:
            save_img_path = osp.join(dataset_dir, img_name + '.png')
        util.save_img(sr_img, save_img_path)

        # calculate PSNR and SSIM
        if need_GT:
            gt_img = util.tensor2img(visuals['GT'])
            sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
            psnr = util.calculate_psnr(sr_img, gt_img)
            ssim = util.calculate_ssim(sr_img, gt_img)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)

            if gt_img.shape[2] == 3:  # RGB image
                sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
                gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)

                psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
                ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                logger.info(
                    '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                    .format(img_name, psnr, ssim, psnr_y, ssim_y))
            else:
コード例 #28
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    print("Model Created! ")

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
        Avg_train_loss_vmaf = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'ssim'):
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'msssim'):
        Avg_train_loss_msssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_msssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
        elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
            Avg_train_loss_vmaf.reset()
        elif (opt['train']['pixel_criterion'] == 'ssim'):
            Avg_train_loss_ssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'msssim'):
            Avg_train_loss_msssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
            Avg_train_loss_pix = AverageMeter()
            Avg_train_loss_msssim = AverageMeter()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))
                util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[0, i, ...])  # uint8
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    util.save_img(LQ_img, save_img_path)

                if (train_idx >= 3):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' vmaf_avg_loss: {:.4e}'.format(
                    Avg_train_loss_vmaf.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                message_train_loss = ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                message_train_loss = ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states
        # if current_step % opt['logger']['save_checkpoint_freq'] == 0:
        #     if rank <= 0:
        #         logger.info('Saving models and training states.')
        #         model.save(current_step)
        #         model.save('latest')
        #         # model.save_training_state(epoch, current_step)
        #         # todo delete previous weights
        #         previous_step = current_step - opt['logger']['save_checkpoint_freq']
        #         save_filename = '{}_{}.pth'.format(previous_step, 'G')
        #         save_path = os.path.join(opt['path']['models'], save_filename)
        #         if os.path.exists(save_path):
        #             os.remove(save_path)

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        print('idx', idx)

                        if 'debug' in opt['name']:
                            if (idx >= 3):
                                break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        if 'debug' in opt['name']:
                            if (val_inx >= 5):
                                break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        print(val_data['LQs'].size())

                        H_S = val_data['LQs'].size(3)  # 540
                        W_S = val_data['LQs'].size(4)  # 960

                        print(H_S)
                        print(W_S)

                        blank_1_S = 0
                        blank_2_S = 0

                        print(val_data['LQs'][0, 2, 0, :, :].size())

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0:
                                blank_1_S = i - 1
                                # assert not sum(data_S[:, :, 0][i+1]) == 0
                                break

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, :,
                                                       H_S - i - 1]) == 0:
                                blank_2_S = (H_S - 1) - i - 1
                                # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                                break
                        print('LQ :', blank_1_S, blank_2_S)

                        if blank_1_S == -1:
                            print('LQ has no blank')
                            blank_1_S = 0
                            blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        print("LQ", val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        blank_1_L = blank_1_S << 2
                        blank_2_L = blank_2_S << 2
                        print(blank_1_L, blank_2_L)

                        print(model.fake_H.size())

                        if not blank_1_S == 0:
                            # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                            model.fake_H[:, :, 0:blank_1_L, :] = 0
                            model.fake_H[:, :, blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average VMAF

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            logger.info("Finished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
            if saved_total_loss >= val_loss_total_avg:
                saved_total_loss = val_loss_total_avg
                #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                model.save('best')
                logger.info(
                    "Best Weights updated for decreased validation loss")

            else:
                logger.info(
                    "Weights Not updated for undecreased validation loss")

            if saved_total_PSNR <= psnr_total_avg:
                saved_total_PSNR = psnr_total_avg
                model.save('bestPSNR')
                logger.info(
                    "Best Weights updated for increased validation PSNR")

            else:
                logger.info(
                    "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
コード例 #29
0
ファイル: test.py プロジェクト: BlueAmulet/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                sr_img = util.tensor2img(visuals['SR'],
                                         min_max=(-1, 1))  # uint8
            else:  # Default: Image range is [0,1]
                sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],
                                             min_max=(-1, 1))  # uint8
                else:  # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR'])  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
コード例 #30
0
def main():
    #################
    # configurations
    #################
    flip_test = False
    scale = 4
    N_in = 7
    predeblur, HR_in = False, False
    n_feats = 128
    back_RBs = 40

    save_imgs = False
    prog = argparse.ArgumentParser()
    prog.add_argument('--train_mode', '-t', type=str, default='REDS', help='train mode')
    prog.add_argument('--data_mode', '-m', type=str, default=None, help='data_mode')
    prog.add_argument('--degradation_mode', '-d', type=str, default='impulse', choices=('impulse', 'bicubic', 'preset'), help='path to image output directory.')
    prog.add_argument('--sigma_x', '-sx', type=float, default=1, help='sigma_x')
    prog.add_argument('--sigma_y', '-sy', type=float, default=0, help='sigma_y')
    prog.add_argument('--theta', '-th', type=float, default=0, help='theta')
    prog.add_argument('--model', type=str, default=None, help='name for subdirectory')

    args = prog.parse_args()

    train_data_mode = args.train_mode
    data_mode = args.data_mode
    if data_mode is None:
        if train_data_mode == 'Vimeo':
            data_mode = 'Vid4'
        elif train_data_mode == 'REDS':
            data_mode = 'REDS'
    degradation_mode = args.degradation_mode  # impulse | bicubic | preset
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta
    if sig_y == 0:
        sig_y = sig_x

    folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
        '{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the))

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../dataset/Vid4/LR_{}/X1_KG_ZSSR'.format(folder_subname)
        #test_dataset_folder = '../dataset/Vid4/LR_{}/X1_CF_DBPN'.format(folder_subname)
        GT_dataset_folder = '../dataset/Vid4/HR'
    elif data_mode == 'MM522':
        test_dataset_folder = '../dataset/MM522val/LR_bicubic/X1_KG_ZSSR'
        GT_dataset_folder = '../dataset/MM522val/HR'
    else:
        # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale)
        test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_KG_ZSSR'.format(folder_subname)
        #test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_CF_DBPN'.format(folder_subname)
        GT_dataset_folder = '../dataset/REDS/train/HR'

    #### evaluation
    crop_border = 0
    border_frame = 0 # N_in // 2  # border frames when evaluate
    # temporal padding mode
    padding = 'new_info'

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    if data_mode == 'REDS':
        subfolder_GT_l = [k for k in subfolder_GT_l if
                          k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0]
    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder_GT, '*5.png')))  ## *5.png
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        img_LQ_l = []
        img_GT_l = []
        for img_LQ_path in sorted(glob.glob(osp.join(subfolder, '*5.png.png'))):
            img_LQ_l.append(data_util.read_img(None, img_LQ_path))

        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*5.png'))):   ### *5.png
            img_GT_l.append(data_util.read_img(None, img_GT_path))

        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0
        avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]

            if save_imgs:
                cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output)

            # calculate PSNR
            output = np.copy(img_LQ_l[img_idx])
            GT = np.copy(img_GT_l[img_idx])
            '''
            output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1)
            GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor)
            torch.save(output_tensor.cpu(), '../results/sr_test.pt')
            torch.save(GT_tensor.cpu(), '../results/hr_test.pt')
            my_psnr = utility.calc_psnr(output_tensor, GT_tensor)
            GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0)
            imageio.imwrite('../results/hr_test.png', GT_tensor)
            print('saved', my_psnr)
            '''
            '''
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''
            output = (output * 255).round().astype('uint8')
            GT = (GT * 255).round().astype('uint8')
            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output, GT)
            crt_ssim = util.calculate_ssim(output, GT)

            # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                avg_ssim_center += crt_ssim
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                avg_ssim_border += crt_ssim
                N_border += 1
        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border)
        avg_ssim_center = avg_ssim_center / N_center
        avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border
        avg_ssim_l.append(avg_ssim)
        avg_ssim_center_l.append(avg_ssim_center)
        avg_ssim_border_l.append(avg_ssim_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
                                                                   (N_center + N_border),
                                                                   avg_psnr_center, N_center,
                                                                   avg_psnr_border, N_border))

        logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; '
                    'Center SSIM: {:.6f} for {} frames; '
                    'Border SSIM: {:.6f} for {} frames.'.format(subfolder_name, avg_ssim,
                                                                   (N_center + N_border),
                                                                   avg_ssim_center, N_center,
                                                                   avg_ssim_border, N_border))

    '''
    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
                                                              avg_psnr_center_l, avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border))
    for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l,
                                                              avg_ssim_center_l, avg_ssim_border_l):
        logger.info('Folder {} - Average SSIM: {:.6f}. '
                    'Center SSIM: {:.6f}. '
                    'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border))
    '''
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
    logger.info('Total Average SSIM: {:.6f} for {} clips. '
                'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format(
                    sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l),
                    sum(avg_ssim_center_l) / len(avg_ssim_center_l),
                    sum(avg_ssim_border_l) / len(avg_ssim_border_l)))

    print('\n\n\n')