Esempio n. 1
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'],
                                    upscale=opt_net['scale'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 2
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])

    elif which_model == 'RCAN':
        netG = RCAN_arch.RCAN(n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'],
                              res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'],rgb_range=opt_net['rgb_range'],
                              scale=opt_net['scale'],reduction=opt_net['reduction'],n_resgroups=opt_net['n_resgroups'])
    elif which_model == 'CARN_M':
        netG = CARN_arch.CARN_M(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], scale=opt_net['scale'], group=opt_net['group'])

    elif which_model == 'fsrcnn':
        netG = FSRCNN_arch.FSRCNN_net(input_channels=opt_net['in_nc'],upscale=opt_net['scale'],d=opt_net['d'],
                                      s=opt_net['s'],m=opt_net['m'])


    elif which_model == 'classSR_3class_fsrcnn_net':
        netG = classSR_3class_arch.classSR_3class_fsrcnn_net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_rcan':
        netG = classSR_rcan_arch.classSR_3class_rcan(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_srresnet':
        netG = classSR_srresnet_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_carn':
        netG = classSR_carn_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])

    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 3
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'PAN':
        netG = PAN_arch.PAN(in_nc=opt_net['in_nc'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            unf=opt_net['unf'],
                            nb=opt_net['nb'],
                            scale=opt_net['scale'])
    elif which_model == 'MSRResNet_PA':
        netG = SRResNet_arch.MSRResNet_PA(in_nc=opt_net['in_nc'],
                                          out_nc=opt_net['out_nc'],
                                          nf=opt_net['nf'],
                                          nb=opt_net['nb'],
                                          upscale=opt_net['scale'])
    elif which_model == 'RCAN_PA':
        netG = RCAN_arch.RCAN_PA(n_resgroups=opt_net['n_resgroups'],
                                 n_resblocks=opt_net['n_resblocks'],
                                 n_feats=opt_net['n_feats'],
                                 res_scale=opt_net['res_scale'],
                                 n_colors=opt_net['n_colors'],
                                 rgb_range=opt_net['rgb_range'],
                                 scale=opt_net['scale'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 4
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    elif which_model == 'ORDSRNet':
        netG = ORDSRModel(in_nc=opt_net['in_nc'],
                          out_nc=opt_net['out_nc'],
                          N=opt_net['base_size'],
                          S=opt_net['stride'],
                          upscale=opt_net['scale'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 5
0
def define_C(opt):
    opt_net = opt['network_C']
    which_model = opt_net['which_model_C']

    if which_model == 'dfn':
        netC = DynamicF.DFN_Color_correction()
    elif 'ResNet' in which_model:
        netC = SRResNet_arch.ResNet_alpha_beta_multi_in(which_model)
    else:
        raise NotImplementedError(
            'Discriminator model [{:s}] not recognized'.format(which_model))
    return netC
Esempio n. 6
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        import models.archs.EDVR_arch as EDVR_arch
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'],
                              w_GCB=opt_net['w_GCB'])
    # elif which_model == 'EDVR_woDCN':
    #     import models.archs.EDVR_woDCN_arch as EDVR_arch
    #     netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
    #                           groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
    #                           back_RBs=opt_net['back_RBs'], center=opt_net['center'],
    #                           predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
    #                           w_TSA=opt_net['w_TSA'], w_GCB=opt_net['w_GCB'])
    elif which_model == 'MGANet':
        netG = Gen_Guided_UNet(input_size=opt_net['input_size'])
    elif which_model == 'Unet':
        import repo.CycleGAN.networks as unet_networks
        netG = unet_networks.define_G(2 * 3, 1, opt_net['nf'],
                                      opt_net['G_type'], opt_net['norm'],
                                      opt_net['dropout'], opt_net['init_type'],
                                      opt_net['init_gain'])

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 7
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        import models.archs.SRResNet_arch as SRResNet_arch
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        import models.archs.RRDBNet_arch as RRDBNet_arch
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        import models.archs.EDVR_arch as EDVR_arch
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'],
                              scale=opt['scale'])
    elif which_model == 'DUF':
        import models.archs.DUF_arch as DUF_arch
        if opt_net['layers'] == 16:
            netG = DUF_arch.DUF_16L(scale=opt['scale'], adapt_official=True)
        elif opt_net['layers'] == 28:
            netG = DUF_arch.DUF_28L(scale=opt['scale'], adapt_official=True)
        else:
            netG = DUF_arch.DUF_52L(scale=opt['scale'], adapt_official=True)

    elif which_model == 'TOF':
        import models.archs.TOF_arch as TOF_arch
        netG = TOF_arch.TOFlow(adapt_official=True)

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 8
0
import os.path as osp
import sys
import torch
try:
    sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
    import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
    pass

pretrained_net = torch.load(
    '../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()

for k, v in crt_net.items():
    if k in pretrained_net and 'upconv1' not in k:
        crt_net[k] = pretrained_net[k]
        print('replace ... ', k)

# x4 -> x3
crt_net['upconv1.weight'][
    0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][
    256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][
    512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2

torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
Esempio n. 9
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVR2X':
        netG = EDVR_arch.EDVR2X(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                predeblur=opt_net['predeblur'],
                                HR_in=opt_net['HR_in'],
                                w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVRImg':
        netG = EDVR_arch.EDVRImage(nf=opt_net['nf'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   down_scale=opt_net['down_scale'])
    elif which_model == 'EDVR3D':
        netG = EDVR_arch.EDVR3D(nf=opt_net['nf'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                down_scale=opt_net['down_scale'])
    elif which_model == 'UPEDVR':
        netG = EDVR_arch.UPEDVR(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                w_TSA=opt_net['w_TSA'],
                                down_scale=opt_net['down_scale'],
                                align_target=opt_net['align_target'],
                                ret_valid=opt_net['ret_valid'])
    elif which_model == 'UPContEDVR':
        netG = EDVR_arch.UPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    elif which_model == 'FlowUPContEDVR':
        netG = EDVR_arch.FlowUPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    # video SR for multiple target frames
    elif which_model == 'MultiEDVR':
        netG = EDVR_arch.MultiEDVR(nf=opt_net['nf'],
                                   nframes=opt_net['nframes'],
                                   groups=opt_net['groups'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   center=opt_net['center'],
                                   predeblur=opt_net['predeblur'],
                                   HR_in=opt_net['HR_in'],
                                   w_TSA=opt_net['w_TSA'])
    # arbitrary magnification video super-resolution
    elif which_model == 'MetaEDVR':
        netG = EDVR_arch.MetaEDVR(nf=opt_net['nf'],
                                  nframes=opt_net['nframes'],
                                  groups=opt_net['groups'],
                                  front_RBs=opt_net['front_RBs'],
                                  back_RBs=opt_net['back_RBs'],
                                  center=opt_net['center'],
                                  predeblur=opt_net['predeblur'],
                                  HR_in=opt_net['HR_in'],
                                  w_TSA=opt_net['w_TSA'],
                                  fix_edvr=opt_net['fix_edvr'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 10
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_FusionDenoise':
        netG = my_EDVR_arch.MYEDVR_FusionDenoise(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            predeblur=opt_net['predeblur'],
            HR_in=opt_net['HR_in'],
            w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_RES':
        netG = my_EDVR_arch.MYEDVR_RES(nf=opt_net['nf'],
                                       nframes=opt_net['nframes'],
                                       groups=opt_net['groups'],
                                       front_RBs=opt_net['front_RBs'],
                                       back_RBs=opt_net['back_RBs'],
                                       center=opt_net['center'],
                                       predeblur=opt_net['predeblur'],
                                       HR_in=opt_net['HR_in'],
                                       w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_PreEnhance':
        netG = my_EDVR_arch.MYEDVR_PreEnhance(nf=opt_net['nf'],
                                              nframes=opt_net['nframes'],
                                              groups=opt_net['groups'],
                                              front_RBs=opt_net['front_RBs'],
                                              back_RBs=opt_net['back_RBs'],
                                              center=opt_net['center'],
                                              predeblur=opt_net['predeblur'],
                                              HR_in=opt_net['HR_in'],
                                              w_TSA=opt_net['w_TSA'])

    elif which_model == 'Recurr_ResBlocks':
        netG = Recurr_arch.Recurr_ResBlocks(
            nf=opt_net['nf'],
            N_RBs=opt_net['N_RBs'],
            N_flow_lv=opt_net['N_flow_lv'],
            pretrain_flow=opt_net['pretrain_flow'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 11
0
def main():
    #################
    # configurations
    #################
    #torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.enabled = True

    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'

    test_set = 'AI4K_test'  # Vid4 | YouKu10 | REDS4 | AI4K_test
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    test_name = 'Contest2_Test18_A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000'  #'AI4K_TEST_Denoise_A02_265000'    |  AI4K_test_A01b_145000
    N_in = 5

    # load test set
    if test_set == 'AI4K_test':
        #test_dataset_folder =  '/data1/yhliu/AI4K/Corrected_TestA_Contest2_001_ResNet_alpha_beta_gaussian_65000/'     #'/data1/yhliu/AI4K/testA_LR_png/'
        test_dataset_folder = '/home/yhliu/AI4K/contest2/testA_LR_png/'

    flip_test = False  #False

    #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #model_path = '../experiments/002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new/models/latest_G.pth'
    #model_path = '../experiments/A02_predenoise/models/415000_G.pth'

    model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth'

    color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    #model = my_EDVR_arch.MYEDVR_FusionDenoise(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, deconv=False)

    color_model = SRResNet_arch.ResNet_alpha_beta_multi_in(
        structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(test_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    #### set up the models
    load_net = torch.load(color_model_path)
    load_net_clean = OrderedDict()  # add prefix 'color_net.'
    for k, v in load_net.items():
        k = 'color_net.' + k
        load_net_clean[k] = v

    color_model.load_state_dict(load_net_clean, strict=True)
    color_model.eval()
    color_model = color_model.to(device)
    color_model = nn.DataParallel(color_model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder in subfolder_l:
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).cpu()
            print(imgs_in.size())

            if flip_test:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.flipx4_forward(model, imgs_in)
            else:
                start_time = time.time()
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.single_forward(model, imgs_in)
                end_time = time.time()
                print('Forward One image:', end_time - start_time)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            logger.info('{:3d} - {:25}'.format(img_idx + 1, img_name))

    logger.info('################ Tidy Outputs ################')

    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4'
    test_set = 'AI4K_val'  # Vid4 | YouKu10 | REDS4 | AI4K_val
    test_name = 'A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000'
    data_mode = 'sharp_bicubic'  # sharp_bicubic | blur_bicubic
    N_in = 5

    # load test set
    if test_set == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    elif test_set == 'YouKu10':
        test_dataset_folder = '../datasets/YouKu10/LR'
        GT_dataset_folder = '../datasets/YouKu10/HR'
    elif test_set == 'YouKu_val':
        test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp'
        GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp'
    elif test_set == 'REDS4':
        test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        GT_dataset_folder = '../datasets/REDS4/GT'
    elif test_set == 'AI4K_val':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/'
    elif test_set == 'AI4K_bic':
        test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png_bic/'
        GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/'
    elif test_set == 'AI4K_testA':
        test_dataset_folder = '/data0/yhliu/AI4K/val2_LR_png'  #'/home/yhliu/AI4K/val2_LR_png/'
        GT_dataset_folder = '/data0/yhliu/AI4K/val1_HR_png/'

    flip_test = False

    #model_path = '../experiments/pretrained_models/A01xxx/900000_G.pth'
    model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth'

    color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth'

    predeblur, HR_in = False, False
    back_RBs = 10
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True

    model = EDVR_arch.EDVR(64,
                           N_in,
                           8,
                           5,
                           back_RBs,
                           predeblur=predeblur,
                           HR_in=HR_in)
    color_model = SRResNet_arch.ResNet_alpha_beta_multi_in(
        structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW')

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(test_name)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)
    model = nn.DataParallel(model)

    #### set up the models
    load_net = torch.load(color_model_path)
    load_net_clean = OrderedDict()  # add prefix 'color_net.'
    for k, v in load_net.items():
        k = 'color_net.' + k
        load_net_clean[k] = v

    color_model.load_state_dict(load_net_clean, strict=True)
    color_model.eval()
    color_model = color_model.to(device)
    color_model = nn.DataParallel(color_model)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    #print(subfolder_l)
    #print(subfolder_GT_l)
    #exit()

    # for each subfolder
    for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        #print(img_path_l)
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)
        img_GT_l = []
        for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
            #print(img_GT_path)
            img_GT_l.append(data_util.read_img(None, img_GT_path))
        #print(img_GT_l[0].shape)
        avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.index_generation(img_idx,
                                                    max_idx,
                                                    N_in,
                                                    padding=padding)
            imgs_in = imgs_LQ.index_select(
                0,
                torch.LongTensor(select_idx)).unsqueeze(0).cpu()  #to(device)
            print(imgs_in.size())

            if flip_test:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.flipx4_forward(model, imgs_in)
            else:
                imgs_in = util.single_forward(color_model, imgs_in)
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            '''
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                GT = data_util.bgr2ycbcr(GT, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)
            '''

            output, GT = util.crop_border([output, GT], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, GT * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('################ Tidy Outputs ################')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))