def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RCAN': netG = RCAN_arch.RCAN(n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'], res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'],rgb_range=opt_net['rgb_range'], scale=opt_net['scale'],reduction=opt_net['reduction'],n_resgroups=opt_net['n_resgroups']) elif which_model == 'CARN_M': netG = CARN_arch.CARN_M(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], scale=opt_net['scale'], group=opt_net['group']) elif which_model == 'fsrcnn': netG = FSRCNN_arch.FSRCNN_net(input_channels=opt_net['in_nc'],upscale=opt_net['scale'],d=opt_net['d'], s=opt_net['s'],m=opt_net['m']) elif which_model == 'classSR_3class_fsrcnn_net': netG = classSR_3class_arch.classSR_3class_fsrcnn_net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc']) elif which_model == 'classSR_3class_rcan': netG = classSR_rcan_arch.classSR_3class_rcan(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc']) elif which_model == 'classSR_3class_srresnet': netG = classSR_srresnet_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc']) elif which_model == 'classSR_3class_carn': netG = classSR_carn_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'PAN': netG = PAN_arch.PAN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], unf=opt_net['unf'], nb=opt_net['nb'], scale=opt_net['scale']) elif which_model == 'MSRResNet_PA': netG = SRResNet_arch.MSRResNet_PA(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RCAN_PA': netG = RCAN_arch.RCAN_PA(n_resgroups=opt_net['n_resgroups'], n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'], res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'], rgb_range=opt_net['rgb_range'], scale=opt_net['scale']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb']) elif which_model == 'ORDSRNet': netG = ORDSRModel(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], N=opt_net['base_size'], S=opt_net['stride'], upscale=opt_net['scale']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_C(opt): opt_net = opt['network_C'] which_model = opt_net['which_model_C'] if which_model == 'dfn': netC = DynamicF.DFN_Color_correction() elif 'ResNet' in which_model: netC = SRResNet_arch.ResNet_alpha_beta_multi_in(which_model) else: raise NotImplementedError( 'Discriminator model [{:s}] not recognized'.format(which_model)) return netC
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb']) # video restoration elif which_model == 'EDVR': import models.archs.EDVR_arch as EDVR_arch netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA'], w_GCB=opt_net['w_GCB']) # elif which_model == 'EDVR_woDCN': # import models.archs.EDVR_woDCN_arch as EDVR_arch # netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], # groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], # back_RBs=opt_net['back_RBs'], center=opt_net['center'], # predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], # w_TSA=opt_net['w_TSA'], w_GCB=opt_net['w_GCB']) elif which_model == 'MGANet': netG = Gen_Guided_UNet(input_size=opt_net['input_size']) elif which_model == 'Unet': import repo.CycleGAN.networks as unet_networks netG = unet_networks.define_G(2 * 3, 1, opt_net['nf'], opt_net['G_type'], opt_net['norm'], opt_net['dropout'], opt_net['init_type'], opt_net['init_gain']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': import models.archs.SRResNet_arch as SRResNet_arch netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': import models.archs.RRDBNet_arch as RRDBNet_arch netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb']) # video restoration elif which_model == 'EDVR': import models.archs.EDVR_arch as EDVR_arch netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA'], scale=opt['scale']) elif which_model == 'DUF': import models.archs.DUF_arch as DUF_arch if opt_net['layers'] == 16: netG = DUF_arch.DUF_16L(scale=opt['scale'], adapt_official=True) elif opt_net['layers'] == 28: netG = DUF_arch.DUF_28L(scale=opt['scale'], adapt_official=True) else: netG = DUF_arch.DUF_52L(scale=opt['scale'], adapt_official=True) elif which_model == 'TOF': import models.archs.TOF_arch as TOF_arch netG = TOF_arch.TOFlow(adapt_official=True) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
import os.path as osp import sys import torch try: sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) import models.archs.SRResNet_arch as SRResNet_arch except ImportError: pass pretrained_net = torch.load( '../../experiments/pretrained_models/MSRResNetx4.pth') crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) crt_net = crt_model.state_dict() for k, v in crt_net.items(): if k in pretrained_net and 'upconv1' not in k: crt_net[k] = pretrained_net[k] print('replace ... ', k) # x4 -> x3 crt_net['upconv1.weight'][ 0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 crt_net['upconv1.weight'][ 256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 crt_net['upconv1.weight'][ 512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb']) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'EDVR2X': netG = EDVR_arch.EDVR2X(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'EDVRImg': netG = EDVR_arch.EDVRImage(nf=opt_net['nf'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], down_scale=opt_net['down_scale']) elif which_model == 'EDVR3D': netG = EDVR_arch.EDVR3D(nf=opt_net['nf'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], down_scale=opt_net['down_scale']) elif which_model == 'UPEDVR': netG = EDVR_arch.UPEDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], w_TSA=opt_net['w_TSA'], down_scale=opt_net['down_scale'], align_target=opt_net['align_target'], ret_valid=opt_net['ret_valid']) elif which_model == 'UPContEDVR': netG = EDVR_arch.UPControlEDVR( nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], w_TSA=opt_net['w_TSA'], down_scale=opt_net['down_scale'], align_target=opt_net['align_target'], ret_valid=opt_net['ret_valid'], multi_scale_cont=opt_net['multi_scale_cont']) elif which_model == 'FlowUPContEDVR': netG = EDVR_arch.FlowUPControlEDVR( nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], w_TSA=opt_net['w_TSA'], down_scale=opt_net['down_scale'], align_target=opt_net['align_target'], ret_valid=opt_net['ret_valid'], multi_scale_cont=opt_net['multi_scale_cont']) # video SR for multiple target frames elif which_model == 'MultiEDVR': netG = EDVR_arch.MultiEDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) # arbitrary magnification video super-resolution elif which_model == 'MetaEDVR': netG = EDVR_arch.MetaEDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA'], fix_edvr=opt_net['fix_edvr']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def define_G(opt): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] # image restoration if which_model == 'MSRResNet': netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'RRDBNet': netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb']) # video restoration elif which_model == 'EDVR': netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'MY_EDVR_FusionDenoise': netG = my_EDVR_arch.MYEDVR_FusionDenoise( nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'MY_EDVR_RES': netG = my_EDVR_arch.MYEDVR_RES(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'MY_EDVR_PreEnhance': netG = my_EDVR_arch.MYEDVR_PreEnhance(nf=opt_net['nf'], nframes=opt_net['nframes'], groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], back_RBs=opt_net['back_RBs'], center=opt_net['center'], predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], w_TSA=opt_net['w_TSA']) elif which_model == 'Recurr_ResBlocks': netG = Recurr_arch.Recurr_ResBlocks( nf=opt_net['nf'], N_RBs=opt_net['N_RBs'], N_flow_lv=opt_net['N_flow_lv'], pretrain_flow=opt_net['pretrain_flow']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) return netG
def main(): ################# # configurations ################# #torch.backends.cudnn.benchmark = True #torch.backends.cudnn.enabled = True device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '5' test_set = 'AI4K_test' # Vid4 | YouKu10 | REDS4 | AI4K_test data_mode = 'sharp_bicubic' # sharp_bicubic | blur_bicubic test_name = 'Contest2_Test18_A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000' #'AI4K_TEST_Denoise_A02_265000' | AI4K_test_A01b_145000 N_in = 5 # load test set if test_set == 'AI4K_test': #test_dataset_folder = '/data1/yhliu/AI4K/Corrected_TestA_Contest2_001_ResNet_alpha_beta_gaussian_65000/' #'/data1/yhliu/AI4K/testA_LR_png/' test_dataset_folder = '/home/yhliu/AI4K/contest2/testA_LR_png/' flip_test = False #False #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' #model_path = '../experiments/002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new/models/latest_G.pth' #model_path = '../experiments/A02_predenoise/models/415000_G.pth' model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth' color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth' predeblur, HR_in = False, False back_RBs = 10 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #model = my_EDVR_arch.MYEDVR_FusionDenoise(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, deconv=False) color_model = SRResNet_arch.ResNet_alpha_beta_multi_in( structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW') #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(test_name) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) model = nn.DataParallel(model) #### set up the models load_net = torch.load(color_model_path) load_net_clean = OrderedDict() # add prefix 'color_net.' for k, v in load_net.items(): k = 'color_net.' + k load_net_clean[k] = v color_model.load_state_dict(load_net_clean, strict=True) color_model.eval() color_model = color_model.to(device) color_model = nn.DataParallel(color_model) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) #print(subfolder_l) #print(subfolder_GT_l) #exit() # for each subfolder for subfolder in subfolder_l: subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) #print(img_path_l) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).cpu() print(imgs_in.size()) if flip_test: imgs_in = util.single_forward(color_model, imgs_in) output = util.flipx4_forward(model, imgs_in) else: start_time = time.time() imgs_in = util.single_forward(color_model, imgs_in) output = util.single_forward(model, imgs_in) end_time = time.time() print('Forward One image:', end_time - start_time) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) logger.info('{:3d} - {:25}'.format(img_idx + 1, img_name)) logger.info('################ Tidy Outputs ################') logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test))
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '5' #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' test_set = 'AI4K_val' # Vid4 | YouKu10 | REDS4 | AI4K_val test_name = 'A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd_165000' data_mode = 'sharp_bicubic' # sharp_bicubic | blur_bicubic N_in = 5 # load test set if test_set == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' elif test_set == 'YouKu10': test_dataset_folder = '../datasets/YouKu10/LR' GT_dataset_folder = '../datasets/YouKu10/HR' elif test_set == 'YouKu_val': test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp' GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp' elif test_set == 'REDS4': test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) GT_dataset_folder = '../datasets/REDS4/GT' elif test_set == 'AI4K_val': test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/' GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/' elif test_set == 'AI4K_bic': test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png_bic/' GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/' elif test_set == 'AI4K_testA': test_dataset_folder = '/data0/yhliu/AI4K/val2_LR_png' #'/home/yhliu/AI4K/val2_LR_png/' GT_dataset_folder = '/data0/yhliu/AI4K/val1_HR_png/' flip_test = False #model_path = '../experiments/pretrained_models/A01xxx/900000_G.pth' model_path = '../experiments/A38_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/165000_G.pth' color_model_path = '/home/yhliu/BasicSR/experiments/35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k/models/220000_G.pth' predeblur, HR_in = False, False back_RBs = 10 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) color_model = SRResNet_arch.ResNet_alpha_beta_multi_in( structure='ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW') #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(test_name) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) model = nn.DataParallel(model) #### set up the models load_net = torch.load(color_model_path) load_net_clean = OrderedDict() # add prefix 'color_net.' for k, v in load_net.items(): k = 'color_net.' + k load_net_clean[k] = v color_model.load_state_dict(load_net_clean, strict=True) color_model.eval() color_model = color_model.to(device) color_model = nn.DataParallel(color_model) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) #print(subfolder_l) #print(subfolder_GT_l) #exit() # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) #print(img_path_l) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): #print(img_GT_path) img_GT_l.append(data_util.read_img(None, img_GT_path)) #print(img_GT_l[0].shape) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).cpu() #to(device) print(imgs_in.size()) if flip_test: imgs_in = util.single_forward(color_model, imgs_in) output = util.flipx4_forward(model, imgs_in) else: imgs_in = util.single_forward(color_model, imgs_in) output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel ''' if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) ''' output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))