def main(): ################# # configurations ################# parser = argparse.ArgumentParser() parser.add_argument("--input_path", type=str, required=True) parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--gpu_id", type=str, required=True) parser.add_argument("--screen_notation", type=str, required=True) parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt, is_train=False) PAD = 32 total_run_time = AverageMeter() # print("GPU ", torch.cuda.device_count()) device = torch.device('cuda') # os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) print('export CUDA_VISIBLE_DEVICES=' + str(args.gpu_id)) data_mode = 'sharp_bicubic' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic" # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k" # Result_folder = "/DATA7_DB7/data/4khdr/data/Results" Input_folder = args.input_path GT_folder = args.gt_path Result_folder = args.output_path Model_path = args.model_path # create results folder if not os.path.exists(Result_folder): os.makedirs(Result_folder, exist_ok=True) ############################################################################ #### model # if data_mode == 'Vid4': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' # else: # raise ValueError('Vid4 does not support stage 2.') # elif data_mode == 'sharp_bicubic': # if stage == 1: # # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' # elif data_mode == 'blur_bicubic': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' # elif data_mode == 'blur': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' # elif data_mode == 'blur_comp': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' # else: # raise NotImplementedError model_path = Model_path if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'], groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'], back_RBs=opt['network_G']['back_RBs'], predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'], w_TSA=opt['network_G']['w_TSA']) # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' else: if stage == 1: # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4' test_dataset_folder = Input_folder else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') # GT_dataset_folder = '../datasets/REDS4/GT' # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp' GT_dataset_folder = GT_folder #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True # save_folder = '../results/{}'.format(data_mode) # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode) save_folder = os.path.join(Result_folder, data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): end = time.time() # load screen change notation import json with open(args.screen_notation) as f: frame_notation = json.load(f) for subfolder in subfolder_l: input_subfolder = os.path.split(subfolder)[1] subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder) if not os.path.exists(subfolder_GT): continue print("Evaluate Folders: ", input_subfolder) subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # Num x 3 x H x W img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l[:5]): img_name = osp.splitext(osp.basename(img_path))[0] # todo here handle screen change select_idx = data_util.index_generation_process_screen_change(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) # 960 x 540 # here we split the input images 960x540 into 9 320x180 patch gtWidth = 3840 gtHeight = 2160 intWidth_ori = imgs_in.shape[4] # 960 intHeight_ori = imgs_in.shape[3] # 540 split_lengthY = 180 split_lengthX = 320 scale = 4 intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_ intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_ pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_]) print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_)) intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) imgs_in = torch.squeeze(imgs_in, 0)# N C H W imgs_in = pader0(imgs_in) # N C 540 960 imgs_in = pader(imgs_in) # N C 604 1024 assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX)) split_lengthY = int(split_lengthY) split_lengthX = int(split_lengthX) split_numY = int(float(intHeight_ori) / split_lengthY ) split_numX = int(float(intWidth_ori) / split_lengthX) splitsY = range(0, split_numY) splitsX = range(0, split_numX) intWidth = split_lengthX intWidth_pad = intWidth + intPaddingLeft + intPaddingRight intHeight = split_lengthY intHeight_pad = intHeight + intPaddingTop + intPaddingBottom # print("split " + str(split_numY) + ' , ' + str(split_numX)) y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32") # HWC for split_j, split_i in itertools.product(splitsY, splitsX): # print(str(split_j) + ", \t " + str(split_i)) X0 = imgs_in[:, :, split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop, split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft] # y_ = torch.FloatTensor() X0 = torch.unsqueeze(X0, 0) # N C H W -> 1 N C H W if flip_test: output = util.flipx4_forward(model, X0) else: output = util.single_forward(model, X0) output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale] output_depadded = output_depadded.squeeze(0) output = util.tensor2img(output_depadded) y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale, split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \ np.round(output).astype(np.uint8) # plt.figure(0) # plt.title("pic") # plt.imshow(y_all) if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all) print("*****************current image process time \t " + str( time.time() - end) + "s ******************") total_run_time.update(time.time() - end, 1) # calculate PSNR y_all = y_all / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) y_all = data_util.bgr2ycbcr(y_all, only_y=True) y_all, GT = util.crop_border([y_all, GT], crop_border) crt_psnr = util.calculate_psnr(y_all * 255, GT * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'ai4khdr_valid' flip_test = False ############################################################################ #### model ################# if data_mode == 'ai4khdr_valid': model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth' else: raise NotImplementedError N_in = 5 front_RBs = 5 back_RBs = 10 predeblur, HR_in = False, False model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in) ############################################################################ #### dataset ################# if data_mode == 'ai4khdr_valid': test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames' GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames' else: raise NotImplementedError ############################################################################ #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'ai4khdr_valid': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp()) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=False) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for ai4khdr_valid, evaluate on the Y channel if data_mode == 'ai4khdr_valid': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
img_LR = torch.from_numpy( np.ascontiguousarray(np.transpose( img_LR, (2, 0, 1)))).float().unsqueeze(0).cuda() with torch.no_grad(): begin_time = time.time() frame, sr_base, output = model(img_LR) end_time = time.time() stat_time += (end_time - begin_time) #print(end_time-begin_time) output = util.tensor2img(output.squeeze(0)) frame = util.tensor2img(frame.squeeze(0)) sr_base = util.tensor2img(sr_base.squeeze(0)) # save images save_path_name = osp.join( save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name)) merge = np.concatenate((frame, sr_base, output), axis=1) util.save_img(merge, save_path_name) # calculate PSNR sr_img, gt_img = util.crop_border([output, img_GT], scale) PSNR_avg += util.calculate_psnr(sr_img, gt_img) SSIM_avg += util.calculate_ssim(sr_img, gt_img) print('average PSNR: ', PSNR_avg / len(img_list)) print('average SSIM: ', SSIM_avg / len(img_list)) print('time: ', stat_time / len(img_list))
def main(): #################### # arguments parser # #################### # [format] dataset(vid4, REDS4) N(number of frames) parser = argparse.ArgumentParser() parser.add_argument('dataset') parser.add_argument('n_frames') parser.add_argument('stage') args = parser.parse_args() data_mode = str(args.dataset) N_in = int(args.n_frames) stage = int(args.stage) #if args.command == 'start': # start(int(args.params[0])) #elif args.command == 'stop': # stop(args.params[0], int(args.params[1])) #elif args.command == 'stop_all': # stop_all(args.params[0]) ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' #data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). #stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' else: raise ValueError('Vid4 does not support stage 2.') elif data_mode == 'sharp_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' elif data_mode == 'blur_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' elif data_mode == 'blur': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' elif data_mode == 'blur_comp': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' else: raise NotImplementedError predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 #### dataset if data_mode == 'Vid4': N_model_default = 7 test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' else: N_model_default = 5 if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True data_mode_t = copy.deepcopy(data_mode) if stage == 1 and data_mode_t != 'Vid4': data_mode = 'REDS-EDVR_REDS_SR_L_flipx4' save_folder = '../results/{}'.format(data_mode) data_mode = copy.deepcopy(data_mode_t) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models print([a for a in dir(model) if not callable(getattr(model, a))]) # not a.startswith('__') and #model.load_state_dict(torch.load(model_path), strict=True) raw_model.load_state_dict(torch.load(model_path), strict=True) # model.load_state_dict(torch.load(model_path), strict=True) #### change model so it can work with less input model.nf = raw_model.nf model.center = N_in // 2 # if center is None else center model.is_predeblur = raw_model.is_predeblur model.HR_in = raw_model.HR_in model.w_TSA = raw_model.w_TSA #ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) #### extract features (for each frame) if model.is_predeblur: model.pre_deblur = raw_model.pre_deblur #Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) model.conv_1x1 = raw_model.conv_1x1 #nn.Conv2d(nf, nf, 1, 1, bias=True) else: if model.HR_in: model.conv_first_1 = raw_model.conv_first_1 #nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.conv_first_2 = raw_model.conv_first_2 #nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.conv_first_3 = raw_model.conv_first_3 #nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: model.conv_first = raw_model.conv_first # nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.feature_extraction = raw_model.feature_extraction # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) model.fea_L2_conv1 = raw_model.fea_L2_conv1 #nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L2_conv2 = raw_model.fea_L2_conv2 #nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.fea_L3_conv1 = raw_model.fea_L3_conv1 #nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L3_conv2 = raw_model.fea_L3_conv2 #nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.pcd_align = raw_model.pcd_align #PCD_Align(nf=nf, groups=groups) ######## Resize TSA model.tsa_fusion.center = model.center # temporal attention (before fusion conv) model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1 model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2 # fusion conv: using 1x1 to save parameters and computation #print(raw_model.tsa_fusion.fea_fusion.weight.shape) #print(raw_model.tsa_fusion.fea_fusion.weight.shape) #print(raw_model.tsa_fusion.fea_fusion.weight[127][639].shape) #print("MAIN SHAPE(FEA): ", raw_model.tsa_fusion.fea_fusion.weight.shape) model.tsa_fusion.fea_fusion = copy.deepcopy( raw_model.tsa_fusion.fea_fusion) model.tsa_fusion.fea_fusion.weight = copy.deepcopy( torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :])) #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) #model.tsa_fusion.fea_fusion.bias = raw_model.tsa_fusion.fea_fusion.bias # spatial attention (after fusion conv) model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1) model.tsa_fusion.sAtt_1.weight = copy.deepcopy( torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :])) #[:][] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) #model.tsa_fusion.sAtt_1.bias = raw_model.tsa_fusion.sAtt_1.bias #print(N_in * 128) #print(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :].shape) print("MODEL TSA SHAPE: ", model.tsa_fusion.fea_fusion.weight.shape) model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2 model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3 model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4 model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5 model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1 model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2 model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3 model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1 model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2 model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu #if model.w_TSA: # model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #TSA_Fusion(nf=nf, nframes=nframes, center=self.center) #else: # model.tsa_fusion = raw_model.tsa_fusion[:][:128 * N_in][:][:] #nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) # print(self.tsa_fusion) #### reconstruction model.recon_trunk = raw_model.recon_trunk # arch_util.make_layer(ResidualBlock_noBN_f, back_RBs) #### upsampling model.upconv1 = raw_model.upconv1 #nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) model.upconv2 = raw_model.upconv2 #nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) model.pixel_shuffle = raw_model.pixel_shuffle # nn.PixelShuffle(2) model.HRconv = raw_model.HRconv model.conv_last = raw_model.conv_last #### activation function model.lrelu = raw_model.lrelu ##################################################### model.eval() model = model.to(device) avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) print("MAX_IDX: ", max_idx) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_ssim, avg_ssim_border, avg_ssim_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] if data_mode == "blur": select_idx = data_util.glarefree_index_generation( img_idx, max_idx, N_in, padding=padding) else: select_idx = data_util.index_generation( img_idx, max_idx, N_in, padding=padding) # HERE GOTCHA print("SELECT IDX: ", select_idx) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: print("IMGS_IN SHAPE: ", imgs_in.shape) # check this output = util.single_forward(model, imgs_in) # error here 1 output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate SSIM output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_ssim = util.calculate_ssim(output * 255, GT * 255) logger.info('{:3d} - {:25} \tSSIM: {:.6f} dB'.format( img_idx + 1, img_name, crt_ssim)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_ssim_center += crt_ssim N_center += 1 else: # border frames avg_ssim_border += crt_ssim N_border += 1 avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) avg_ssim_center = avg_ssim_center / N_center avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border avg_ssim_l.append(avg_ssim) avg_ssim_center_l.append(avg_ssim_center) avg_ssim_border_l.append(avg_ssim_border) logger.info('Folder {} - Average SSIM: {:.6f} dB for {} frames; ' 'Center SSIM: {:.6f} dB for {} frames; ' 'Border SSIM: {:.6f} dB for {} frames.'.format( subfolder_name, avg_ssim, (N_center + N_border), avg_ssim_center, N_center, avg_ssim_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, ssim, ssim_center, ssim_border in zip( subfolder_name_l, avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l): logger.info('Folder {} - Average SSIM: {:.6f} dB. ' 'Center SSIM: {:.6f} dB. ' 'Border SSIM: {:.6f} dB.'.format(subfolder_name, ssim, ssim_center, ssim_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average SSIM: {:.6f} dB for {} clips. ' 'Center SSIM: {:.6f} dB. Border SSIM: {:.6f} dB.'.format( sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l), sum(avg_ssim_center_l) / len(avg_ssim_center_l), sum(avg_ssim_border_l) / len(avg_ssim_border_l)))
def main(): #################### # arguments parser # #################### # [format] dataset(vid4, REDS4) N(number of frames) # data_mode = str(args.dataset) # N_in = int(args.n_frames) # metrics = str(args.metrics) # output_format = str(args.output_format) ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' #data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). # STAGE Vid4 # Collecting results for Vid4 model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False predeblur, HR_in = False, False back_RBs = 40 N_model_default = 7 data_mode = 'Vid4' # vid4_dir_map = {"calendar": 0, "city": 1, "foliage": 2, "walk": 3} vid4_results = {"calendar": {}, "city": {}, "foliage": {}, "walk": {}} #vid4_results = 4 * [[]] for N_in in range(1, N_model_default + 1): raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' aposterior_GT_dataset_folder = '../datasets/Vid4/GT_7' crop_border = 0 border_frame = N_in // 2 # border frames when evaluate padding = 'new_info' save_imgs = False raw_model.load_state_dict(torch.load(model_path), strict=True) model.nf = raw_model.nf model.center = N_in // 2 # if center is None else center model.is_predeblur = raw_model.is_predeblur model.HR_in = raw_model.HR_in model.w_TSA = raw_model.w_TSA if model.is_predeblur: model.pre_deblur = raw_model.pre_deblur # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) model.conv_1x1 = raw_model.conv_1x1 # nn.Conv2d(nf, nf, 1, 1, bias=True) else: if model.HR_in: model.conv_first_1 = raw_model.conv_first_1 # nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.conv_first_2 = raw_model.conv_first_2 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.conv_first_3 = raw_model.conv_first_3 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: model.conv_first = raw_model.conv_first # nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.feature_extraction = raw_model.feature_extraction # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) model.fea_L2_conv1 = raw_model.fea_L2_conv1 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L2_conv2 = raw_model.fea_L2_conv2 # nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.fea_L3_conv1 = raw_model.fea_L3_conv1 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L3_conv2 = raw_model.fea_L3_conv2 # nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.pcd_align = raw_model.pcd_align # PCD_Align(nf=nf, groups=groups) model.tsa_fusion.center = model.center model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1 model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2 model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion) model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :])) model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1) model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :])) model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2 model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3 model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4 model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5 model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1 model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2 model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3 model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1 model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2 model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu model.recon_trunk = raw_model.recon_trunk model.upconv1 = raw_model.upconv1 model.upconv2 = raw_model.upconv2 model.pixel_shuffle = raw_model.pixel_shuffle model.HRconv = raw_model.HRconv model.conv_last = raw_model.conv_last model.lrelu = raw_model.lrelu ##################################################### model.eval() model = model.to(device) #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*"))) # for each subfolder for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) print("MAX_IDX: ", max_idx) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) img_GT_a = [] for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))): img_GT_a.append(data_util.read_img(None, img_GT_a_path)) #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: print("IMGS_IN SHAPE: ", imgs_in.shape) output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel #if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) GT_a = np.copy(img_GT_a[img_idx]) GT_a = data_util.bgr2ycbcr(GT_a, only_y=True) output_a = copy.deepcopy(output) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) crt_ssim = util.calculate_ssim(output * 255, GT * 255) output_a, GT_a = util.crop_border([output_a, GT_a], crop_border) crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255) # CHANGE t = vid4_results[subfolder_name].get(str(img_name)) if t != None: vid4_results[subfolder_name][img_name].add_psnr(crt_psnr) vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim) vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior) else: vid4_results[subfolder_name].update({img_name: metrics_file(img_name)}) vid4_results[subfolder_name][img_name].add_psnr(crt_psnr) vid4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim) vid4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior) ############################################################################ #### model #### writing vid4 results util.mkdirs('../results/calendar') util.mkdirs('../results/city') util.mkdirs('../results/foliage') util.mkdirs('../results/walk') save_folder = '../results/' for i, dir_name in enumerate(["calendar", "city", "foliage", "walk"]): save_subfolder = osp.join(save_folder, dir_name) for j, value in vid4_results[dir_name].items(): # cur_result = json.dumps(_) with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile: json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4) #json.dump(cur_result, outfile) #cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) ################################################################################### # STAGE REDS reds4_results = {"000": {}, "011": {}, "015": {}, "020": {}} data_mode = 'sharp_bicubic' N_model_default = 5 for N_in in range(1, N_model_default + 1): for stage in range(1,3): flip_test = False if data_mode == 'sharp_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' elif data_mode == 'blur_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' elif data_mode == 'blur': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' elif data_mode == 'blur_comp': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' else: raise NotImplementedError predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' raw_model = EDVR_arch.EDVR(128, N_model_default, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True data_mode_t = copy.deepcopy(data_mode) if stage == 1 and data_mode_t != 'Vid4': data_mode = 'REDS-EDVR_REDS_SR_L_flipx4' save_folder = '../results/{}'.format(data_mode) data_mode = copy.deepcopy(data_mode_t) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) aposterior_GT_dataset_folder = '../datasets/REDS4/GT_5' crop_border = 0 border_frame = N_in // 2 # border frames when evaluate raw_model.load_state_dict(torch.load(model_path), strict=True) model.nf = raw_model.nf model.center = N_in // 2 # if center is None else center model.is_predeblur = raw_model.is_predeblur model.HR_in = raw_model.HR_in model.w_TSA = raw_model.w_TSA if model.is_predeblur: model.pre_deblur = raw_model.pre_deblur # Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) model.conv_1x1 = raw_model.conv_1x1 # nn.Conv2d(nf, nf, 1, 1, bias=True) else: if model.HR_in: model.conv_first_1 = raw_model.conv_first_1 # nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.conv_first_2 = raw_model.conv_first_2 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.conv_first_3 = raw_model.conv_first_3 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: model.conv_first = raw_model.conv_first # nn.Conv2d(3, nf, 3, 1, 1, bias=True) model.feature_extraction = raw_model.feature_extraction # arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) model.fea_L2_conv1 = raw_model.fea_L2_conv1 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L2_conv2 = raw_model.fea_L2_conv2 # nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.fea_L3_conv1 = raw_model.fea_L3_conv1 # nn.Conv2d(nf, nf, 3, 2, 1, bias=True) model.fea_L3_conv2 = raw_model.fea_L3_conv2 # nn.Conv2d(nf, nf, 3, 1, 1, bias=True) model.pcd_align = raw_model.pcd_align # PCD_Align(nf=nf, groups=groups) model.tsa_fusion.center = model.center model.tsa_fusion.tAtt_1 = raw_model.tsa_fusion.tAtt_1 model.tsa_fusion.tAtt_2 = raw_model.tsa_fusion.tAtt_2 model.tsa_fusion.fea_fusion = copy.deepcopy(raw_model.tsa_fusion.fea_fusion) model.tsa_fusion.fea_fusion.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.fea_fusion.weight[:, 0:N_in * 128, :, :])) model.tsa_fusion.sAtt_1 = copy.deepcopy(raw_model.tsa_fusion.sAtt_1) model.tsa_fusion.sAtt_1.weight = copy.deepcopy(torch.nn.Parameter(raw_model.tsa_fusion.sAtt_1.weight[:, 0:N_in * 128, :, :])) model.tsa_fusion.maxpool = raw_model.tsa_fusion.maxpool model.tsa_fusion.avgpool = raw_model.tsa_fusion.avgpool model.tsa_fusion.sAtt_2 = raw_model.tsa_fusion.sAtt_2 model.tsa_fusion.sAtt_3 = raw_model.tsa_fusion.sAtt_3 model.tsa_fusion.sAtt_4 = raw_model.tsa_fusion.sAtt_4 model.tsa_fusion.sAtt_5 = raw_model.tsa_fusion.sAtt_5 model.tsa_fusion.sAtt_L1 = raw_model.tsa_fusion.sAtt_L1 model.tsa_fusion.sAtt_L2 = raw_model.tsa_fusion.sAtt_L2 model.tsa_fusion.sAtt_L3 = raw_model.tsa_fusion.sAtt_L3 model.tsa_fusion.sAtt_add_1 = raw_model.tsa_fusion.sAtt_add_1 model.tsa_fusion.sAtt_add_2 = raw_model.tsa_fusion.sAtt_add_2 model.tsa_fusion.lrelu = raw_model.tsa_fusion.lrelu model.recon_trunk = raw_model.recon_trunk model.upconv1 = raw_model.upconv1 model.upconv2 = raw_model.upconv2 model.pixel_shuffle = raw_model.pixel_shuffle model.HRconv = raw_model.HRconv model.conv_last = raw_model.conv_last model.lrelu = raw_model.lrelu ##################################################### model.eval() model = model.to(device) #avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) subfolder_GT_a_l = sorted(glob.glob(osp.join(aposterior_GT_dataset_folder, "*"))) # for each subfolder for subfolder, subfolder_GT, subfolder_GT_a in zip(subfolder_l, subfolder_GT_l, subfolder_GT_a_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) print("MAX_IDX: ", max_idx) print("SAVE FOLDER::::::", save_folder) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) img_GT_a = [] for img_GT_a_path in sorted(glob.glob(osp.join(subfolder_GT_a, '*'))): img_GT_a.append(data_util.read_img(None, img_GT_a_path)) #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: print("IMGS_IN SHAPE: ", imgs_in.shape) output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs and stage == 1: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR if stage == 2: output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel #if data_mode == 'Vid4': # bgr2y, [0, 1] GT_a = np.copy(img_GT_a[img_idx]) output_a = copy.deepcopy(output) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) crt_ssim = util.calculate_ssim(output * 255, GT * 255) output_a, GT_a = util.crop_border([output_a, GT_a], crop_border) crt_aposterior = util.calculate_ssim(output_a * 255, GT_a * 255) # CHANGE t = reds4_results[subfolder_name].get(str(img_name)) if t != None: reds4_results[subfolder_name][img_name].add_psnr(crt_psnr) reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim) reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior) else: reds4_results[subfolder_name].update({img_name: metrics_file(img_name)}) reds4_results[subfolder_name][img_name].add_psnr(crt_psnr) reds4_results[subfolder_name][img_name].add_gt_ssim(crt_ssim) reds4_results[subfolder_name][img_name].add_aposterior_ssim(crt_aposterior) ############################################################################ #### model #### writing reds4 results util.mkdirs('../results/000') util.mkdirs('../results/011') util.mkdirs('../results/015') util.mkdirs('../results/020') save_folder = '../results/' for i, dir_name in enumerate(["000", "011", "015", "020"]): # + save_subfolder = osp.join(save_folder, dir_name) for j, value in reds4_results[dir_name].items(): # cur_result = json.dumps(value.__dict__) with open(osp.join(save_subfolder, '{}.json'.format(value.name)), 'w') as outfile: json.dump(value.__dict__, outfile, ensure_ascii=False, indent=4)
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #pdb.set_trace() #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist(args.launcher) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader #pdb.set_trace() dataset_ratio = 1000 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: if rank <= 0: # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if current_step % 50000 == 0: torch.cuda.empty_cache() torch.cuda.empty_cache() if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'SDR_4bit' stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'SDR_4bit': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' else: raise NotImplementedError # use N_in images to restore one high bitdepth image N_in = 5 # predeblur: predeblur for blurry input # HR_in: downsample high resolution input predeblur, HR_in = False, False back_RBs = 40 predeblur = True HR_in = True if data_mode == 'SDR_4bit': # predeblur, HR_in = False, True pass if stage == 2: HR_in = True back_RBs = 20 # EDVR(num_feature_map, num_input_frames, deformable_groups?, front_RBs, # back_RBs, predeblur, HR_in) model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if stage == 1: test_dataset_folder = '../datasets/{}'.format(data_mode) else: test_dataset_folder = '../' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/SDR_10bit/' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LBD and GT images #### resize to avoid cuda out of memory, 2160x3840->720x1280 imgs_LBD = data_util.read_img_seq(subfolder, scale=65535., zoomout=(1280, 720)) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append( data_util.read_img(None, img_GT_path, scale=65535., zoomout=True)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] # generate frame index select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LBD.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: # self ensemble with fipping input at four different directions output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0), out_type=np.uint16) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR # output = output / 255. output = output / 65535. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 65535, GT * 65535) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main(): ################# # configurations ################# flip_test = False scale = 4 N_in = 7 predeblur, HR_in = False, False n_feats = 128 back_RBs = 40 save_imgs = False prog = argparse.ArgumentParser() prog.add_argument('--train_mode', '-t', type=str, default='REDS', help='train mode') prog.add_argument('--data_mode', '-m', type=str, default=None, help='data_mode') prog.add_argument('--degradation_mode', '-d', type=str, default='impulse', choices=('impulse', 'bicubic', 'preset'), help='path to image output directory.') prog.add_argument('--sigma_x', '-sx', type=float, default=1, help='sigma_x') prog.add_argument('--sigma_y', '-sy', type=float, default=0, help='sigma_y') prog.add_argument('--theta', '-th', type=float, default=0, help='theta') prog.add_argument('--model', type=str, default=None, help='name for subdirectory') args = prog.parse_args() train_data_mode = args.train_mode data_mode = args.data_mode if data_mode is None: if train_data_mode == 'Vimeo': data_mode = 'Vid4' elif train_data_mode == 'REDS': data_mode = 'REDS' degradation_mode = args.degradation_mode # impulse | bicubic | preset sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta if sig_y == 0: sig_y = sig_x folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str( '{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the)) #### dataset if data_mode == 'Vid4': test_dataset_folder = '../dataset/Vid4/LR_{}/X1_KG_ZSSR'.format(folder_subname) #test_dataset_folder = '../dataset/Vid4/LR_{}/X1_CF_DBPN'.format(folder_subname) GT_dataset_folder = '../dataset/Vid4/HR' elif data_mode == 'MM522': test_dataset_folder = '../dataset/MM522val/LR_bicubic/X1_KG_ZSSR' GT_dataset_folder = '../dataset/MM522val/HR' else: # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale) test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_KG_ZSSR'.format(folder_subname) #test_dataset_folder = '../dataset/REDS/train/LR_{}/X1_CF_DBPN'.format(folder_subname) GT_dataset_folder = '../dataset/REDS/train/HR' #### evaluation crop_border = 0 border_frame = 0 # N_in // 2 # border frames when evaluate # temporal padding mode padding = 'new_info' save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) if data_mode == 'REDS': subfolder_GT_l = [k for k in subfolder_GT_l if k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0] # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder_GT, '*5.png'))) ## *5.png max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images img_LQ_l = [] img_GT_l = [] for img_LQ_path in sorted(glob.glob(osp.join(subfolder, '*5.png.png'))): img_LQ_l.append(data_util.read_img(None, img_LQ_path)) for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*5.png'))): ### *5.png img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = np.copy(img_LQ_l[img_idx]) GT = np.copy(img_GT_l[img_idx]) ''' output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1) GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor) torch.save(output_tensor.cpu(), '../results/sr_test.pt') torch.save(GT_tensor.cpu(), '../results/hr_test.pt') my_psnr = utility.calc_psnr(output_tensor, GT_tensor) GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0) imageio.imwrite('../results/hr_test.png', GT_tensor) print('saved', my_psnr) ''' ''' # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) ''' output = (output * 255).round().astype('uint8') GT = (GT * 255).round().astype('uint8') output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output, GT) crt_ssim = util.calculate_ssim(output, GT) # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr avg_ssim_center += crt_ssim N_center += 1 else: # border frames avg_psnr_border += crt_psnr avg_ssim_border += crt_ssim N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) avg_ssim_center = avg_ssim_center / N_center avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border avg_ssim_l.append(avg_ssim) avg_ssim_center_l.append(avg_ssim_center) avg_ssim_border_l.append(avg_ssim_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; ' 'Center SSIM: {:.6f} for {} frames; ' 'Border SSIM: {:.6f} for {} frames.'.format(subfolder_name, avg_ssim, (N_center + N_border), avg_ssim_center, N_center, avg_ssim_border, N_border)) ''' logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l): logger.info('Folder {} - Average SSIM: {:.6f}. ' 'Center SSIM: {:.6f}. ' 'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border)) ''' logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l))) logger.info('Total Average SSIM: {:.6f} for {} clips. ' 'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format( sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l), sum(avg_ssim_center_l) / len(avg_ssim_center_l), sum(avg_ssim_border_l) / len(avg_ssim_border_l))) print('\n\n\n')
def do_step(self, train_data): if self._profile: print("Data fetch: %f" % (time() - _t)) _t = time() opt = self.opt self.current_step += 1 #### update learning rate self.model.update_learning_rate( self.current_step, warmup_iter=opt['train']['warmup_iter']) #### training if self._profile: print("Update LR: %f" % (time() - _t)) _t = time() self.model.feed_data(train_data, self.current_step) self.model.optimize_parameters(self.current_step) if self._profile: print("Model feed + step: %f" % (time() - _t)) _t = time() #### log if self.current_step % opt['logger'][ 'print_freq'] == 0 and self.rank <= 0: logs = self.model.get_current_log(self.current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( self.epoch, self.current_step) for v in self.model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): if 'histogram' in k: self.tb_logger.add_histogram(k, v, self.current_step) elif isinstance(v, dict): self.tb_logger.add_scalars(k, v, self.current_step) else: message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: self.tb_logger.add_scalar(k, v, self.current_step) if opt['wandb'] and self.rank <= 0: import wandb wandb.log(logs) self.logger.info(message) #### save models and training states if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) self.model.save_training_state(self.epoch, self.current_step) if 'alt_path' in opt['path'].keys(): import shutil print("Synchronizing tb_logger to alt_path..") alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger") shutil.rmtree(alt_tblogger, ignore_errors=True) shutil.copytree(self.tb_logger_path, alt_tblogger) #### validation if opt['datasets'].get( 'val', None) and self.current_step % opt['train']['val_freq'] == 0: if opt['model'] in [ 'sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer' ] and self.rank <= 0: # image restoration validation avg_psnr = 0. avg_fea_loss = 0. idx = 0 val_tqdm = tqdm(self.val_loader) for val_data in val_tqdm: idx += 1 for b in range(len(val_data['HQ_path'])): img_name = os.path.splitext( os.path.basename(val_data['HQ_path'][b]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) self.model.feed_data(val_data, self.current_step) self.model.test() visuals = self.model.get_current_visuals() if visuals is None: continue sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 # calculate PSNR if self.val_compute_psnr: gt_img = util.tensor2img(visuals['hq'][b]) # uint8 sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format( img_name, self.current_step) save_img_path = os.path.join(img_dir, img_base_name) util.save_img(sr_img, save_img_path) avg_psnr = avg_psnr / idx avg_fea_loss = avg_fea_loss / idx # log self.logger.info( '# Validation # PSNR: {:.4e} Fea: {:.4e}'.format( avg_psnr, avg_fea_loss)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt[ 'name'] and self.rank <= 0: self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) if len(self.evaluators ) != 0 and self.current_step % opt['train']['val_freq'] == 0: eval_dict = {} for eval in self.evaluators: if eval.uses_all_ddp or self.rank <= 0: eval_dict.update(eval.perform_eval()) if self.rank <= 0: print("Evaluator results: ", eval_dict) for ek, ev in eval_dict.items(): self.tb_logger.add_scalar(ek, ev, self.current_step) if opt['wandb']: import wandb wandb.log(eval_dict)
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '1' stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False #### model data_mode = 'sharp_bicubic' if stage == 1: model_path = '../experiments/001_EDVRwoTSA_scratch_lr4e-4_600k_SR4K_LrCAR4S/models/200000_G.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' N_in = 3 # use N_in images to restore one HR image predeblur, HR_in = False, False back_RBs = 10 if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, w_TSA=False) #### dataset val_txt = '/home/mcc/4khdr/val.txt' if stage == 1: test_dataset_folder = '/home/mcc/4khdr/image/540p' GT_dataset_folder = '/home/mcc/4khdr/image/4k' else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = False save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] with open(val_txt, 'r') as f: split_name_l = [] name_l = [x.strip() for x in f.readlines()] # for name in name_l: # for i in range(4): # split_name_l.append(name + 'x{}'.format(i)) subfolder_l = sorted( [osp.join(test_dataset_folder, name) for name in name_l]) subfolder_GT_l = sorted( [osp.join(GT_dataset_folder, name) for name in name_l]) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # evaluate on RGB channels output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) sys.stdout.write('\r' + '{:03d}/{:03d}'.format(img_idx, len(img_path_l))) sys.stdout.flush() if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - frames {} - Average PSNR: {:.6f} dB; ' 'Center PSNR: {:.6f} dB; ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, (N_center + N_border), avg_psnr, avg_psnr_center, avg_psnr_border)) break logger.info('################ Tidy Outputs ################') for i in range(len(subfolder_name_l)): logger.info( 'Folder {} - Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB. ' .format(subfolder_name_l[i], avg_psnr_l[i], avg_psnr_center_l[i], avg_psnr_border_l[i])) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info( 'Total clips {}. Average PSNR: {:.6f} dB, Center PSNR: {:.6f} dB, Border PSNR: {:.6f} dB.' 'Score: {:.6f}'.format(len(subfolder_l), sum(avg_psnr_l) / len(avg_psnr_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l), sum(avg_psnr_l) / len(avg_psnr_l) / 50))
def main(): ################# # configurations ################# parser = argparse.ArgumentParser() parser.add_argument("--input_path", type=str, required=True) parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--gpu_id", type=str, required=True) #parser.add_argument("--screen_notation", type=str, required=True) parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt, is_train=False) PAD = 32 total_run_time = AverageMeter() print("GPU ", torch.cuda.device_count()) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) device = torch.device('cuda') data_mode = 'sharp_bicubic' flip_test = False Input_folder = args.input_path GT_folder = args.gt_path Result_folder = args.output_path Model_path = args.model_path # create results folder if not os.path.exists(Result_folder): os.makedirs(Result_folder, exist_ok=True) model_path = Model_path N_in = 5 model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'], groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'], back_RBs=opt['network_G']['back_RBs'], predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'], w_TSA=opt['network_G']['w_TSA']) #### dataset test_dataset_folder = Input_folder GT_dataset_folder = GT_folder #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode padding = 'new_info' save_imgs = True save_folder = os.path.join(Result_folder, data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) end = time.time() for subfolder in subfolder_l: input_subfolder = os.path.split(subfolder)[1] subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder) if not os.path.exists(subfolder_GT): continue print("Evaluate Folders: ", input_subfolder) subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images, notice we load yuv img here imgs_LQ = data_util.read_img_seq_yuv(subfolder) # Num x 3 x H x W img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img_yuv(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 avg_rgb_psnr, avg_rgb_psnr_border, avg_rgb_psnr_center = 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) # 960 x 540 # here we split the input images 960x540 into 9 320x180 patch gtWidth = 3840 gtHeight = 2160 intWidth_ori = imgs_in.shape[4] # 960 intHeight_ori = imgs_in.shape[3] # 540 scale = 4 intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) imgs_in = torch.squeeze(imgs_in, 0) # N C H W imgs_in = pader(imgs_in) # N C 604 1024 # todo: output 4k X0 = imgs_in X0 = torch.unsqueeze(X0, 0) if flip_test: output = util.flipx4_forward(model, X0) else: output = util.single_forward(model, X0) # todo remove padding output = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight_ori) * scale, intPaddingLeft * scale: (intPaddingLeft + intWidth_ori) * scale] output = util.tensor2img(output.squeeze(0)) print("*****************current image process time \t " + str( time.time() - end) + "s ******************") total_run_time.update(time.time() - end, 1) # calculate PSNR on YUV y_all = output / 255. GT = np.copy(img_GT_l[img_idx]) y_all, GT = util.crop_border([y_all, GT], crop_border) crt_psnr = util.calculate_psnr(y_all * 255, GT * 255) logger.info('{:3d} - {:25} \tYUV_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) # here we also calculate PSNR on RGB y_all_rgb = data_util.ycbcr2rgb(output / 255.) GT_rgb = data_util.ycbcr2rgb(np.copy(img_GT_l[img_idx])) y_all_rgb, GT_rgb = util.crop_border([y_all_rgb, GT_rgb], crop_border) crt_rgb_psnr = util.calculate_psnr(y_all_rgb * 255, GT_rgb * 255) logger.info('{:3d} - {:25} \tRGB_PSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_rgb_psnr)) if save_imgs: im_out = np.round(y_all_rgb*255.).astype(numpy.uint8) # todo, notice here we got rgb img, but cv2 need bgr when saving a img cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), cv2.cv2Color(im_out, cv2.COLOR_RGB2BGR)) # for YUV and RGB, respectively if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr avg_rgb_psnr_center += crt_rgb_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr avg_rgb_psnr_border += crt_rgb_psnr N_border += 1 # for YUV avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average YUV PSNR: {:.6f} dB for {} frames; ' 'Center YUV PSNR: {:.6f} dB for {} frames; ' 'Border YUV PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) # for RGB avg_rgb_psnr = (avg_rgb_psnr_center + avg_rgb_psnr_border) / (N_center + N_border) avg_rgb_psnr_center = avg_rgb_psnr_center / N_center avg_rgb_psnr_border = 0 if N_border == 0 else avg_rgb_psnr_border / N_border avg_rgb_psnr_l.append(avg_rgb_psnr) avg_rgb_psnr_center_l.append(avg_rgb_psnr_center) avg_rgb_psnr_border_l.append(avg_rgb_psnr_border) logger.info('Folder {} - Average RGB PSNR: {:.6f} dB for {} frames; ' 'Center RGB PSNR: {:.6f} dB for {} frames; ' 'Border RGB PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_rgb_psnr, (N_center + N_border), avg_rgb_psnr_center, N_center, avg_rgb_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') # for YUV for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average YUV PSNR: {:.6f} dB. ' 'Center YUV PSNR: {:.6f} dB. ' 'Border YUV PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) # for RGB for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_rgb_psnr_l, avg_rgb_psnr_center_l, avg_rgb_psnr_border_l): logger.info('Folder {} - Average RGB PSNR: {:.6f} dB. ' 'Center RGB PSNR: {:.6f} dB. ' 'Border RGB PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average YUV PSNR: {:.6f} dB for {} clips. ' 'Center YUV PSNR: {:.6f} dB. Border YUV PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l))) logger.info('Total Average RGB PSNR: {:.6f} dB for {} clips. ' 'Center RGB PSNR: {:.6f} dB. Border RGB PSNR: {:.6f} dB.'.format( sum(avg_rgb_psnr_l) / len(avg_rgb_psnr_l), len(subfolder_l), sum(avg_rgb_psnr_center_l) / len(avg_rgb_psnr_center_l), sum(avg_rgb_psnr_border_l) / len(avg_rgb_psnr_border_l)))
def main(): ################# # configurations ################# os.environ['CUDA_VISIBLE_DEVICES'] = '0' save_imgs = False prog = argparse.ArgumentParser() prog.add_argument('--train_mode', '-t', type=str, default='Vimeo', help='train mode') prog.add_argument('--data_mode', '-m', type=str, default=None, help='data_mode') prog.add_argument('--degradation_mode', '-d', type=str, default='impulse', choices=('impulse', 'bicubic', 'preset'), help='path to image output directory.') prog.add_argument('--sigma_x', '-sx', type=float, default=1, help='sigma_x') prog.add_argument('--sigma_y', '-sy', type=float, default=0, help='sigma_y') prog.add_argument('--theta', '-th', type=float, default=0, help='theta') args = prog.parse_args() train_mode = args.train_mode data_mode = args.data_mode if data_mode is None: if train_mode == 'Vimeo': data_mode = 'Vid4' elif train_mode == 'REDS': data_mode = 'REDS' degradation_mode = args.degradation_mode # impulse | bicubic | preset sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta if sig_y == 0: sig_y = sig_x # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52) scale = 2 layer = 16 assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), (4, 52) ], 'Unrecognized (scale, layer) combination' # model N_in = 7 # model_path = '../experiments/pretrained_models/DUF_{}L_BLIND_{}_FT_report.pth'.format(layer, train_mode[0]) model_path = '../experiments/pretrained_models/DUF_{}L_{}_S{}.pth'.format( layer, train_mode, scale) # model_path = '../experiments/pretrained_models/DUF_x2_16L_official.pth' adapt_official = True # if 'official' in model_path else False DUF_downsampling = False # True | False if layer == 16: model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official) elif layer == 28: model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official) elif layer == 52: model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official) #### dataset folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str( '{:.1f}'.format(sig_x)) + '_' + str( '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the)) # folder_subname = degradation_mode + '_' + str('{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the)) if data_mode == 'Vid4': # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale) test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format( folder_subname, scale) GT_dataset_folder = '../dataset/Vid4/HR' elif data_mode == 'MM522': test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format( scale) GT_dataset_folder = '../dataset/MM522val/HR' else: # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale) test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format( folder_subname, scale) GT_dataset_folder = '../dataset/REDS/train/HR' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode padding = 'new_info' # different from the official testing codes, which pads zeros. ############################################################################ device = torch.device('cuda') save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) def read_image(img_path): '''read one image from img_path Return img: HWC, BGR, [0,1], numpy ''' img_GT = cv2.imread(img_path) img = img_GT.astype(np.float32) / 255. return img def read_seq_imgs(img_seq_path): '''read a sequence of images''' img_path_l = sorted(glob.glob(img_seq_path + '/*')) img_l = [read_image(v) for v in img_path_l] # stack to TCHW, RGB, [0,1], torch imgs = np.stack(img_l, axis=0) imgs = imgs[:, :, :, [2, 1, 0]] imgs = torch.from_numpy( np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() return imgs def index_generation(crt_i, max_n, N, padding='reflection'): ''' padding: replicate | reflection | new_info | circle ''' max_n = max_n - 1 n_pad = N // 2 return_l = [] for i in range(crt_i - n_pad, crt_i + n_pad + 1): if i < 0: if padding == 'replicate': add_idx = 0 elif padding == 'reflection': add_idx = -i elif padding == 'new_info': add_idx = (crt_i + n_pad) + (-i) elif padding == 'circle': add_idx = N + i else: raise ValueError('Wrong padding mode') elif i > max_n: if padding == 'replicate': add_idx = max_n elif padding == 'reflection': add_idx = max_n * 2 - i elif padding == 'new_info': add_idx = (crt_i - n_pad) - (i - max_n) elif padding == 'circle': add_idx = i - N else: raise ValueError('Wrong padding mode') else: add_idx = i return_l.append(add_idx) return return_l def single_forward(model, imgs_in): with torch.no_grad(): model_output = model(imgs_in) if isinstance(model_output, list) or isinstance( model_output, tuple): output = model_output[0] else: output = model_output return output sub_folder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) sub_folder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) if data_mode == 'REDS': sub_folder_GT_l = [ k for k in sub_folder_GT_l if k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0 ] #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], [] subfolder_name_l = [] # for each sub-folder for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l): sub_folder_name = sub_folder.split('/')[-1] subfolder_name_l.append(sub_folder_name) save_sub_folder = osp.join(save_folder, sub_folder_name) img_path_l = sorted(glob.glob(sub_folder + '/*')) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_sub_folder) #### read LR images imgs = read_seq_imgs(sub_folder) #### read GT images img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))): img_GT_l.append(read_image(img_GT_path)) # When using the downsampling in DUF official code, we downsample the HR images if DUF_downsampling: sub_folder = sub_folder_GT img_path_l = sorted(glob.glob(sub_folder + '/*')) max_idx = len(img_path_l) imgs = read_seq_imgs(sub_folder) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0 # process each image num_images = len(img_path_l) for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] c_idx = int(osp.splitext(osp.basename(img_path))[0]) select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) # get input images imgs_in = imgs.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) # Downsample the HR images H, W = imgs_in.size(3), imgs_in.size(4) if DUF_downsampling: imgs_in = util.DUF_downsample(imgs_in, sigma=1.3, scale=scale) output = single_forward(model, imgs_in) # Crop to the original shape if scale == 3: pad_h = scale - (H % scale) pad_w = scale - (W % scale) if pad_h > 0: output = output[:, :, :-pad_h, :] if pad_w > 0: output = output[:, :, :, :-pad_w] output_f = output.data.float().cpu().squeeze(0) output = util.tensor2img(output_f) # save imgs if save_imgs: cv2.imwrite( osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) #### calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) output = (output * 255).round().astype('uint8') GT = (GT * 255).round().astype('uint8') output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output, GT) crt_ssim = util.calculate_ssim(output, GT) logger.info( '{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format( img_idx + 1, img_name, crt_psnr, crt_ssim)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr avg_ssim_center += crt_ssim N_center += 1 else: # border frames avg_psnr_border += crt_psnr avg_ssim_border += crt_ssim N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) avg_ssim_center = avg_ssim_center / N_center avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border avg_ssim_l.append(avg_ssim) avg_ssim_center_l.append(avg_ssim_center) avg_ssim_border_l.append(avg_ssim_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( sub_folder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) ''' logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; ' 'Center SSIM: {:.6f} for {} frames; ' 'Border SSIM: {:.6f} for {} frames.'.format(sub_folder_name, avg_ssim, (N_center + N_border), avg_ssim_center, N_center, avg_ssim_border, N_border)) ''' ''' logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l): logger.info('Folder {} - Average SSIM: {:.6f}. ' 'Center SSIM: {:.6f}. ' 'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border)) ''' logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l))) logger.info('Total Average SSIM: {:.6f} for {} clips. ' 'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format( sum(avg_ssim_l) / len(avg_ssim_l), len(sub_folder_l), sum(avg_ssim_center_l) / len(avg_ssim_center_l), sum(avg_ssim_border_l) / len(avg_ssim_border_l))) print('\n\n\n')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### loading resume state if exists if 'resume_latest' in opt and opt['resume_latest'] == True: if os.path.isdir(opt['path']['training_state']): name_state_files = os.listdir(opt['path']['training_state']) if len(name_state_files) > 0: latest_state_num = 0 for name_state_file in name_state_files: state_num = int(name_state_file.split('.')[0]) if state_num > latest_state_num: latest_state_num = state_num opt['path']['resume_state'] = os.path.join( opt['path']['training_state'], str(latest_state_num)+'.state') else: raise ValueError if opt['path'].get('resume_state', None): device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'] + '_{}'.format(util.get_timestamp())) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) train_sampler = None train_loader = create_dataloader( train_set, dataset_opt, opt, train_sampler) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training is_time = False logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) if is_time: batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') for epoch in range(start_epoch, total_epochs + 1): if current_step > total_iters: break if is_time: torch.cuda.synchronize() end = time.time() for _, train_data in enumerate(train_loader): if 'adv_train' in opt: current_step += opt['adv_train']['m'] else: current_step += 1 if current_step > total_iters: break #### training model.feed_data(train_data) if is_time: torch.cuda.synchronize() data_time.update(time.time() - end) model.optimize_parameters(current_step) #### update learning rate model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) if is_time: torch.cuda.synchronize() batch_time.update(time.time() - end) #### log if current_step % opt['logger']['print_freq'] == 0: # FIXME remove debug debug = True if debug: torch.cuda.empty_cache() logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) if is_time: logger.info(str(data_time)) logger.info(str(batch_time)) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan']: # image restoration validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join( opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border( [sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) tb_logger.flush() if is_time: torch.cuda.synchronize() end = time.time() logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()
def main(name_flag, input_path, gt_path, model_path, save_path, save_imgs, flip_test): device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' save_path = os.path.join(save_path, name_flag) #### model model = CNLRN_arch.CNLRN(n_colors=3, n_deblur_blocks=20, n_nlrgs_body=6, n_nlrgs_up1=2, n_nlrgs_up2=2, n_subgroups=2, n_rcabs=4, n_feats=64, nonlocal_psize=(4, 4, 4, 4), scale=4) model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) #### logger util.mkdirs(save_path) util.setup_logger('base', save_path, name_flag, level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info('Evaluate: {}'.format(name_flag)) logger.info('Input images path: {}'.format(input_path)) logger.info('GT images path: {}'.format(gt_path)) logger.info('Model path: {}'.format(model_path)) logger.info('Results save path: {}'.format(save_path)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Save images: {}'.format(save_imgs)) #### Evaluation total_psnr_l = [] total_ssim_l = [] img_path_l = sorted(glob.glob(osp.join(input_path, '*'))) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(input_path) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(gt_path, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] imgs_in = imgs_LQ[img_idx:img_idx + 1].to(device) if flip_test: output = util.flipx8_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite(osp.join(save_path, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) output, GT = util.crop_border([output, GT], crop_border=4) crt_psnr = util.calculate_psnr(output * 255, GT * 255) crt_ssim = util.ssim(output * 255, GT * 255) total_psnr_l.append(crt_psnr) total_ssim_l.append(crt_ssim) logger.info('{} \tPSNR: {:.3f} \tSSIM: {:.4f}'.format( img_name, crt_psnr, crt_ssim)) logger.info('################ Final Results ################') logger.info('Evaluate: {}'.format(name_flag)) logger.info('Input images path: {}'.format(input_path)) logger.info('GT images path: {}'.format(gt_path)) logger.info('Model path: {}'.format(model_path)) logger.info('Results save path: {}'.format(save_path)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Save images: {}'.format(save_imgs)) logger.info( 'Total Average PSNR: {:.3f} SSIM: {:.4f} for {} images.'.format( sum(total_psnr_l) / len(total_psnr_l), sum(total_ssim_l) / len(total_ssim_l), len(img_path_l)))
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '1' flip_test = False scale = 4 N_in = 5 predeblur, HR_in = False, False n_feats = 128 back_RBs = 40 save_imgs = False prog = argparse.ArgumentParser() prog.add_argument('--train_mode', '-t', type=str, default='REDS', help='train mode') prog.add_argument('--data_mode', '-m', type=str, default=None, help='data_mode') prog.add_argument('--degradation_mode', '-d', type=str, default='impulse', choices=('impulse', 'bicubic', 'preset'), help='path to image output directory.') prog.add_argument('--sigma_x', '-sx', type=float, default=1, help='sigma_x') prog.add_argument('--sigma_y', '-sy', type=float, default=0, help='sigma_y') prog.add_argument('--theta', '-th', type=float, default=0, help='theta') args = prog.parse_args() train_data_mode = args.train_mode data_mode = args.data_mode if data_mode is None: if train_data_mode == 'Vimeo': data_mode = 'Vid4' elif train_data_mode == 'REDS': data_mode = 'REDS' degradation_mode = args.degradation_mode # impulse | bicubic | preset sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta if sig_y == 0: sig_y = sig_x ############################################################################ #### model if scale == 2: if train_data_mode == 'Vimeo': model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_M_Scale2_FT.pth' #model_path = '../experiments/pretrained_models/EDVR_M_BLIND_V_FT_report.pth' # model_path = '../experiments/pretrained_models/2500_G.pth' elif train_data_mode == 'REDS': model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M_Scale2.pth' # model_path = '../experiments/pretrained_models/EDVR_M_BLIND_R_FT_report.pth' elif train_data_mode == 'Both': model_path = '../experiments/pretrained_models/EDVR_REDS+Vimeo90K_SR_M_Scale2_FT.pth' elif train_data_mode == 'MM522': model_path = '../experiments/pretrained_models/EDVR_MM522_SR_M_Scale2_FT.pth' else: raise NotImplementedError else: if data_mode == 'Vid4': model_path = '../experiments/pretrained_models/EDVR_BLIND_Vimeo_SR_L.pth' # model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' elif data_mode == 'REDS': model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' # model_path = '../experiments/pretrained_models/EDVR_BLIND_REDS_SR_L.pth' else: raise NotImplementedError model = EDVR_arch.EDVR(n_feats, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in, scale=scale) folder_subname = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str( '{:.1f}'.format(sig_x)) + '_' + str( '{:.1f}'.format(sig_y)) + '_' + str('{:.1f}'.format(the)) #### dataset if data_mode == 'Vid4': # test_dataset_folder = '../dataset/Vid4/LR_bicubic/X{}'.format(scale) test_dataset_folder = '../dataset/Vid4/LR_{}/X{}'.format( folder_subname, scale) GT_dataset_folder = '../dataset/Vid4/HR' elif data_mode == 'MM522': test_dataset_folder = '../dataset/MM522val/LR_bicubic/X{}'.format( scale) GT_dataset_folder = '../dataset/MM522val/HR' else: # test_dataset_folder = '../dataset/REDS4/LR_bicubic/X{}'.format(scale) test_dataset_folder = '../dataset/REDS/train/LR_{}/X{}'.format( folder_subname, scale) GT_dataset_folder = '../dataset/REDS/train/HR' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode padding = 'new_info' save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) if data_mode == 'REDS': subfolder_GT_l = [ k for k in subfolder_GT_l if k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0 ] # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) ''' output_tensor = torch.from_numpy(np.copy(output[:,:,::-1])).permute(2,0,1) GT_tensor = torch.from_numpy(np.copy(GT[:,:,::-1])).permute(2,0,1).type_as(output_tensor) torch.save(output_tensor.cpu(), '../results/sr_test.pt') torch.save(GT_tensor.cpu(), '../results/hr_test.pt') my_psnr = utility.calc_psnr(output_tensor, GT_tensor) GT_tensor = GT_tensor.cpu().numpy().transpose(1,2,0) imageio.imwrite('../results/hr_test.png', GT_tensor) print('saved', my_psnr) ''' ''' # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4' or 'sharp_bicubic' or 'MM522': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) ''' output = (output * 255).round().astype('uint8') GT = (GT * 255).round().astype('uint8') output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output, GT) crt_ssim = 0.001 #util.calculate_ssim(output, GT) # logger.info('{:3d} - {:16} \tPSNR: {:.6f} dB \tSSIM: {:.6f}'.format(img_idx + 1, img_name, crt_psnr, crt_ssim)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr avg_ssim_center += crt_ssim N_center += 1 else: # border frames avg_psnr_border += crt_psnr avg_ssim_border += crt_ssim N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) avg_ssim_center = avg_ssim_center / N_center avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border avg_ssim_l.append(avg_ssim) avg_ssim_center_l.append(avg_ssim_center) avg_ssim_border_l.append(avg_ssim_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; ' 'Center SSIM: {:.6f} for {} frames; ' 'Border SSIM: {:.6f} for {} frames.'.format( subfolder_name, avg_ssim, (N_center + N_border), avg_ssim_center, N_center, avg_ssim_border, N_border)) ''' logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) for subfolder_name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l): logger.info('Folder {} - Average SSIM: {:.6f}. ' 'Center SSIM: {:.6f}. ' 'Border SSIM: {:.6f}.'.format(subfolder_name, ssim, ssim_center, ssim_border)) ''' logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l))) logger.info('Total Average SSIM: {:.6f} for {} clips. ' 'Center SSIM: {:.6f}. Border PSNR: {:.6f}.'.format( sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l), sum(avg_ssim_center_l) / len(avg_ssim_center_l), sum(avg_ssim_border_l) / len(avg_ssim_border_l))) print('\n\n\n')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) #### Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info('Number of test images in [{:s}]: {:d}'.format( dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = osp.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_GT = False if test_loader.dataset.opt[ 'dataroot_GT'] is None else True model.feed_data(data, need_GT=need_GT) img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] img_name = osp.splitext(osp.basename(img_path))[0] model.test() visuals = model.get_current_visuals(need_GT=need_GT) sr_img = util.tensor2img(visuals['rlt']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = osp.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_GT: gt_img = util.tensor2img(visuals['GT']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) psnr = util.calculate_psnr(sr_img, gt_img) ssim = util.calculate_ssim(sr_img, gt_img) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.' .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format( img_name, psnr, ssim)) else: logger.info(img_name) if need_GT: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info( '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n' .format(test_set_name, ave_psnr, ave_ssim)) if test_results['psnr_y'] and test_results['ssim_y']: ave_psnr_y = sum(test_results['psnr_y']) / len( test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len( test_results['ssim_y']) logger.info( '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n' .format(ave_psnr_y, ave_ssim_y))
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '6' test_set = 'AI4K_val' # Vid4 | YouKu10 | REDS4 | AI4K_val | zhibo | AI4K_val_bic test_name = 'PCD_Vis_Test_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_A01xxx_900000_AI4K_5000' # 'AI4K_val_Denoise_A02_420000' data_mode = 'sharp_bicubic' # sharp_bicubic | blur_bicubic N_in = 5 # load test set if test_set == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' elif test_set == 'YouKu10': test_dataset_folder = '../datasets/YouKu10/LR' GT_dataset_folder = '../datasets/YouKu10/HR' elif test_set == 'YouKu_val': test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp' GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp' elif test_set == 'REDS4': test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) GT_dataset_folder = '../datasets/REDS4/GT' elif test_set == 'AI4K_val': test_dataset_folder = '/home/yhliu/AI4K/contest2/val2_LR_png/' GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png/' elif test_set == 'AI4K_val_bic': test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_bic/' GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_bic/' elif test_set == 'zhibo': test_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_LR/' GT_dataset_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/Test_video_HR/' flip_test = False #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' #model_path = '../experiments/A01b/models/250000_G.pth' #model_path = '../experiments/A02_predenoise/models/415000_G.pth' model_path = '../experiments/A37_color_EDVR_35_220000_A01_5in_64f_10b_128_pretrain_A01xxx_900000_fix_before_pcd/models/5000_G.pth' predeblur, HR_in = False, False back_RBs = 10 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #model = my_EDVR_arch.MYEDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #model = my_EDVR_arch.MYEDVR_RES(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True #True | False save_folder = '../results/{}'.format(test_name) if test_set == 'zhibo': save_folder = '/data1/yhliu/SR_ZHIBO_VIDEO/SR_png_sample_150' util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) model = nn.DataParallel(model) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) print(subfolder_l) print(subfolder_GT_l) #exit() # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) print(img_path_l) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): #print(img_GT_path) img_GT_l.append(data_util.read_img(None, img_GT_path)) #print(img_GT_l[0].shape) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).cpu() #to(device) print(imgs_in.size()) if flip_test: output = util.flipx4_forward(model, imgs_in) else: start_time = time.time() output = util.single_forward(model, imgs_in) end_time = time.time() print('Forward One image:', end_time - start_time) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel ''' if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) ''' output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def val(model_name, current_step, arch='EDVR'): ################# # configurations ################# device = torch.device('cuda') #os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' test_set = 'REDS4' # Vid4 | YouKu10 | REDS4 | AI4K_val data_mode = 'sharp_bicubic' # sharp_bicubic | blur_bicubic N_in = 5 # load test set if test_set == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' elif test_set == 'YouKu10': test_dataset_folder = '../datasets/YouKu10/LR' GT_dataset_folder = '../datasets/YouKu10/HR' elif test_set == 'YouKu_val': test_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_lr_bmp' GT_dataset_folder = '/data0/yhliu/DATA/YouKuVid/valid/valid_hr_bmp' elif test_set == 'REDS4': test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) GT_dataset_folder = '../datasets/REDS4/GT' elif test_set == 'AI4K_val': test_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_LR_png/' GT_dataset_folder = '/data0/yhliu/AI4K/contest1/val1_HR_png/' elif test_set == 'AI4K_val_small': test_dataset_folder = '/home/yhliu/AI4K/contest1/val1_LR_png_small/' GT_dataset_folder = '/home/yhliu/AI4K/contest1/val1_HR_png_small/' flip_test = False #model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' model_path = os.path.join('../experiments/', model_name, 'models/{}_G.pth'.format(current_step)) predeblur, HR_in = False, False back_RBs = 10 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if arch == 'EDVR': model = EDVR_arch.EDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) elif arch == 'MY_EDVR': model = my_EDVR_arch.MYEDVR(64, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = False save_folder = '../validation/{}'.format(test_set) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) model = nn.DataParallel(model, device_ids=[0, 1, 2, 3]) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) #print(subfolder_l) #print(subfolder_GT_l) #exit() # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) #print(img_path_l) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): #print(img_GT_path) img_GT_l.append(data_util.read_img(None, img_GT_path)) #print(img_GT_l[0].shape) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) #print(imgs_in.size()) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel ''' if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) ''' output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l))) return sum(avg_psnr_l) / len(avg_psnr_l)
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key and 'wandb_load_run_path' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) if opt['use_wandb_logger'] and 'debug' not in opt['name']: json_path = os.path.join(os.path.expanduser('~'), '.wandb_api_keys.json') if os.path.exists(json_path): with open(json_path, 'r') as j: json_file = json.loads(j.read()) os.environ['WANDB_API_KEY'] = json_file['ryul99'] wandb.init(project="mmsr", config=opt, sync_tensorboard=True) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) if opt['use_wandb_logger'] and 'debug' not in opt['name']: wandb.config.update({'random_seed': seed}) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data, noise_mode=opt['datasets']['train']['noise_mode'], noise_rate=opt['datasets']['train']['noise_rate']) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt['name']: if rank <= 0: wandb.log({k: v}, step=current_step) if rank <= 0: logger.info(message) #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data( val_data, noise_mode=opt['datasets']['val']['noise_mode'], noise_rate=opt['datasets']['val']['noise_rate']) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) if opt['use_wandb_logger'] and 'debug' not in opt['name']: wandb.log({'psnr': avg_psnr}, step=current_step) else: # video restoration validation if opt['dist']: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data, noise_mode=opt['datasets']['val'] ['noise_mode'], noise_rate=opt['datasets']['val'] ['noise_rate']) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr( rlt_img, gt_img) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt[ 'name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt[ 'name']: lq_img, rlt_img, gt_img = map( util.tensor2img, [ visuals['LQ'], visuals['rlt'], visuals['GT'] ]) wandb.log({'psnr_avg': psnr_total_avg}, step=current_step) wandb.log(psnr_rlt_avg, step=current_step) wandb.log( { 'Validation Image': [ wandb.Image(lq_img[:, :, [2, 1, 0]], caption='LQ'), wandb.Image(rlt_img[:, :, [2, 1, 0]], caption='output'), wandb.Image(gt_img[:, :, [2, 1, 0]], caption='GT'), ] }, step=current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. for val_data in val_loader: folder = val_data['folder'][0] idx_d = val_data['idx'].item() # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data, noise_mode=opt['datasets']['val'] ['noise_mode'], noise_rate=opt['datasets']['val'] ['noise_rate']) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update('Test {} - {}'.format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt[ 'name']: lq_img, rlt_img, gt_img = map( util.tensor2img, [visuals['LQ'], visuals['rlt'], visuals['GT']]) wandb.log({'psnr_avg': psnr_total_avg}, step=current_step) wandb.log(psnr_rlt_avg, step=current_step) wandb.log( { 'Validation Image': [ wandb.Image(lq_img[:, :, [2, 1, 0]], caption='LQ'), wandb.Image(rlt_img[:, :, [2, 1, 0]], caption='output'), wandb.Image(gt_img[:, :, [2, 1, 0]], caption='GT'), ] }, step=current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.close()
def main(): ############################################ # # set options # ############################################ parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) ############################################ # # distributed training settings # ############################################ if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("Rank:", rank) print("------------------DIST-------------------------") ############################################ # # loading resume state if exists # ############################################ if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None ############################################ # # mkdir and loggers # ############################################ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('base_val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_', level=logging.INFO, screen=True) print("set train log") util.setup_logger('base_val', opt['path']['log'], 'val_', level=logging.INFO, screen=True) print("set val log") logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ############################################ # # create train and val dataloader # ############################################ #### # dataset_ratio = 200 # enlarge the size of each epoch, todo: what it is dataset_ratio = 1 # enlarge the size of each epoch, todo: what it is for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) # total_iters = int(opt['train']['niter']) # total_epochs = int(math.ceil(total_iters / train_size)) total_iters = train_size total_epochs = int(opt['train']['epoch']) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) total_epochs = int(opt['train']['epoch']) if opt['train']['enable'] == False: total_epochs = 1 else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None ############################################ # # create model # ############################################ #### model = create_model(opt) print("Model Created! ") #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 print("Not Resume Training") ############################################ # # training # ############################################ #### #### logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) Avg_train_loss = AverageMeter() # total if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() Avg_train_loss_vmaf = AverageMeter() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() saved_total_loss = 10e10 saved_total_PSNR = -1 for epoch in range(start_epoch, total_epochs): ############################################ # # Start a new epoch # ############################################ # Turn into training mode #model = model.train() # reset total loss Avg_train_loss.reset() current_step = 0 if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() Avg_train_loss_vmaf.reset() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() if opt['dist']: train_sampler.set_epoch(epoch) for train_idx, train_data in enumerate(train_loader): if 'debug' in opt['name']: img_dir = os.path.join(opt['path']['train_images']) util.mkdir(img_dir) LQ = train_data['LQs'] GT = train_data['GT'] GT_img = util.tensor2img(GT) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT')) util.save_img(GT_img, save_img_path) for i in range(5): LQ_img = util.tensor2img(LQ[0, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ', i)) util.save_img(LQ_img, save_img_path) if (train_idx >= 3): break if opt['train']['enable'] == False: message_train_loss = 'None' break current_step += 1 if current_step > total_iters: print("Total Iteration Reached !") break #### update learning rate if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': pass else: model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': # model.optimize_parameters_without_schudlue(current_step) # else: model.optimize_parameters(current_step) if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1) elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) else: Avg_train_loss.update(model.log_dict['l_pix'], 1) # add total train loss if (opt['train']['pixel_criterion'] == 'cb+ssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' vmaf_avg_loss: {:.4e}'.format( Avg_train_loss_vmaf.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'ssim'): message_train_loss = ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'msssim'): message_train_loss = ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) else: message_train_loss = ' train_avg_loss: {:.4e}'.format( Avg_train_loss.avg) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) message += message_train_loss if rank <= 0: logger.info(message) ############################################ # # end of one epoch, save epoch model # ############################################ #### save models and training states # if current_step % opt['logger']['save_checkpoint_freq'] == 0: # if rank <= 0: # logger.info('Saving models and training states.') # model.save(current_step) # model.save('latest') # # model.save_training_state(epoch, current_step) # # todo delete previous weights # previous_step = current_step - opt['logger']['save_checkpoint_freq'] # save_filename = '{}_{}.pth'.format(previous_step, 'G') # save_path = os.path.join(opt['path']['models'], save_filename) # if os.path.exists(save_path): # os.remove(save_path) if epoch == 1: save_filename = '{:04d}_{}.pth'.format(0, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) if rank <= 0: logger.info('Saving models and training states.') save_filename = '{:04d}'.format(epoch) model.save(save_filename) # model.save('latest') # model.save_training_state(epoch, current_step) ############################################ # # end of one epoch, do validation # ############################################ #### validation #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None): if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) #util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # todo : multi-GPU testing psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): print('idx', idx) if 'debug' in opt['name']: if (idx >= 3): break val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) # model.test() # model.test_stitch() if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # visuals = model.get_current_visuals() visuals = model.get_current_visuals( save=True, name='{}_{}'.format(folder, idx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder][idx_d] = psnr # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder][idx_d] = ssim # calculate Val loss val_loss = model.get_loss() val_loss_rlt[folder][idx_d] = val_loss logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, idx, psnr, ssim)) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) for _, v in ssim_rlt.items(): dist.reduce(v, 0) for _, v in val_loss_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # ssim ssim_rlt_avg = {} ssim_total_avg = 0. for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = torch.mean(v).cpu().item() ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # added val_loss_rlt_avg = {} val_loss_total_avg = 0. for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = torch.mean(v).cpu().item() val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: # Todo: our function One GPU pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. for val_inx, val_data in enumerate(val_loader): if 'debug' in opt['name']: if (val_inx >= 5): break folder = val_data['folder'][0] # idx_d = val_data['idx'].item() idx_d = val_data['idx'] # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = [] if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = [] # process the black blank [B N C H W] print(val_data['LQs'].size()) H_S = val_data['LQs'].size(3) # 540 W_S = val_data['LQs'].size(4) # 960 print(H_S) print(W_S) blank_1_S = 0 blank_2_S = 0 print(val_data['LQs'][0, 2, 0, :, :].size()) for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0: blank_1_S = i - 1 # assert not sum(data_S[:, :, 0][i+1]) == 0 break for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, :, H_S - i - 1]) == 0: blank_2_S = (H_S - 1) - i - 1 # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0 break print('LQ :', blank_1_S, blank_2_S) if blank_1_S == -1: print('LQ has no blank') blank_1_S = 0 blank_2_S = H_S # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:] print("LQ", val_data['LQs'].size()) # end of process the black blank model.feed_data(val_data) if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # process blank blank_1_L = blank_1_S << 2 blank_2_L = blank_2_S << 2 print(blank_1_L, blank_2_L) print(model.fake_H.size()) if not blank_1_S == 0: # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:] model.fake_H[:, :, 0:blank_1_L, :] = 0 model.fake_H[:, :, blank_2_L:H_S, :] = 0 # end of # process blank visuals = model.get_current_visuals( save=True, name='{}_{:02d}'.format(folder, val_inx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder].append(ssim) # val loss val_loss = model.get_loss() val_loss_rlt[folder].append(val_loss.item()) logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, val_inx, psnr, ssim)) pbar.update('Test {} - {}'.format(folder, idx_d)) # average PSNR for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average SSIM for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average VMAF # average Val LOSS for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = sum(v) / len(v) val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) # toal validation log message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) # end add if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step) # for k, v in ssim_rlt_avg.items(): # tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) ############################################ # # end of validation, save model # ############################################ # logger.info("Finished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_loss_total_avg: saved_total_loss = val_loss_total_avg #torch.save(model.state_dict(), args.save_path + "/best" + ".pth") model.save('best') logger.info( "Best Weights updated for decreased validation loss") else: logger.info( "Weights Not updated for undecreased validation loss") if saved_total_PSNR <= psnr_total_avg: saved_total_PSNR = psnr_total_avg model.save('bestPSNR') logger.info( "Best Weights updated for increased validation PSNR") else: logger.info( "Weights Not updated for unincreased validation PSNR") ############################################ # # end of one epoch, schedule LR # ############################################ # add scheduler todo if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': for scheduler in model.schedulers: # scheduler.step(val_loss_total_avg) scheduler.step(val_loss_total_avg) if rank <= 0: logger.info('Saving the final model.') model.save('last') logger.info('End of training.') tb_logger.close()
def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '1' data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: #model_path = '../experiments/pretrained_models/EDVR_REDS_SR_M.pth' model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth' else: raise ValueError('Vid4 does not support stage 2.') elif data_mode == 'sharp_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' elif data_mode == 'blur_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' elif data_mode == 'blur': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' elif data_mode == 'blur_comp': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' else: raise NotImplementedError if data_mode == 'Vid4': N_in = 5 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 10 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(64, 5, 8, 5, 10, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/540p_frames' GT_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/valid/4k_frames' #test_dataset_folder = '../datasets/Vid4/BIx4' #GT_dataset_folder = '../datasets/Vid4/GT' else: if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=False) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main(): # Create object for parsing command-line options parser = argparse.ArgumentParser(description="Test with EDVR, requre path to test dataset folder.") # Add argument which takes path to a bag file as an input parser.add_argument("-i", "--input", type=str, help="Path to test folder") # Parse the command line arguments to an object args = parser.parse_args() # Safety if no parameter have been given if not args.input: print("No input paramater have been given.") print("For help type --help") exit() folder_name = args.input.split("/")[-1] if folder_name == '': index = len(args.input.split("/")) - 2 folder_name = args.input.split("/")[index] ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' else: raise ValueError('Vid4 does not support stage 2.') else: raise NotImplementedError if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': # debug test_dataset_folder = os.path.join(args.input, 'BIx4') GT_dataset_folder = os.path.join(args.input, 'GT') else: if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(folder_name) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
visuals = model.get_current_visuals(need_GT=need_GT) sr_img = util.tensor2img(visuals['rlt']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = osp.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_GT: gt_img = util.tensor2img(visuals['GT']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) psnr = util.calculate_psnr(sr_img, gt_img) ssim = util.calculate_ssim(sr_img, gt_img) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
def main(opts): ################## configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus cache_all_imgs = opts.cache > 0 n_gpus = len(opts.gpus.split(',')) flip_test, save_imgs = False, False scale = 4 N_in, nf = 5, 64 back_RBs = 10 w_TSA = False predeblur, HR_in = False, False crop_border = 0 border_frame = N_in // 2 padding = 'new_info' ################## model files #################### model_dir = opts.model_dir if osp.isfile(model_dir): model_names = [osp.basename(model_dir)] model_dir = osp.dirname(model_dir) elif osp.isdir(model_dir): model_names = [ x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0]) ] model_names = sorted(model_names, key=lambda x: int(x.split("_")[0])) else: raise IOError('Invalid model_dir: {}'.format(model_dir)) ################## dataset ######################## test_subs = sorted(os.listdir(opts.test_dir)) gt_subs = os.listdir(opts.gt_dir) valid_test_subs = [sub in gt_subs for sub in test_subs] assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format( opts.test_dir) scale = float(os.path.basename(os.path.dirname(opts.test_dir))[1:]) if cache_all_imgs: print('Cacheing all testing images ...') all_imgs = {} for sub in test_subs: print('Reading sub-folder: {} ...'.format(sub)) test_sub_dir = osp.join(opts.test_dir, sub) gt_sub_dir = osp.join(opts.gt_dir, sub) all_imgs[sub] = {'test': [], 'gt': []} im_names = sorted(os.listdir(test_sub_dir)) for i, name in enumerate(im_names): test_im_path = osp.join(test_sub_dir, name) gt_im_path = osp.join(gt_sub_dir, name) test_im = cv2.imread(test_im_path, cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)] test_im = test_im.astype(np.float32).transpose( (2, 0, 1)) / 255. all_imgs[sub]['test'].append(test_im) gt_im = cv2.imread(gt_im_path, cv2.IMREAD_UNCHANGED).astype(np.float32) all_imgs[sub]['gt'].append(gt_im) all_psnrs = [] for model_name in model_names: model_path = osp.join(model_dir, model_name) exp_name = model_name.split('_')[0] if 'meta' in opts.mode.lower(): model = EDVR_arch.MetaEDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=back_RBs, predeblur=predeblur, HR_in=HR_in, w_TSA=w_TSA) elif opts.mode.lower() == 'edvr': model = EDVR_arch.EDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=back_RBs, predeblur=predeblur, HR_in=HR_in, w_TSA=w_TSA) elif opts.mode.lower() == 'upedvr': model = EDVR_arch.UPEDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=10, w_TSA=w_TSA, down_scale=True, align_target=True, ret_valid=True) elif opts.mode.lower() == 'upcont1': model = EDVR_arch.UPControlEDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=10, w_TSA=w_TSA, down_scale=True, align_target=True, ret_valid=True, multi_scale_cont=False) elif opts.mode.lower() == 'upcont3': model = EDVR_arch.UPControlEDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=10, w_TSA=w_TSA, down_scale=True, align_target=True, ret_valid=True, multi_scale_cont=True) elif opts.mode.lower() == 'upcont2': model = EDVR_arch.UPControlEDVR(nf=nf, nframes=N_in, groups=8, front_RBs=5, center=None, back_RBs=10, w_TSA=w_TSA, down_scale=True, align_target=True, ret_valid=True, multi_scale_cont=True) else: raise TypeError('Unknown model mode: {}'.format(opts.mode)) save_folder = osp.join(opts.save_dir, exp_name) util.mkdirs(save_folder) util.setup_logger(exp_name, save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger(exp_name) #### log info logger.info('Data: {}'.format(opts.test_dir)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() if n_gpus > 1: model = nn.DataParallel(model) model = model.to(device) avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], [] avg_ssims, avg_ssim_centers, avg_ssim_borders = [], [], [] evaled_subs = [] # for each subfolder for sub in test_subs: evaled_subs.append(sub) test_sub_dir = osp.join(opts.test_dir, sub) gt_sub_dir = osp.join(opts.gt_dir, sub) img_names = sorted(os.listdir(test_sub_dir)) max_idx = len(img_names) if save_imgs: save_subfolder = osp.join(save_folder, sub) util.mkdirs(save_subfolder) #### get LQ and GT images if not cache_all_imgs: img_LQs, img_GTs = [], [] for i, name in enumerate(img_names): test_im_path = osp.join(test_sub_dir, name) gt_im_path = osp.join(gt_sub_dir, name) test_im = cv2.imread(test_im_path, cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)] test_im = test_im.astype(np.float32).transpose( (2, 0, 1)) / 255. gt_im = cv2.imread(gt_im_path, cv2.IMREAD_UNCHANGED).astype(np.float32) img_LQs.append(test_im) img_GTs.append(gt_im) else: img_LQs = all_imgs[sub]['test'] img_GTs = all_imgs[sub]['gt'] avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0 # process each image for i in range(0, max_idx, n_gpus): end = min(i + n_gpus, max_idx) select_idxs = [ data_util.index_generation(j, max_idx, N_in, padding=padding) for j in range(i, end) ] imgs = [] for select_idx in select_idxs: im = torch.from_numpy( np.stack([img_LQs[k] for k in select_idx])) imgs.append(im) if (i + n_gpus) > max_idx: for _ in range(max_idx, i + n_gpus): imgs.append(torch.zeros_like(im)) imgs = torch.stack(imgs, 0).to(device) if flip_test: output = util.flipx4_forward(model, imgs) else: if 'meta' in opts.mode.lower(): output = util.meta_single_forward( model, imgs, scale, n_gpus) if 'up' in opts.mode.lower(): output = util.up_single_forward(model, imgs, scale) else: output = util.single_forward(model, imgs) output = [ util.tensor2img(x).astype(np.float32) for x in output ] if save_imgs: for ii in range(i, end): cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_names[ii])), output[ii - i].astype(np.uint8)) # calculate PSNR GT = np.copy(img_GTs[i:end]) output = util.crop_border(output, crop_border) GT = util.crop_border(GT, crop_border) for m in range(i, end): crt_psnr = util.calculate_psnr(output[m - i], GT[m - i]) crt_ssim = util.calculate_ssim(output[m - i], GT[m - i]) logger.info( '{:3d} - {:25} \tPSNR: {:.6f} dB SSIM: {:.6}'. format(m + 1, img_names[m], crt_psnr, crt_ssim)) if m >= border_frame and m < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr avg_ssim_center += crt_ssim N_center += 1 else: # border frames avg_psnr_border += crt_psnr avg_ssim_border += crt_ssim N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) avg_ssim_center = avg_ssim_center / N_center avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border avg_psnrs.append(avg_psnr) avg_psnr_centers.append(avg_psnr_center) avg_psnr_borders.append(avg_psnr_border) avg_ssims.append(avg_ssim) avg_ssim_centers.append(avg_ssim_center) avg_ssim_borders.append(avg_ssim_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( sub, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('Folder {} - Average SSIM: {:.6f} for {} frames; ' 'Center SSIM: {:.6f} for {} frames; ' 'Border SSIM: {:.6f} for {} frames.'.format( sub, avg_ssim, (N_center + N_border), avg_ssim_center, N_center, avg_ssim_border, N_border)) logger.info('################ Tidy Outputs ################') for sub_name, psnr, psnr_center, psnr_border, ssim, ssim_center, ssim_border in zip( evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders, avg_ssims, avg_ssim_centers, avg_ssim_borders): logger.info( 'Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sub_name, psnr, psnr_center, psnr_border)) logger.info('Folder {} - Average SSIM: {:.6f} ' 'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format( sub_name, ssim, ssim_center, ssim_border)) logger.info('################ Final Results ################') logger.info('Data: {}'.format(opts.test_dir)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnrs) / len(avg_psnrs), len(test_subs), sum(avg_psnr_centers) / len(avg_psnr_centers), sum(avg_psnr_borders) / len(avg_psnr_borders))) logger.info('Total Average SSIM: {:.6f} for {} clips. ' 'Center SSIM: {:.6f} Border SSIM: {:.6f} '.format( sum(avg_ssims) / len(avg_ssims), len(test_subs), sum(avg_ssim_centers) / len(avg_ssim_centers), sum(avg_ssim_borders) / len(avg_ssim_borders)))
def main(): # options parser = argparse.ArgumentParser() parser.add_argument("-opt", type=str, help="Path to option YAML file.") parser.add_argument("--launcher", choices=["none", "pytorch"], default="none", help="job launcher") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) # distributed training settings if args.launcher == "none": # disabled distributed training opt["dist"] = False rank = -1 print("Disabled distributed training.") else: opt["dist"] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() # loading resume state if exists if opt["path"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state["iter"]) # check resume options else: resume_state = None # mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt["path"] ["experiments_root"]) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key)) # config loggers. Before it, the log will not work util.setup_logger("base", opt["path"]["log"], "train_" + opt["name"], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info("You are using PyTorch {}. \ Tensorboard will use [tensorboardX]".format( version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"]) else: util.setup_logger("base", opt["path"]["log"], "train", level=logging.INFO, screen=True) logger = logging.getLogger("base") # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) # random seed seed = opt["train"]["manual_seed"] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info("Random seed: {}".format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt["batch_size"])) total_iters = int(opt["train"]["niter"]) total_epochs = int(math.ceil(total_iters / train_size)) if opt["dist"]: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), train_size)) logger.info("Total epochs needed: {:d} for iters {:,d}".format( total_epochs, total_iters)) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info("Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set))) else: raise NotImplementedError( "Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None # create model model = create_model(opt) print("Model created!") # resume training if resume_state: logger.info("Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"])) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 # training logger.info("Start training from epoch: {:d}, iter: {:d}".format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt["dist"]: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # update learning rate model.update_learning_rate(current_step, warmup_iter=opt["train"]["warmup_iter"]) # training model.feed_data(train_data) model.optimize_parameters(current_step) # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format( epoch, current_step) for v in model.get_current_learning_rate(): message += "{:.3e},".format(v) message += ")] " for k, v in logs.items(): message += "{:s}: {:.4e} ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if opt["datasets"].get( "val", None) and current_step % opt["train"]["val_freq"] == 0: # image restoration validation if opt["model"] in ["sr", "srgan"] and rank <= 0: # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data["LQ_path"][0]))[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt["scale"]) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update("Test {}".format(img_name)) avg_psnr = avg_psnr / idx # log logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr", avg_psnr, current_step) else: # video restoration validation if opt["dist"]: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data["LQs"].unsqueeze_(0) val_data["GT"].unsqueeze_(0) folder = val_data["folder"] idx_d, max_idx = val_data["idx"].split("/") idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device="cuda") model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr( rlt_img, gt_img) if rank == 0: for _ in range(world_size): pbar.update("Test {} - {}/{}".format( folder, idx_d, max_idx)) # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0.0 for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = "# Validation # PSNR: {:.4e}:".format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += " {}: {:.4e}".format(k, v) logger.info(log_s) if opt["use_tb_logger"] and "debug" not in opt[ "name"]: tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0.0 for val_data in val_loader: folder = val_data["folder"][0] idx_d, max_id = val_data["idx"][0].split("/") # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 lq_img = util.tensor2img(visuals["LQ"][2]) # uint8 img_dir = opt["path"]["val_images"] util.mkdir(img_dir) save_img_path = os.path.join( img_dir, "{}.png".format(idx_d)) util.save_img(np.hstack((lq_img, rlt_img, gt_img)), save_img_path) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update("Test {} - {}".format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = "# Validation # PSNR: {:.4e}:".format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += " {}: {:.4e}".format(k, v) logger.info(log_s) if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank <= 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") tb_logger.close()
def main(opts): ################## configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpus with open(opts.cfg, mode='r') as f: cfgs = yaml.load(f) flip_test, save_imgs = cfgs['test']['flip_test'], cfgs['test']['save_imgs'] scale = cfgs['datasets']['scale'] crop_border = cfgs['test']['crop_border'] nframe = cfgs['datasets']['N_frames'] border_frame = nframe // 2 half_nframe = nframe // 2 padding = cfgs['datasets']['padding'] flow_path = cfgs['path']['pretrain_model_F'] cache_all_imgs = cfgs['datasets']['cache_data'] ################## model files #################### model_dir = opts.model_dir if osp.isfile(model_dir): model_names = [osp.basename(model_dir)] model_dir = osp.dirname(model_dir) elif osp.isdir(model_dir): model_names = [ x for x in os.listdir(model_dir) if str.isdigit(x.split('_')[0]) ] model_names = sorted(model_names, key=lambda x: int(x.split("_")[0])) else: raise IOError('Invalid model_dir: {}'.format(model_dir)) ################## dataset ######################## test_subs = sorted(os.listdir(opts.test_dir)) gt_subs = os.listdir(opts.gt_dir) valid_test_subs = [sub in gt_subs for sub in test_subs] assert (all(valid_test_subs)), 'Invalid sub folders exists in {}'.format( opts.test_dir) if cache_all_imgs: print('Cacheing all testing images ...') all_imgs = {} for sub in test_subs: print('Reading sub-folder: {} ...'.format(sub)) test_sub_dir = osp.join(opts.test_dir, sub) gt_sub_dir = osp.join(opts.gt_dir, sub) all_imgs[sub] = {'test': [], 'gt': []} im_names = sorted(os.listdir(test_sub_dir)) for i, name in enumerate(im_names): test_im_path = osp.join(test_sub_dir, name) gt_im_path = osp.join(gt_sub_dir, name) test_im = cv2.imread(test_im_path, cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)] test_im = test_im.astype(np.float32).transpose( (2, 0, 1)) / 255. all_imgs[sub]['test'].append(test_im) gt_im = cv2.imread(gt_im_path, cv2.IMREAD_UNCHANGED).astype(np.float32) all_imgs[sub]['gt'].append(gt_im) #################### model ######################## model = define_G(cfgs) netF = define_F(cfgs) netF.load_state_dict(torch.load(flow_path), strict=True) model = model.to(device) netF = netF.to(device) all_psnrs = [] for model_name in model_names: model_path = osp.join(model_dir, model_name) exp_name = model_name.split('_')[0] save_folder = osp.join(opts.save_dir, exp_name) util.mkdirs(save_folder) util.setup_logger(exp_name, save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger(exp_name) #### log info logger.info('Data: {}'.format(opts.test_dir)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### load model parameters model.load_state_dict(torch.load(model_path), strict=True) model.eval() avg_psnrs, avg_psnr_centers, avg_psnr_borders = [], [], [] evaled_subs = [] # for each subfolder for sub in test_subs: evaled_subs.append(sub) test_sub_dir = osp.join(opts.test_dir, sub) gt_sub_dir = osp.join(opts.gt_dir, sub) img_names = sorted(os.listdir(test_sub_dir)) max_idx = len(img_names) if save_imgs: save_subfolder = osp.join(save_folder, sub) util.mkdirs(save_subfolder) #### get LQ and GT images if not cache_all_imgs: img_LQs, img_GTs = [], [] for i, name in enumerate(img_names): test_im_path = osp.join(test_sub_dir, name) gt_im_path = osp.join(gt_sub_dir, name) test_im = cv2.imread(test_im_path, cv2.IMREAD_UNCHANGED)[:, :, (2, 1, 0)] test_im = test_im.astype(np.float32).transpose( (2, 0, 1)) / 255. gt_im = cv2.imread(gt_im_path, cv2.IMREAD_UNCHANGED).astype(np.float32) img_LQs.append(test_im) img_GTs.append(gt_im) else: img_LQs = all_imgs[sub]['test'] img_GTs = all_imgs[sub]['gt'] avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image previous = None for i in range(0, max_idx): select_idxs = data_util.index_generation(i, max_idx, nframe, padding=padding) lqs = torch.from_numpy( np.stack([img_LQs[k] for k in select_idxs])) lqs = lqs.unsqueeze(0).to(device) upX = F.interpolate(lqs[:, half_nframe, :, :, :], scale_factor=scale, mode='bilinear', align_corners=False) if i == 0: wrap_img = upX else: if previous is not None: pre = previous.clone() else: pre = img_GTs[i - 1][:, :, (2, 1, 0)].transpose( (2, 0, 1)) / 255. pre = torch.from_numpy(pre).unsqueeze(0).to(device) first = F.interpolate(lqs[:, half_nframe - 1, :, :, :], scale_factor=scale, mode='bilinear', align_corners=False) second = upX b_flow = Flow_arch.estimate_flow(netF, second, first) wrap_img, _ = Flow_arch.wraping(pre, b_flow) with torch.no_grad(): output = model(lqs, wrap_img, scale) output = output[0].detach() previous = output output = util.tensor2img(output).astype(np.float32) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_names[i])), output.astype(np.uint8)) # calculate PSNR GT = np.copy(img_GTs[i]) output = util.crop_border(output, crop_border) GT = util.crop_border(GT, crop_border) crt_psnr = util.calculate_psnr(output, GT) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( i + 1, img_names[i], crt_psnr)) if i >= border_frame and i < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnrs.append(avg_psnr) avg_psnr_centers.append(avg_psnr_center) avg_psnr_borders.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( sub, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( evaled_subs, avg_psnrs, avg_psnr_centers, avg_psnr_borders): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format( subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {}'.format(opts.test_dir)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnrs) / len(avg_psnrs), len(test_subs), sum(avg_psnr_centers) / len(avg_psnr_centers), sum(avg_psnr_borders) / len(avg_psnr_borders))) all_psnrs.append(avg_psnrs + [model_name]) with open(osp.join(opts.save_dir, 'all_psnrs.txt'), 'w') as f: for psnrs in all_psnrs: f.write("{:>14s}: {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} \n".format( psnrs[-1], sum(psnrs[:-1]) / 4., psnrs[0], psnrs[1], psnrs[2], psnrs[3]))