def main(): ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '1' data_mode = 'ai4khdr_test' flip_test = False ############################################################################ #### model ################# if data_mode == 'ai4khdr_test': model_path = '../experiments/002_EDVR_lr4e-4_600k_AI4KHDR/models/4000_G.pth' else: raise NotImplementedError N_in = 5 front_RBs = 5 back_RBs = 10 predeblur, HR_in = False, False model = EDVR_arch.EDVR(64, N_in, 8, front_RBs, back_RBs, predeblur=predeblur, HR_in=HR_in) ############################################################################ #### dataset ################# if data_mode == 'ai4khdr_test': test_dataset_folder = '/workspace/nas_mengdongwei/dataset/AI4KHDR/test/540p_frames' else: raise NotImplementedError ############################################################################ #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'ai4khdr_test': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}_{}'.format(data_mode, util.get_timestamp()) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=False) model.eval() model = model.to(device) subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) # for each subfolder for subfolder in subfolder_l: subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) logger.info('Folder {}'.format(subfolder_name)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test))
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options file.') opt = option.parse(parser.parse_args().opt, is_train=False) util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G')) opt = option.dict_to_nonedict(opt) util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) # Create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True model.feed_data(data, need_HR=need_HR) img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) img_c = util.tensor2img(visuals['img_c']) # uint8 img_s = util.tensor2img(visuals['img_s']) # uint8 img_p = util.tensor2img(visuals['img_p']) # uint8 # save images suffix = opt['suffix'] if suffix: save_c_img_path = os.path.join(dataset_dir, img_name + suffix + '_c.png') save_s_img_path = os.path.join(dataset_dir, img_name + suffix + '_s.png') save_p_img_path = os.path.join(dataset_dir, img_name + suffix + '_p.png') else: save_c_img_path = os.path.join(dataset_dir, img_name + '_c.png') save_s_img_path = os.path.join(dataset_dir, img_name + '_s.png') save_p_img_path = os.path.join(dataset_dir, img_name + '_p.png') util.save_img(img_c, save_c_img_path) util.save_img(img_s, save_s_img_path) util.save_img(img_p, save_p_img_path) # calculate PSNR and SSIM if need_HR: gt_img = util.tensor2img(visuals['HR']) gt_img = gt_img / 255. sr_img = img_c / 255. crop_border = test_loader.dataset.opt['scale'] cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) gt_img_y = bgr2ycbcr(gt_img, only_y=True) cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\ .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) else: logger.info(img_name) if need_HR: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\ .format(test_set_name, ave_psnr, ave_ssim)) if test_results['psnr_y'] and test_results['ssim_y']: ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\ .format(ave_psnr_y, ave_ssim_y))
def main(): # Create object for parsing command-line options parser = argparse.ArgumentParser( description="Test with EDVR, requre path to test dataset folder.") # Add argument which takes path to a bag file as an input parser.add_argument("-i", "--input", type=str, help="Path to test folder") # Parse the command line arguments to an object args = parser.parse_args() # Safety if no parameter have been given if not args.input: print("No input paramater have been given.") print("For help type --help") exit() folder_name = args.input.split("/")[-1] if folder_name == '': index = len(args.input.split("/")) - 2 folder_name = args.input.split("/")[index] ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' else: raise ValueError('Vid4 does not support stage 2.') else: raise NotImplementedError if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': # debug test_dataset_folder = os.path.join(args.input, 'BIx4') GT_dataset_folder = os.path.join(args.input, 'GT') else: if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(folder_name) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) # print(select_idx) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, GT = util.crop_border([output, GT], crop_border) crt_psnr = util.calculate_psnr(output * 255, GT * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(folder_name, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def create_test_png(model_path, device, gpu_id, opt, subfolder_l, save_folder, save_imgs, frame_notation, N_in, PAD, flip_test, end, total_run_time, logger, padding): model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'], groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'], back_RBs=opt['network_G']['back_RBs'], predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'], w_TSA=opt['network_G']['w_TSA']) model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) #if (torch.cuda.is_available()): model = model.cuda(gpu_id) for subfolder in subfolder_l: input_subfolder = os.path.split(subfolder)[1] # subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder) #if not os.path.exists(subfolder_GT): # continue print("Evaluate Folders: ", input_subfolder) subfolder_name = osp.basename(subfolder) #subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # Num x 3 x H x W #img_GT_l = [] #for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): # img_GT_l.append(data_util.read_img(None, img_GT_path)) #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] # todo here handle screen change select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding) if not log1 == None: logger.info('screen change') logger.info(nota) logger.info(log1) logger.info(log2) imgs_in = imgs_LQ.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).cuda(gpu_id) # 960 x 540 # here we split the input images 960x540 into 9 320x180 patch gtWidth = 3840 gtHeight = 2160 intWidth_ori = imgs_in.shape[4] # 960 intHeight_ori = imgs_in.shape[3] # 540 split_lengthY = 180 split_lengthX = 320 scale = 4 intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_ intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_ pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_]) print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_)) intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) imgs_in = torch.squeeze(imgs_in, 0)# N C H W imgs_in = pader0(imgs_in) # N C 540 960 imgs_in = pader(imgs_in) # N C 604 1024 assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX)) split_lengthY = int(split_lengthY) split_lengthX = int(split_lengthX) split_numY = int(float(intHeight_ori) / split_lengthY ) split_numX = int(float(intWidth_ori) / split_lengthX) splitsY = range(0, split_numY) splitsX = range(0, split_numX) intWidth = split_lengthX intWidth_pad = intWidth + intPaddingLeft + intPaddingRight intHeight = split_lengthY intHeight_pad = intHeight + intPaddingTop + intPaddingBottom # print("split " + str(split_numY) + ' , ' + str(split_numX)) y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32") # HWC for split_j, split_i in itertools.product(splitsY, splitsX): # print(str(split_j) + ", \t " + str(split_i)) X0 = imgs_in[:, :, split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop, split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft] # y_ = torch.FloatTensor() X0 = torch.unsqueeze(X0, 0) # N C H W -> 1 N C H W #X0 = X0.cuda(gpu_id) if flip_test: output = util.flipx4_forward(model, X0) else: output = util.single_forward(model, X0) output_depadded = output[0, :, intPaddingTop * scale :(intPaddingTop+intHeight) * scale, intPaddingLeft * scale: (intPaddingLeft+intWidth)*scale] output_depadded = output_depadded.squeeze(0) output = util.tensor2img(output_depadded) y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale, split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \ np.round(output).astype(np.uint8) # plt.figure(0) # plt.title("pic") # plt.imshow(y_all) if save_imgs: cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), y_all) print("*****************current image process time \t " + str( time.time() - end) + "s ******************") total_run_time.update(time.time() - end, 1) # calculate PSNR #y_all = y_all / 255. #GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel #if data_mode == 'Vid4': # bgr2y, [0, 1] # GT = data_util.bgr2ycbcr(GT, only_y=True) # y_all = data_util.bgr2ycbcr(y_all, only_y=True) #y_all, GT = util.crop_border([y_all, GT], crop_border) #crt_psnr = util.calculate_psnr(y_all * 255, GT * 255) #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder, img_idx + 1, img_name))
def main(): ###### SFTMD train ###### #### setup options parser = argparse.ArgumentParser() parser.add_argument( "-opt_F", type=str, default="options/train/SFTMD/train_SFTMD_x4.yml", help="Path to option YMAL file of SFTMD_Net.", ) parser.add_argument("--launcher", choices=["none", "pytorch"], default="none", help="job launcher") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() opt_F = option.parse(args.opt_F, is_train=True) # convert to NoneDict, which returns None for missing keys opt_F = option.dict_to_nonedict(opt_F) #### random seed seed = opt_F["train"]["manual_seed"] if seed is None: seed = random.randint(1, 10000) util.set_random_seed(seed) # load PCA matrix of enough kernel print("load PCA matrix") pca_matrix = torch.load( "../../../pca_matrix/IKC/pca_matrix.pth", map_location=lambda storage, loc: storage, ) print("PCA matrix shape: {}".format(pca_matrix.shape)) #### distributed training settings if args.launcher == "none": # disabled distributed training opt_F["dist"] = False rank = -1 print("Disabled distributed training.") else: opt_F["dist"] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### loading resume state if exists if opt_F["path"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt_F["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) option.check_resume(opt_F, resume_state["iter"]) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: if resume_state is None: util.mkdir_and_rename( opt_F["path"] ["experiments_root"]) # rename experiment folder if exists util.mkdirs( (path for key, path in opt_F["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key)) os.system("rm ./log") os.symlink(os.path.join(opt_F["path"]["experiments_root"], ".."), "./log") # config loggers. Before it, the log will not work util.setup_logger( "base", opt_F["path"]["log"], "train_" + opt_F["name"], level=logging.INFO, screen=True, tofile=True, ) util.setup_logger( "val", opt_F["path"]["log"], "val_" + opt_F["name"], level=logging.INFO, screen=True, tofile=True, ) logger = logging.getLogger("base") logger.info(option.dict2str(opt_F)) # tensorboard logger if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( "You are using PyTorch {}. Tensorboard will use [tensorboardX]" .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter( log_dir="log/{}/tb_logger/".format(opt_F["name"])) else: util.setup_logger("base", opt_F["path"]["log"], "train", level=logging.INFO, screen=True) logger = logging.getLogger("base") #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt_F["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt["batch_size"])) total_iters = int(opt_F["train"]["niter"]) total_epochs = int(math.ceil(total_iters / train_size)) if opt_F["dist"]: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt_F, train_sampler) if rank <= 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), train_size)) logger.info("Total epochs needed: {:d} for iters {:,d}".format( total_epochs, total_iters)) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt_F, None) if rank <= 0: logger.info("Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set))) else: raise NotImplementedError( "Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None #### create model model_F = create_model(opt_F) #### resume training if resume_state: logger.info("Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"])) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model_F.resume_training( resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info("Start training from epoch: {:d}, iter: {:d}".format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt_F["dist"]: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing( opt_F["scale"], pca_matrix, random=True, para_input=opt_F["code_length"], kernel=opt_F["kernel_size"], noise=False, cuda=True, sig=opt_F["sig"], sig_min=opt_F["sig_min"], sig_max=opt_F["sig_max"], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0, ) LR_img, ker_map = prepro(train_data["GT"]) #### update learning rate, schedulers model_F.update_learning_rate( current_step, warmup_iter=opt_F["train"]["warmup_iter"]) #### training model_F.feed_data(train_data, LR_img, ker_map) model_F.optimize_parameters(current_step) #### log if current_step % opt_F["logger"]["print_freq"] == 0: logs = model_F.get_current_log() message = "<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format( epoch, current_step, model_F.get_current_learning_rate()) for k, v in logs.items(): message += "{:s}: {:.4e} ".format(k, v) # tensorboard logger if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if current_step % opt_F["train"]["val_freq"] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): idx += 1 #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing( opt_F["scale"], pca_matrix, random=True, para_input=opt_F["code_length"], kernel=opt_F["kernel_size"], noise=False, cuda=True, sig=opt_F["sig"], sig_min=opt_F["sig_min"], sig_max=opt_F["sig_max"], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0, ) LR_img, ker_map = prepro(val_data["GT"]) model_F.feed_data(val_data, LR_img, ker_map) model_F.test() visuals = model_F.get_current_visuals() sr_img = util.tensor2img(visuals["SR"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 # Save SR images for reference img_name = os.path.splitext( os.path.basename(val_data["LQ_path"][0]))[0] # img_dir = os.path.join(opt_F['path']['val_images'], img_name) img_dir = os.path.join(opt_F["path"]["val_images"], str(current_step)) util.mkdir(img_dir) save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt_F["scale"] gt_img = gt_img / 255.0 sr_img = sr_img / 255.0 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr)) logger_val = logging.getLogger("val") # validation logger logger_val.info( "<epoch:{:3d}, iter:{:8,d}> psnr: {:.6f}".format( epoch, current_step, avg_psnr)) # tensorboard logger if opt_F["use_tb_logger"] and "debug" not in opt_F["name"]: tb_logger.add_scalar("psnr", avg_psnr, current_step) #### save models and training states if current_step % opt_F["logger"]["save_checkpoint_freq"] == 0: if rank <= 0: logger.info("Saving models and training states.") model_F.save(current_step) model_F.save_training_state(epoch, current_step) if rank <= 0: logger.info("Saving the final model.") model_F.save("latest") logger.info("End of SFTMD training.")
test_results['ssim_y'] = [] time_total = 0.0 time_cnt = 0 for data in test_loader: need_GT = False if test_loader.dataset.opt[ 'dataroot_GT'] is None else True model.feed_data(data, need_GT=need_GT) video_name = data['key'][0].split('_')[1] category_name = data['key'][0].split('_')[0] model.test() visuals = model.get_current_visuals(need_GT=need_GT) sr_img = util.tensor2img(visuals['rlt']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: os.makedirs(os.path.join(dataset_dir, category_name, video_name), exist_ok=True) save_img_path = osp.join(dataset_dir, category_name, video_name, 'im4.png') util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_GT: gt_img = util.tensor2img(visuals['GT'])
# Corrector test for step in range(opt_C['step']): step += 1 # Test SFTMD to produce SR images model_F.feed_data(test_data, LR_img, est_ker_map) model_F.test() F_visuals = model_F.get_current_visuals() SR_img = F_visuals['Batch_SR'] model_C.feed_data(SR_img, est_ker_map, ker_map) model_C.test() C_visuals = model_C.get_current_visuals() est_ker_map = C_visuals['Batch_est_ker_map'] sr_img = util.tensor2img(F_visuals['SR']) # uint8 # save images suffix = opt_F['suffix'] if suffix: save_img_path = os.path.join(dataset_dir, str(step), img_name + suffix + '.png') else: save_img_path = os.path.join(dataset_dir, str(step), img_name + '.png') util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_GT: gt_img = util.tensor2img(F_visuals['GT']) gt_img = gt_img / 255.
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YMAL file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings opt['dist'] = False rank = -1 print('Disabled distributed training.') #### loading resume state if exists if opt['path'].get('resume_state', None): resume_state_path, _ = get_resume_paths(opt) # distributed resuming: all load into default GPU if resume_state_path is None: resume_state = None else: device_id = torch.cuda.current_device() resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter conf_name = basename(args.opt).replace(".yml", "") exp_dir = opt['path']['experiments_root'] log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train') log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid') tb_logger_train = SummaryWriter(log_dir=log_dir_train) tb_logger_valid = SummaryWriter(log_dir=log_dir_valid) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) print('Dataset created') train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model current_step = 0 if resume_state is None else resume_state['iter'] model = create_model(opt, current_step) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training timer = Timer() logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) timerData = TickTock() for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) timerData.tick() for _, train_data in enumerate(train_loader): timerData.tock() current_step += 1 if current_step > total_iters: break #### training model.feed_data(train_data) #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) try: nll = model.optimize_parameters(current_step) except RuntimeError as e: print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ") print(e) if nll is None: nll = 0 #### log def eta(t_iter): return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if current_step % opt['logger']['print_freq'] == 0 \ or current_step - (resume_state['iter'] if resume_state else 0) < 25: avg_time = timer.get_average_and_reset() avg_data_time = timerData.get_average_and_reset() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format( epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time, eta(avg_time), nll) print(message) timer.tick() # Reduce number of logs if current_step % 5 == 0: tb_logger_train.add_scalar('loss/nll', nll, current_step) tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step) tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step) tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step) tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step) for k, v in model.get_current_log().items(): tb_logger_train.add_scalar(k, v, current_step) # validation if current_step % opt['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 nlls = [] for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) nll = model.test() if nll is None: nll = 0 nlls.append(nll) visuals = model.get_current_visuals() sr_img = None # Save SR images for reference if hasattr(model, 'heats'): for heat in model.heats: for i in range(model.n_sample): sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8 save_img_path = os.path.join(img_dir, '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, current_step, int(heat * 100), i)) util.save_img(sr_img, save_img_path) else: sr_img = util.tensor2img(visuals['SR']) # uint8 save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) assert sr_img is not None # Save LQ images for reference save_img_path_lq = os.path.join(img_dir, '{:s}_LQ.png'.format(img_name)) if not os.path.isfile(save_img_path_lq): lq_img = util.tensor2img(visuals['LQ']) # uint8 util.save_img( cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'], interpolation=cv2.INTER_NEAREST), save_img_path_lq) # Save GT images for reference gt_img = util.tensor2img(visuals['GT']) # uint8 save_img_path_gt = os.path.join(img_dir, '{:s}_GT.png'.format(img_name)) if not os.path.isfile(save_img_path_gt): util.save_img(gt_img, save_img_path_gt) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx avg_nll = sum(nlls) / len(nlls) # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format( epoch, current_step, avg_psnr)) # tensorboard logger tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step) tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step) tb_logger_train.flush() tb_logger_valid.flush() #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) timerData.tick() with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f: f.write("TRAIN_DONE") if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def test(model, test_data_loader, save=False): # pdb.set_trace() psnr_cal = PSNR() msssim_cal = MS_SSIM(data_range=1.0) ssim_cal = SSIM(data_range=1.0) psnr_meter_mono, psnr_meter_resL, psnr_meter_resR = AverageMeter( ), AverageMeter(), AverageMeter() msssim_meter_mono, msssim_meter_resL, msssim_meter_resR = AverageMeter( ), AverageMeter(), AverageMeter() ssim_meter_mono, ssim_meter_resL, ssim_meter_resR = AverageMeter( ), AverageMeter(), AverageMeter() with torch.no_grad(): model.eval() for i, (left, right, original_shape) in enumerate(tqdm(test_data_loader)): batch_size = left.shape[0] assert batch_size == 1, 'Only support batch of 1 now!' left = left.cuda(non_blocking=True) right = right.cuda(non_blocking=True) res_mono, res_l, res_r = model(left, right) original_shape = [x.item() for x in original_shape] def fun(x): x = test_data_loader.dataset.depad_tensor(x, original_shape) x = inverse_normalize(x).clamp(0.0, 1.0) return x left, right, res_mono, res_l, res_r = fun(left), fun(right), fun( res_mono), fun(res_l), fun(res_r) name = test_data_loader.dataset.frames[i][0].split('/')[-1].split( '.')[0] # pdb.set_trace() psnr_meter_mono.update(psnr_cal(res_mono, left), n=batch_size) psnr_meter_resL.update(psnr_cal(res_l, left), n=batch_size) psnr_meter_resR.update(psnr_cal(res_r, right), n=batch_size) msssim_meter_mono.update(msssim_cal(res_mono, left), n=batch_size) msssim_meter_resL.update(msssim_cal(res_l, left), n=batch_size) msssim_meter_resR.update(msssim_cal(res_r, right), n=batch_size) ssim_meter_mono.update(ssim_cal(res_mono, left), n=batch_size) ssim_meter_resL.update(ssim_cal(res_l, left), n=batch_size) ssim_meter_resR.update(ssim_cal(res_r, right), n=batch_size) if save: for x, last_fix in zip( [res_mono, res_l, res_r], ["_res_mono.png", "_res_l.png", "_res_r.png"]): cv2.imwrite( join(args.save_folder, args.data_name, name + last_fix), util.tensor2img(x)[..., ::-1] * 255) logger.info( '==>Mononized: \n' 'PSNR: {psnr_meter_mono.avg:.2f}\n' 'MS-SSIM: {msssim_meter_mono.avg:.2f}\n' 'SSIM: {ssim_meter_mono.avg:.2f}\n' '==>restored: \n' 'PSNR: {psnr_meter_resL.avg:.2f}, {psnr_meter_resR.avg:.2f}\n' 'MS-SSIM: {msssim_meter_resL.avg:.2f}, {msssim_meter_resR.avg:.2f}\n' 'SSIM: {ssim_meter_resL.avg:.2f}, {ssim_meter_resR.avg:.2f}'.format( psnr_meter_mono=psnr_meter_mono, msssim_meter_mono=msssim_meter_mono, ssim_meter_mono=ssim_meter_mono, psnr_meter_resL=psnr_meter_resL, msssim_meter_resL=msssim_meter_resL, ssim_meter_resL=ssim_meter_resL, psnr_meter_resR=psnr_meter_resR, msssim_meter_resR=msssim_meter_resR, ssim_meter_resR=ssim_meter_resR))
model.feed_data(data) if test_set_name == 'Vid4': folder = osp.split(osp.dirname(data['GT_path'][0][0]))[1] else: folder = '' util.mkdir(osp.join(dataset_dir, folder)) model.test() visuals = model.get_current_visuals() if test_set_name == 'Vimeo90K': center = visuals['SR'].shape[0] // 2 img_path = data['GT_path'][0] img_name = osp.splitext(osp.basename(img_path))[0] sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT'][center]) # uint8 lr_img = util.tensor2img(visuals['LR']) # uint8 lrgt_img = util.tensor2img(visuals['LR_ref'][center]) # uint8 test_results = cal_pnsr_ssim(sr_img, gt_img, lr_img, lrgt_img) else: t_step = visuals['SR'].shape[0] for i in range(t_step): img_path = data['GT_path'][i][0] img_name = osp.splitext(osp.basename(img_path))[0] sr_img = util.tensor2img(visuals['SR'][i]) # uint8 gt_img = util.tensor2img(visuals['GT'][i]) # uint8 lr_img = util.tensor2img(visuals['LR'][i]) # uint8
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. pytorch_ver = get_pytorch_ver() # train from scratch OR resume training if opt['path']['resume_state']: if os.path.isdir(opt['path']['resume_state']): import glob resume_state_path = util.sorted_nicely( glob.glob( os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1] else: resume_state_path = opt['path']['resume_state'] resume_state = torch.load(resume_state_path) else: # training from scratch resume_state = None util.mkdir_and_rename( opt['path']['experiments_root']) # rename old folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) logger = logging.getLogger('base') if resume_state: logger.info('Set [resume_state] to ' + resume_state_path) logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) option.check_resume(opt) # check resume options logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: from tensorboardX import SummaryWriter try: tb_logger = SummaryWriter( logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7 except: tb_logger = SummaryWriter( log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6 # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) # if the model does not change and input sizes remain the same during training then there may be benefit # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # create model model = create_model(opt) # resume training if resume_state: start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers model.update_schedulers( opt['train'] ) # updated schedulers in case JSON configuration has changed else: current_step = 0 start_epoch = 0 # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs): for n, train_data in enumerate(train_loader, start=1): current_step += 1 if current_step > total_iters: break if pytorch_ver == "pre": #Order for PyTorch ver < 1.1.0 # update learning rate model.update_learning_rate(current_step - 1) # training model.feed_data(train_data) model.optimize_parameters(current_step) elif pytorch_ver == "post": #Order for PyTorch ver > 1.1.0 # training model.feed_data(train_data) model.optimize_parameters(current_step) # update learning rate model.update_learning_rate(current_step - 1) else: print('Error identifying PyTorch version. ', torch.__version__) break # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) # save models and training states (changed to save models before validation) if current_step % opt['logger']['save_checkpoint_freq'] == 0: model.save(current_step) model.save_training_state(epoch + (n >= len(train_loader)), current_step) logger.info('Models and training states saved.') # validation if val_loader and current_step % opt['train']['val_freq'] == 0: avg_psnr_c = 0.0 avg_psnr_s = 0.0 avg_psnr_p = 0.0 avg_ssim_c = 0.0 avg_ssim_s = 0.0 avg_ssim_p = 0.0 idx = 0 val_sr_imgs_list = [] val_gt_imgs_list = [] for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() if opt['datasets']['train'][ 'znorm']: # If the image range is [-1,1] img_c = util.tensor2img(visuals['img_c'], min_max=(-1, 1)) # uint8 img_s = util.tensor2img(visuals['img_s'], min_max=(-1, 1)) # uint8 img_p = util.tensor2img(visuals['img_p'], min_max=(-1, 1)) # uint8 gt_img = util.tensor2img(visuals['HR'], min_max=(-1, 1)) # uint8 else: # Default: Image range is [0,1] img_c = util.tensor2img(visuals['img_c']) # uint8 img_s = util.tensor2img(visuals['img_s']) # uint8 img_p = util.tensor2img(visuals['img_p']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference save_c_img_path = os.path.join( img_dir, '{:s}_{:d}_c.png'.format(img_name, current_step)) save_s_img_path = os.path.join( img_dir, '{:s}_{:d}_s.png'.format(img_name, current_step)) save_p_img_path = os.path.join( img_dir, '{:s}_{:d}_d.png'.format(img_name, current_step)) util.save_img(img_c, save_c_img_path) util.save_img(img_s, save_s_img_path) util.save_img(img_p, save_p_img_path) # calculate PSNR, SSIM and LPIPS distance crop_size = opt['scale'] gt_img = gt_img / 255. #sr_img = sr_img / 255. #ESRGAN #PPON sr_img_c = img_c / 255. #C sr_img_s = img_s / 255. #S sr_img_p = img_p / 255. #D # For training models with only one channel ndim==2, if RGB ndim==3, etc. if gt_img.ndim == 2: cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size] else: # gt_img.ndim == 3, # Default: RGB images cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] # All 3 output images will have the same dimensions if sr_img_c.ndim == 2: cropped_sr_img_c = sr_img_c[crop_size:-crop_size, crop_size:-crop_size] cropped_sr_img_s = sr_img_s[crop_size:-crop_size, crop_size:-crop_size] cropped_sr_img_p = sr_img_p[crop_size:-crop_size, crop_size:-crop_size] else: #sr_img_c.ndim == 3, # Default: RGB images cropped_sr_img_c = sr_img_c[crop_size:-crop_size, crop_size:-crop_size, :] cropped_sr_img_s = sr_img_s[crop_size:-crop_size, crop_size:-crop_size, :] cropped_sr_img_p = sr_img_p[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr_c += util.calculate_psnr(cropped_sr_img_c * 255, cropped_gt_img * 255) avg_ssim_c += util.calculate_ssim(cropped_sr_img_c * 255, cropped_gt_img * 255) avg_psnr_s += util.calculate_psnr(cropped_sr_img_s * 255, cropped_gt_img * 255) avg_ssim_s += util.calculate_ssim(cropped_sr_img_s * 255, cropped_gt_img * 255) avg_psnr_p += util.calculate_psnr(cropped_sr_img_p * 255, cropped_gt_img * 255) avg_ssim_p += util.calculate_ssim(cropped_sr_img_p * 255, cropped_gt_img * 255) # LPIPS only works for RGB images # Using only the final perceptual image to calulate LPIPS if sr_img_c.ndim == 3: #avg_lpips += lpips.calculate_lpips([cropped_sr_img], [cropped_gt_img]) # If calculating for each image val_gt_imgs_list.append( cropped_gt_img ) # If calculating LPIPS only once for all images val_sr_imgs_list.append( cropped_sr_img_p ) # If calculating LPIPS only once for all images # PSNR avg_psnr_c = avg_psnr_c / idx avg_psnr_s = avg_psnr_s / idx avg_psnr_p = avg_psnr_p / idx # SSIM avg_ssim_c = avg_ssim_c / idx avg_ssim_s = avg_ssim_s / idx avg_ssim_p = avg_ssim_p / idx # LPIPS #avg_lpips = avg_lpips / idx # If calculating for each image avg_lpips = lpips.calculate_lpips( val_sr_imgs_list, val_gt_imgs_list ) # If calculating only once for all images # log # PSNR logger.info('# Validation # PSNR_c: {:.5g}'.format(avg_psnr_c)) logger.info('# Validation # PSNR_s: {:.5g}'.format(avg_psnr_s)) logger.info('# Validation # PSNR_p: {:.5g}'.format(avg_psnr_p)) # SSIM logger.info('# Validation # SSIM_c: {:.5g}'.format(avg_ssim_c)) logger.info('# Validation # SSIM_s: {:.5g}'.format(avg_ssim_s)) logger.info('# Validation # SSIM_p: {:.5g}'.format(avg_ssim_p)) # LPIPS logger.info('# Validation # LPIPS: {:.5g}'.format(avg_lpips)) logger_val = logging.getLogger('val') # validation logger # logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format( # epoch, current_step, avg_psnr_c, avg_psnr_s, avg_psnr_p)) logger_val.info('<epoch:{:3d}, iter:{:8,d}>'.format( epoch, current_step)) logger_val.info( 'psnr_c: {:.5g}, psnr_s: {:.5g}, psnr_p: {:.5g}'.format( avg_psnr_c, avg_psnr_s, avg_psnr_p)) logger_val.info( 'ssim_c: {:.5g}, ssim_s: {:.5g}, ssim_p: {:.5g}'.format( avg_ssim_c, avg_ssim_s, avg_ssim_p)) logger_val.info('lpips: {:.5g}'.format(avg_lpips)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_c', avg_psnr_c, current_step) tb_logger.add_scalar('psnr_s', avg_psnr_s, current_step) tb_logger.add_scalar('psnr_p', avg_psnr_p, current_step) tb_logger.add_scalar('ssim_c', avg_ssim_c, current_step) tb_logger.add_scalar('ssim_s', avg_ssim_s, current_step) tb_logger.add_scalar('ssim_p', avg_ssim_p, current_step) tb_logger.add_scalar('lpips', avg_lpips, current_step) logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # train from scratch OR resume training if opt['path']['resume_state']: # resuming training resume_state = torch.load(opt['path']['resume_state']) else: # training from scratch resume_state = None util.mkdir_and_rename( opt['path']['experiments_root']) # rename old folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) logger = logging.getLogger('base') if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) option.check_resume(opt) # check resume options logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benckmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # create model model = create_model(opt) # resume training if resume_state: start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs): for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # update learning rate # model.update_learning_rate() # training model.feed_data(train_data) model.optimize_parameters(current_step) # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt['train']['val_freq'] == 0: avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) # model.feed_data2(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\ img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info( '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format( epoch, current_step, avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) model.update_learning_rate() # save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
np.ascontiguousarray(np.transpose( img_Ref, (2, 0, 1)))).float().unsqueeze(0).cuda() img_Ref_DUX4 = cv2.imread(osp.join(Ref_DUX4_path, use_name)) / 255. img_Ref_DUX4 = img_Ref_DUX4[:, :, [2, 1, 0]] img_Ref_DUX4 = torch.from_numpy( np.ascontiguousarray(np.transpose( img_Ref_DUX4, (2, 0, 1)))).float().unsqueeze(0).cuda() with torch.no_grad(): begin_time = time.time() output = model(img_LR, img_LR_UX4, img_Ref, img_Ref_DUX4) end_time = time.time() stat_time += (end_time - begin_time) output = util.tensor2img(output.squeeze(0)) # save images save_path_name = osp.join( save_path, '{}_exp{}/{}.png'.format(dataset, exp_name, base_name)) util.save_img(output, save_path_name) # calculate PSNR sr_img, gt_img = util.crop_border([output, img_GT], scale) PSNR_avg += util.calculate_psnr(sr_img, gt_img) SSIM_avg += util.calculate_ssim(sr_img, gt_img) print('average PSNR: ', PSNR_avg / len(img_list)) print('average SSIM: ', SSIM_avg / len(img_list)) print('time: ', stat_time / len(img_list))
def main(): ############################################ # # set options # ############################################ parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) ############################################ # # distributed training settings # ############################################ if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("Rank:", rank) print("------------------DIST-------------------------") ############################################ # # loading resume state if exists # ############################################ if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None ############################################ # # mkdir and loggers # ############################################ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('base_val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_', level=logging.INFO, screen=True) print("set train log") util.setup_logger('base_val', opt['path']['log'], 'val_', level=logging.INFO, screen=True) print("set val log") logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ############################################ # # create train and val dataloader # ############################################ #### # dataset_ratio = 200 # enlarge the size of each epoch, todo: what it is dataset_ratio = 1 # enlarge the size of each epoch, todo: what it is for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) # total_iters = int(opt['train']['niter']) # total_epochs = int(math.ceil(total_iters / train_size)) total_iters = train_size total_epochs = int(opt['train']['epoch']) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) total_epochs = int(opt['train']['epoch']) if opt['train']['enable'] == False: total_epochs = 1 else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None ############################################ # # create model # ############################################ #### model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 print("Not Resume Training") ############################################ # # training # ############################################ #### #### logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) Avg_train_loss = AverageMeter() # total Avg_train_psnr = AverageMeter() if opt['datasets']['train']['color'] == 'YUV': Avg_train_yuv_psnr = AverageMeter() if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() saved_total_loss = 10e10 saved_total_PSNR = -1 for epoch in range(start_epoch, total_epochs): ############################################ # # Start a new epoch # ############################################ # Turn into training mode #model = model.train() # reset total loss Avg_train_loss.reset() # reset psnr Avg_train_psnr.reset() if opt['datasets']['train']['color'] == 'YUV': Avg_train_yuv_psnr.reset() current_step = 0 if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() if opt['dist']: train_sampler.set_epoch(epoch) for train_idx, train_data in enumerate(train_loader): if 'debug' in opt['name']: img_dir = os.path.join(opt['path']['train_images']) util.mkdir(img_dir) LQ = train_data['LQs'] GT = train_data['GT'] GT_img = util.tensor2img(GT) # uint8 NCHW print('GT_img', GT_img.shape) print('LQ', LQ.shape) if opt['datasets']['train']['color'] == 'YUV': GT_img = data_util.ycbcr2rgb(GT_img) save_img_path = os.path.join( img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT')) if opt['datasets']['train']['color'] == 'YUV': util.save_img(GT_img, save_img_path, mode='RGB') else: util.save_img(GT_img, save_img_path) for i in range(5): LQ_img = util.tensor2img(LQ[:, i, ...]) # uint8 if opt['datasets']['train']['color'] == 'YUV': LQ_img = data_util.ycbcr2rgb(LQ_img) save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ', i)) if opt['datasets']['train']['color'] == 'YUV': util.save_img(LQ_img, save_img_path, mode='RGB') else: util.save_img(LQ_img, save_img_path) if (train_idx >= 10): break if opt['train']['enable'] == False: message_train_loss = 'None' break current_step += 1 if current_step > total_iters: print("Total Iteration Reached !") break #### update learning rate if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': pass else: model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': # model.optimize_parameters_without_schudlue(current_step) # else: model.optimize_parameters(current_step) visuals = model.get_current_visuals(need_GT=True, save=False) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 if opt['datasets']['train']['color'] == 'YUV': yuv_psnr = util.calculate_psnr(rlt_img, gt_img) rlt_img = data_util.ycbcr2rgb(rlt_img) gt_img = data_util.ycbcr2rgb(gt_img) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) if opt['train']['pixel_criterion'] == 'cb+ssim': Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) Avg_train_psnr.update(psnr, 1) if opt['datasets']['train']['color'] == 'YUV': Avg_train_yuv_psnr.update(yuv_psnr, 1) else: Avg_train_loss.update(model.log_dict['l_pix'], 1) Avg_train_psnr.update(psnr, 1) if opt['datasets']['train']['color'] == 'YUV': Avg_train_yuv_psnr.update(yuv_psnr, 1) # add total train loss if (opt['train']['pixel_criterion'] == 'cb+ssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) message_train_loss += ' psnr_inst : {:.2f}'.format(psnr) message_train_loss += ' psnr_avg : {:.2f}'.format( Avg_train_psnr.avg) else: message_train_loss = ' train_avg_loss: {:.4e}'.format( Avg_train_loss.avg) if opt['datasets']['train']['color'] == 'YUV': message_train_loss += ' yuv_psnr_inst : {:.2f}'.format( yuv_psnr) message_train_loss += ' psnr_inst : {:.2f}'.format(psnr) if opt['datasets']['train']['color'] == 'YUV': message_train_loss += ' yuv_psnr_avg : {:.2f}'.format( Avg_train_yuv_psnr.avg) message_train_loss += ' psnr_avg : {:.2f}'.format( Avg_train_psnr.avg) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) # tensorboard logger - avg part if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar('train_avg_loss', Avg_train_loss.avg, current_step) if opt['datasets']['train']['color'] == 'YUV': tb_logger.add_scalar('yuv_psnr_avg', Avg_train_yuv_psnr.avg, current_step) tb_logger.add_scalar('psnr_avg', Avg_train_psnr.avg, current_step) message += message_train_loss if rank <= 0: logger.info(message) ############################################ # # end of one epoch, save epoch model # ############################################ #### save models and training states if epoch == 1: save_filename = '{:04d}_{}.pth'.format(0, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) if rank <= 0: logger.info('Saving models and training states.') save_filename = '{:04d}'.format(epoch) model.save(save_filename) # model.save('latest') # model.save_training_state(epoch, current_step) ############################################ # # end of one epoch, do validation # ############################################ #### validation #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None): if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) #util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # todo : multi-GPU testing psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): if 'debug' in opt['name']: print('idx', idx) # if (idx >= 3): # break val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) # model.test() # model.test_stitch() if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # visuals = model.get_current_visuals() visuals = model.get_current_visuals( save=True, name='{}_{}'.format(folder, idx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 if opt['datasets']['train']['color'] == 'YUV': rlt_img = data_util.ycbcr2rgb(rlt_img) gt_img = data_util.ycbcr2rgb(gt_img) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder][idx_d] = psnr # calculate SSIM # ssim = util.calculate_ssim(rlt_img, gt_img) ssim = 0 # to do save time do not use it ssim_rlt[folder][idx_d] = ssim # calculate Val loss val_loss = model.get_loss() val_loss_rlt[folder][idx_d] = val_loss logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, idx, psnr, ssim)) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) for _, v in ssim_rlt.items(): dist.reduce(v, 0) for _, v in val_loss_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # ssim ssim_rlt_avg = {} ssim_total_avg = 0. for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = torch.mean(v).cpu().item() ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # added val_loss_rlt_avg = {} val_loss_total_avg = 0. for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = torch.mean(v).cpu().item() val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: # Todo: our function One GPU pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. for val_inx, val_data in enumerate(val_loader): # if 'debug' in opt['name']: # if (val_inx >= 5): # break folder = val_data['folder'][0] # idx_d = val_data['idx'].item() idx_d = val_data['idx'] # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = [] if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = [] # process the black blank [B N C H W] # print(val_data['LQs'].size()) # H_S = val_data['LQs'].size(3) # 540 # W_S = val_data['LQs'].size(4) # 960 # print(H_S) # print(W_S) # blank_1_S = 0 # blank_2_S = 0 # print(val_data['LQs'][0,2,0,:,:].size()) # for i in range(H_S): # if not sum(val_data['LQs'][0,2,0,i,:]) == 0: # blank_1_S = i - 1 # # assert not sum(data_S[:, :, 0][i+1]) == 0 # break # for i in range(H_S): # if not sum(val_data['LQs'][0,2,0,:,H_S - i - 1]) == 0: # blank_2_S = (H_S - 1) - i - 1 # # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0 # break # print('LQ :', blank_1_S, blank_2_S) # if blank_1_S == -1: # print('LQ has no blank') # blank_1_S = 0 # blank_2_S = H_S # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:] # print("LQ",val_data['LQs'].size()) # end of process the black blank model.feed_data(val_data) if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # process blank # blank_1_L = blank_1_S << 2 # blank_2_L = blank_2_S << 2 # print(blank_1_L, blank_2_L) # print(model.fake_H.size()) # if not blank_1_S == 0: # # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:] # model.fake_H[:, :,0:blank_1_L, :] = 0 # model.fake_H[:, :,blank_2_L:H_S, :] = 0 # end of # process blank visuals = model.get_current_visuals( save=True, name='{}_{:02d}'.format(folder, val_inx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 if opt['datasets']['train']['color'] == 'YUV': rlt_img = data_util.ycbcr2rgb(rlt_img) gt_img = data_util.ycbcr2rgb(gt_img) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) # calculate SSIM # ssim = util.calculate_ssim(rlt_img, gt_img) ssim = 0 # to do save time do not use it ssim_rlt[folder].append(ssim) # val loss val_loss = model.get_loss() val_loss_rlt[folder].append(val_loss.item()) logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, val_inx, psnr, ssim)) pbar.update('Test {} - {}'.format(folder, idx_d)) # average PSNR for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average SSIM for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average Val LOSS for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = sum(v) / len(v) val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) # toal validation log message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) # end add if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step) # for k, v in ssim_rlt_avg.items(): # tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) ############################################ # # end of validation, save model # ############################################ # if rank <= 0: logger.info( "Finished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_loss_total_avg: saved_total_loss = val_loss_total_avg #torch.save(model.state_dict(), args.save_path + "/best" + ".pth") model.save('best') logger.info( "Best Weights updated for decreased validation loss") else: logger.info( "Weights Not updated for undecreased validation loss") if saved_total_PSNR <= psnr_total_avg: saved_total_PSNR = psnr_total_avg model.save('bestPSNR') logger.info( "Best Weights updated for increased validation PSNR") else: logger.info( "Weights Not updated for unincreased validation PSNR") ############################################ # # end of one epoch, schedule LR # ############################################ # add scheduler todo if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': for scheduler in model.schedulers: # scheduler.step(val_loss_total_avg) scheduler.step(val_loss_total_avg) if rank <= 0: logger.info('Saving the final model.') model.save('last') logger.info('End of training.') tb_logger.close()
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') parser.add_argument('-single_GPU', action='store_true', help='Utilize only one GPU') if parser.parse_args().single_GPU: available_GPUs = util.Assign_GPU() else: available_GPUs = util.Assign_GPU(max_GPUs=None) opt = option.parse(parser.parse_args().opt, is_train=True, batch_size_multiplier=len(available_GPUs)) if not opt['train']['resume']: util.mkdir_and_rename( opt['path'] ['experiments_root']) # Modify experiment name if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrained_model_G' and not key == 'pretrained_model_D')) option.save(opt) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': max_accumulation_steps = max([ opt['train']['grad_accumulation_steps_G'], opt['train']['grad_accumulation_steps_D'] ]) train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter'] * max_accumulation_steps) #-current_step total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format( total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # Create model if max_accumulation_steps != 1: model = create_model(opt, max_accumulation_steps) else: model = create_model(opt) # create logger logger = Logger(opt) # Save validation set results as image collage: SAVE_IMAGE_COLLAGE = True per_image_saved_patch = min( [min(im['HR'].shape[1:]) for im in val_loader.dataset]) - 2 num_val_images = len(val_loader.dataset) val_images_collage_rows = int(np.floor(np.sqrt(num_val_images))) while val_images_collage_rows > 1: if np.round(num_val_images / val_images_collage_rows ) == num_val_images / val_images_collage_rows: break val_images_collage_rows -= 1 start_time = time.time() min_accumulation_steps = min([ opt['train']['grad_accumulation_steps_G'], opt['train']['grad_accumulation_steps_D'] ]) save_GT_HR = True lr_too_low = False print('---------- Start training -------------') last_saving_time = time.time() recently_saved_models = deque(maxlen=4) for epoch in range(int(math.floor(model.step / train_size)), total_epoches): for i, train_data in enumerate(train_loader): gradient_step_num = model.step // max_accumulation_steps not_within_batch = model.step % max_accumulation_steps == ( max_accumulation_steps - 1) saving_step = ( (time.time() - last_saving_time) > 60 * opt['logger']['save_checkpoint_freq']) and not_within_batch if saving_step: last_saving_time = time.time() # save models if lr_too_low or saving_step: recently_saved_models.append(model.save(gradient_step_num)) model.save_log() if len(recently_saved_models) > 3: model_2_delete = recently_saved_models.popleft() os.remove(model_2_delete) if model.D_exists: os.remove(model_2_delete.replace('_G.', '_D.')) print('{}: Saving the model before iter {:d}.'.format( datetime.now().strftime('%H:%M:%S'), gradient_step_num)) if lr_too_low: break if model.step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters() if not model.D_exists: #Avoid using the naive MultiLR scheduler when using adversarial loss for scheduler in model.schedulers: scheduler.step(model.gradient_step_num) time_elapsed = time.time() - start_time if not_within_batch: start_time = time.time() # log if gradient_step_num % opt['logger'][ 'print_freq'] == 0 and not_within_batch: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = gradient_step_num print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt, keys_ignore_list=IGNORED_KEYS_LIST) model.display_log_figure() # validation if not_within_batch and (gradient_step_num) % opt['train'][ 'val_freq'] == 0: # and gradient_step_num>=opt['train']['D_init_iters']: print_rlt = OrderedDict() if model.generator_changed: print('---------- validation -------------') start_time = time.time() if False and SAVE_IMAGE_COLLAGE and model.gradient_step_num % opt[ 'train'][ 'val_save_freq'] == 0: #Saving training images: GT_image_collage = [] cur_train_results = model.get_current_visuals( entire_batch=True) train_psnrs = [ util.calculate_psnr( util.tensor2img( cur_train_results['SR'][im_num], out_type=np.float32) * 255, util.tensor2img( cur_train_results['HR'][im_num], out_type=np.float32) * 255) for im_num in range(len(cur_train_results['SR'])) ] #Save latest training batch output: save_img_path = os.path.join( os.path.join(opt['path']['val_images']), '{:d}_Tr_PSNR{:.3f}.png'.format( gradient_step_num, np.mean(train_psnrs))) util.save_img( np.clip( np.concatenate( (np.concatenate([ util.tensor2img( cur_train_results['HR'][im_num], out_type=np.float32) * 255 for im_num in range( len(cur_train_results['SR'])) ], 0), np.concatenate([ util.tensor2img( cur_train_results['SR'][im_num], out_type=np.float32) * 255 for im_num in range( len(cur_train_results['SR'])) ], 0)), 1), 0, 255).astype(np.uint8), save_img_path) Z_latent = [0] + ([-1, 1] if opt['network_G']['latent_input'] else []) print_rlt['psnr'] = 0 for cur_Z in Z_latent: sr_images = model.perform_validation( data_loader=val_loader, cur_Z=cur_Z, print_rlt=print_rlt, save_GT_HR=save_GT_HR, save_images=((model.gradient_step_num) % opt['train']['val_save_freq'] == 0) or save_GT_HR) if logger.use_tb_logger: logger.tb_logger.log_images( 'validation_Z%.2f' % (cur_Z), [im[:, :, [2, 1, 0]] for im in sr_images], model.gradient_step_num) if save_GT_HR: # Save GT Uncomp images save_GT_HR = False model.log_dict['psnr_val'].append( (gradient_step_num, print_rlt['psnr'] / len(Z_latent))) else: print('Skipping validation because generator is unchanged') time_elapsed = time.time() - start_time # Save to log print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = gradient_step_num print_rlt['time'] = time_elapsed model.display_log_figure() logger.print_format_results('val', print_rlt, keys_ignore_list=IGNORED_KEYS_LIST) print('-----------------------------------') # update learning rate if not_within_batch: lr_too_low = model.update_learning_rate(gradient_step_num) if lr_too_low: print('Stopping training because LR is too low') break print('Saving the final model.') model.save(gradient_step_num) print('End of training.')
img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] target = test_set1[0]['HR'] for i in range(5000): cur_out = model.netG(data['LR'], code_val_0, code_val_1, code_val_2)[-1] copied_cur_out = Variable(cur_out.detach().to(device1), requires_grad=True) output = model1.netG(copied_cur_out, code_val_3, None, None)[-1] dist = model1.loss_fn.forward(output, target, normalize=True) optimizer.zero_grad() optimizer1.zero_grad() dist.backward() cur_out.backward(copied_cur_out.grad) optimizer1.step() optimizer.step() if i % 10 == 0: print('iter %d, dist %.3g' % (i, dist.view(-1).data.cpu().numpy()[0])) if i % 100 == 0: save_img_path = os.path.join(dataset_dir, img_name + '_%d.png' % i) sr_img = util.tensor2img( output.detach()[0].float().cpu()) # uint8 print("saving: %s" % save_img_path) util.save_img(sr_img, save_img_path)
def main(): #### setup options of three networks parser = argparse.ArgumentParser() parser.add_argument('-opt_P', type=str, help='Path to option YMAL file of Predictor.') parser.add_argument('-opt_C', type=str, help='Path to option YMAL file of Corrector.') parser.add_argument('-opt_F', type=str, help='Path to option YMAL file of SFTMD_Net.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt_P = option.parse(args.opt_P, is_train=True) opt_C = option.parse(args.opt_C, is_train=True) opt_F = option.parse(args.opt_F, is_train=True) # convert to NoneDict, which returns None for missing keys opt_P = option.dict_to_nonedict(opt_P) opt_C = option.dict_to_nonedict(opt_C) opt_F = option.dict_to_nonedict(opt_F) # choose small opt for SFTMD test, fill path of pre-trained model_F opt_F = opt_F['sftmd'] #### set random seed seed = opt_P['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) util.set_random_seed(seed) # load PCA matrix of enough kernel print('load PCA matrix') pca_matrix = torch.load('./pca_matrix.pth', map_location=lambda storage, loc: storage) print('PCA matrix shape: {}'.format(pca_matrix.shape)) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt_P['dist'] = False opt_F['dist'] = False opt_C['dist'] = False rank = -1 print('Disabled distributed training.') else: opt_P['dist'] = True opt_F['dist'] = True opt_C['dist'] = True init_dist() world_size = torch.distributed.get_world_size( ) #Returns the number of processes in the current process group rank = torch.distributed.get_rank( ) #Returns the rank of current process group torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ###### Predictor&Corrector train ###### #### loading resume state if exists if opt_P['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt_P['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt_P, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0-7) if resume_state is None: # Predictor path util.mkdir_and_rename( opt_P['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt_P['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # Corrector path util.mkdir_and_rename( opt_C['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt_C['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt_P['path']['log'], 'train_' + opt_P['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt_P['path']['log'], 'val_' + opt_P['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt_P)) logger.info(option.dict2str(opt_C)) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_P['name']) else: util.setup_logger('base', opt_P['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt_P['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt_P['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt_P['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt_P, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt_P, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None assert val_loader is not None #### create model model_F = create_model(opt_F) #load pretrained model of SFTMD model_P = create_model(opt_P) model_C = create_model(opt_C) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model_P.resume_training( resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt_P['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate, schedulers # model.update_learning_rate(current_step, warmup_iter=opt_P['train']['warmup_iter']) #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_P['scale'], pca_matrix, random=True, para_input=opt_P['code_length'], kernel=opt_P['kernel_size'], noise=False, cuda=True, sig=opt_P['sig'], sig_min=opt_P['sig_min'], sig_max=opt_P['sig_max'], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(train_data['GT']) #### training Predictor model_P.feed_data(LR_img, ker_map) model_P.optimize_parameters(current_step) P_visuals = model_P.get_current_visuals() est_ker_map = P_visuals['Batch_est_ker_map'] #### log of model_P if current_step % opt_P['logger']['print_freq'] == 0: logs = model_P.get_current_log() message = 'Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_P.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### training Corrector for step in range(opt_C['step']): # test SFTMD for corresponding SR image model_F.feed_data(train_data, LR_img, est_ker_map) model_F.test() F_visuals = model_F.get_current_visuals() SR_img = F_visuals['Batch_SR'] # Test SFTMD to produce SR images # train corrector given SR image and estimated kernel map model_C.feed_data(SR_img, est_ker_map, ker_map) model_C.optimize_parameters(current_step) C_visuals = model_C.get_current_visuals() est_ker_map = C_visuals['Batch_est_ker_map'] #### log of model_C if current_step % opt_C['logger']['print_freq'] == 0: logs = model_C.get_current_log() message = 'Corrector <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_C.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_C['use_tb_logger'] and 'debug' not in opt_C[ 'name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation, to produce ker_map_list(fake) if current_step % opt_P['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): prepro = util.SRMDPreprocessing( opt_P['scale'], pca_matrix, random=True, para_input=opt_P['code_length'], kernel=opt_P['kernel_size'], noise=False, cuda=True, sig=opt_P['sig'], sig_min=opt_P['sig_min'], sig_max=opt_P['sig_max'], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(val_data['GT']) single_img_psnr = 0.0 lr_img = util.tensor2img( LR_img) #save LR image for reference # valid Predictor model_P.feed_data(LR_img, ker_map) model_P.test() P_visuals = model_P.get_current_visuals() est_ker_map = P_visuals['Batch_est_ker_map'] # Save images for reference img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt_P['path']['val_images'], img_name) # img_dir = os.path.join(opt_F['path']['val_images'], str(current_step), '_', str(step)) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, '{:s}_LR.png'.format(img_name)) util.save_img(lr_img, save_lr_path) for step in range(opt_C['step']): step += 1 idx += 1 model_F.feed_data(val_data, LR_img, est_ker_map) model_F.test() F_visuals = model_F.get_current_visuals() SR_img = F_visuals['Batch_SR'] # Test SFTMD to produce SR images model_C.feed_data(SR_img, est_ker_map, ker_map) model_C.test() C_visuals = model_C.get_current_visuals() est_ker_map = C_visuals['Batch_est_ker_map'] sr_img = util.tensor2img(F_visuals['SR']) # uint8 gt_img = util.tensor2img(F_visuals['GT']) # uint8 save_img_path = os.path.join( img_dir, '{:s}_{:d}_{:d}.png'.format( img_name, current_step, step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt_P['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] step_psnr = util.calculate_psnr( cropped_sr_img * 255, cropped_gt_img * 255) logger.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, psnr: {:.6f}' .format(epoch, current_step, step, img_name, step_psnr)) single_img_psnr += step_psnr avg_psnr += util.calculate_psnr( cropped_sr_img * 255, cropped_gt_img * 255) avg_signle_img_psnr = single_img_psnr / step logger.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, average psnr: {:.6f}' .format(epoch, current_step, step, img_name, avg_signle_img_psnr)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.6f}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> psnr: {:.6f}'. format(epoch, current_step, step, avg_psnr)) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt_P['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model_P.save(current_step) model_P.save_training_state(epoch, current_step) model_C.save(current_step) model_C.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model_P.save('latest') model_C.save('latest') logger.info('End of Predictor and Corrector training.') tb_logger.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, default='options/train/train_ESRCNN_S2L8_2.json', help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) opt = option.dict_to_nonedict(opt) if opt['path']['resume_state']: resume_state = torch.load(opt['path']['resume_state']) else: resume_state = None util.mkdir_and_rename(opt['path']['experiments_root']) util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) logger = logging.getLogger('base') if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) option.check_resume(opt) logger.info(option.dict2str(opt)) if opt['use_tb_logger'] and 'debug' not in opt['name']: from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name']) seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benckmark = True # Setup TrainDataLoader trainloader = DataLoader(opt['datasets']['train']['dataroot'], split='train') train_size = int( math.ceil(len(trainloader) / opt['datasets']['train']['batch_size'])) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(trainloader), train_size)) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) TrainDataLoader = data.DataLoader( trainloader, batch_size=opt['datasets']['train']['batch_size'], num_workers=12, shuffle=True) #Setup for validate valloader = DataLoader(opt['datasets']['train']['dataroot'], split='val') VALDataLoader = data.DataLoader( valloader, batch_size=opt['datasets']['train']['batch_size'] // 5, num_workers=1, shuffle=True) logger.info('Number of val images:{:d}'.format(len(valloader))) # Setup Model model = get_model('esrcnn_s2l8_2', opt) if resume_state: start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) else: current_step = 0 start_epoch = 0 logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs): for i, train_data in enumerate(TrainDataLoader): current_step += 1 if current_step > total_iters: break model.update_learning_rate() model.feed_data(train_data) model.optimize_parameters(current_step) if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}>'.format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v[0]) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v[0], current_step) logger.info(message) if current_step % opt['train']['val_freq'] == 0: avg_psnr = 0.0 idx = 0 for i_val, val_data in enumerate(VALDataLoader): idx += 1 img_name = val_data[3][0].split('.')[0] model.feed_data(val_data) model.val() visuals = model.get_current_visuals() pred_img = util.tensor2img(visuals['Pred']) gt_img = util.tensor2img(visuals['label']) avg_psnr += util.calculate_psnr(pred_img, gt_img) avg_psnr = avg_psnr / idx logger.info('# Validation #PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') logger_val.info( '<epoch:{:3d}, iter:{:8,d}> psnr:{:.4e}'.format( epoch, current_step, avg_psnr)) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) logger.info('Saving the final model.') model.save('latest') logger.info('End of training')
def main(): print('hello') ################# # configurations ################# device = torch.device('cuda') os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' else: raise ValueError('Vid4 does not support stage 2.') elif data_mode == 'sharp_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' elif data_mode == 'blur_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' elif data_mode == 'blur': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' elif data_mode == 'blur_comp': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' else: raise NotImplementedError if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': test_dataset_folder = '../datasets/Vid4/test' GT_dataset_folder = '../datasets/Vid4/GT' else: if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = '../datasets/REDS4/GT' #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] print('test_dataset_folder:', test_dataset_folder) subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) print('list:', subfolder_l) subfolder_l = ['../datasets/test/dance_small'] subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder for subfolder in subfolder_l: subfolder_name = osp.basename(subfolder) print(subfolder_name) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # img_GT_l = [] # for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): # img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): print('path:', img_path) img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) print(select_idx) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: print('im_path:', osp.join(save_subfolder, '{}.png'.format(img_name))) cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # # calculate PSNR # output = output / 255. # GT = np.copy(img_GT_l[img_idx]) # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel # if data_mode == 'Vid4': # bgr2y, [0, 1] # GT = data_util.bgr2ycbcr(GT, only_y=True) # output = data_util.bgr2ycbcr(output, only_y=True) # # output, GT = util.crop_border([output, GT], crop_border) # crt_psnr = util.calculate_psnr(output * 255, GT * 255) # logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) # # if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames # avg_psnr_center += crt_psnr # N_center += 1 # else: # border frames # avg_psnr_border += crt_psnr # N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('################ Tidy Outputs ################') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test))
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) util.mkdir_and_rename( opt['path']['experiments_root']) # rename old experiments if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrain_model_G' and not key == 'pretrain_model_D')) option.save(opt) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format( total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # Create model model = create_model(opt) # create logger- logger = Logger(opt) current_step = 0 start_time = time.time() print('---------- Start training -------------') for epoch in range(total_epoches): for i, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters(current_step) time_elapsed = time.time() - start_time start_time = time.time() # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt) # save models if current_step % opt['logger']['save_checkpoint_freq'] == 0: print('Saving the model at the end of iter {:d}.'.format( current_step)) model.save(current_step) # validation if current_step % opt['train']['val_freq'] == 0: print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\ img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) avg_psnr = avg_psnr / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) print('-----------------------------------') # update learning rate model.update_learning_rate() print('Saving the final model.') model.save('latest') print('End of training.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) #### Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info('Number of test images in [{:s}]: {:d}'.format( dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = osp.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_GT = False if test_loader.dataset.opt[ 'dataroot_GT'] is None else True model.feed_data(data, need_GT=need_GT) img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] img_name = osp.splitext(osp.basename(img_path))[0] model.test() visuals = model.get_current_visuals(need_GT=need_GT) sr_img = util.tensor2img(visuals['rlt']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = osp.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_GT: gt_img = util.tensor2img(visuals['GT']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) psnr = util.calculate_psnr(sr_img, gt_img) ssim = util.calculate_ssim(sr_img, gt_img) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.' .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format( img_name, psnr, ssim)) else: logger.info(img_name) if need_GT: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info( '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n' .format(test_set_name, ave_psnr, ave_ssim)) if test_results['psnr_y'] and test_results['ssim_y']: ave_psnr_y = sum(test_results['psnr_y']) / len( test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len( test_results['ssim_y']) logger.info( '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n' .format(ave_psnr_y, ave_ssim_y))
test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_GT = True model.feed_data(data, need_GT=need_GT) img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] img_name = osp.splitext(osp.basename(img_path))[0] model.test() visuals = model.get_current_visuals(need_GT=need_GT) if which_model == 'RCAN': sr_img = util.tensor2img(visuals['rlt'], out_type=np.uint8, min_max=(0, 255)) # uint8 else: sr_img = util.tensor2img(visuals['rlt']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = osp.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) if need_GT: if which_model == 'RCAN': gt_img = util.tensor2img(visuals['GT'],
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 1 #200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training t0 = time.time() model.feed_data(train_data) model.optimize_parameters(current_step) t1 = time.time() #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8d}, speed:{:5.1f}, lr:('.format(epoch, dataset_opt['batch_size']/(t1-t0), current_step) for v in model.get_current_learning_rate(): message += '{:.5f},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4f} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: model.save(current_step) model.save_training_state(epoch, current_step) logger.info('Saving models and training states.') #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan'] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4f}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4f}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4f}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. for val_data in val_loader: folder = val_data['folder'][0] idx_d = val_data['idx'].item() # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update('Test {} - {}'.format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4f}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4f}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') parser.add_argument('-single_GPU', action='store_true',help='Utilize only one GPU') parser.add_argument('-chroma', action='store_true',help='Training the chroma-channels generator') if parser.parse_args().single_GPU: available_GPUs = util.Assign_GPU(maxMemory=0.66) else: # available_GPUs = util.Assign_GPU(max_GPUs=None,maxMemory=0.8,maxLoad=0.8) available_GPUs = util.Assign_GPU(max_GPUs=None) opt = option.parse(parser.parse_args().opt, is_train=True,batch_size_multiplier=len(available_GPUs),name='JPEG'+('_chroma' if parser.parse_args().chroma else '')) if not opt['train']['resume']: util.mkdir_and_rename(opt['path']['experiments_root']) # Modify experiment name if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrained_model_G' and not key == 'pretrained_model_D')) option.save(opt) opt = option.dict_to_nonedict(opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': max_accumulation_steps = max([opt['train']['grad_accumulation_steps_G'], opt['train']['grad_accumulation_steps_D']]) train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) total_iters = int(opt['train']['niter']*max_accumulation_steps)#-current_step total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None DEBUG = False # Create model if DEBUG: from models.base_model import BaseModel model = BaseModel model.step = 0 else: model = create_model(opt,max_accumulation_steps,chroma_mode=opt['name'][:len('JPEG/chroma')]=='JPEG/chroma') # create logger logger = Logger(opt) # Save validation set results as image collage: SAVE_IMAGE_COLLAGE = True start_time,start_time_gradient_step = time.time(),model.step // max_accumulation_steps save_GT_Uncomp = True lr_too_low = False print('---------- Start training -------------') last_saving_time = time.time() recently_saved_models = deque(maxlen=4) for epoch in range(int(math.floor(model.step / train_size)),total_epoches): for i, train_data in enumerate(train_loader): model.gradient_step_num = model.step // max_accumulation_steps not_within_batch = model.step % max_accumulation_steps == (max_accumulation_steps - 1) saving_step = ((time.time()-last_saving_time)>60*opt['logger']['save_checkpoint_freq']) and not_within_batch if saving_step: last_saving_time = time.time() # save models if lr_too_low or saving_step: model.save_log() recently_saved_models.append(model.save(model.gradient_step_num)) if len(recently_saved_models)>3: model_2_delete = recently_saved_models.popleft() os.remove(model_2_delete) if model.D_exists: os.remove(model_2_delete.replace('_G.','_D.')) print('{}: Saving the model before iter {:d}.'.format(datetime.now().strftime('%H:%M:%S'),model.gradient_step_num)) if lr_too_low: break if model.step > total_iters: break # time_elapsed = time.time() - start_time # if not_within_batch: start_time = time.time() # log if model.gradient_step_num % opt['logger']['print_freq'] == 0 and not_within_batch: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = model.gradient_step_num # time_elapsed = time.time() - start_time print_rlt['time'] = (time.time() - start_time)/np.maximum(1,model.gradient_step_num-start_time_gradient_step) start_time, start_time_gradient_step = time.time(), model.gradient_step_num for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt,keys_ignore_list=['avg_est_err']) model.display_log_figure() # validation if (not_within_batch or i==0) and (model.gradient_step_num) % opt['train']['val_freq'] == 0: # and model.gradient_step_num>=opt['train']['D_init_iters']: print_rlt = OrderedDict() if model.generator_changed: print('---------- validation -------------') start_time = time.time() if False and SAVE_IMAGE_COLLAGE and model.gradient_step_num%opt['train']['val_save_freq'] == 0: #Saving training images: # GT_image_collage,quantized_image_collage = [],[] cur_train_results = model.get_current_visuals(entire_batch=True) train_psnrs = [util.calculate_psnr(util.tensor2img(cur_train_results['Decomp'][im_num], out_type=np.uint8,min_max=[0,255]), util.tensor2img(cur_train_results['Uncomp'][im_num], out_type=np.uint8,min_max=[0,255])) for im_num in range(len(cur_train_results['Decomp']))] #Save latest training batch output: save_img_path = os.path.join(os.path.join(opt['path']['val_images']), '{:d}_Tr_PSNR{:.3f}.png'.format(model.gradient_step_num, np.mean(train_psnrs))) util.save_img(np.clip(np.concatenate((np.concatenate([util.tensor2img(cur_train_results['Uncomp'][im_num], out_type=np.uint8,min_max=[0,255]) for im_num in range(len(cur_train_results['Decomp']))],0), np.concatenate( [util.tensor2img(cur_train_results['Decomp'][im_num], out_type=np.uint8,min_max=[0,255]) for im_num in range(len(cur_train_results['Decomp']))], 0)), 1), 0, 255).astype(np.uint8), save_img_path) Z_latent = [0]+([-0.5,0.5] if opt['network_G']['latent_input'] else []) print_rlt['psnr'] = 0 for cur_Z in Z_latent: model.perform_validation(data_loader=val_loader,cur_Z=cur_Z,print_rlt=print_rlt,GT_and_quantized=save_GT_Uncomp, save_images=((model.gradient_step_num) % opt['train']['val_save_freq'] == 0) or save_GT_Uncomp) if save_GT_Uncomp: # Save GT Uncomp images save_GT_Uncomp = False print_rlt['psnr'] /= len(Z_latent) model.log_dict['psnr_val'].append((model.gradient_step_num,print_rlt['psnr'])) else: print('Skipping validation because generator is unchanged') # time_elapsed = time.time() - start_time # Save to log print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = model.gradient_step_num # print_rlt['time'] = time_elapsed print_rlt['time'] = (time.time() - start_time)/np.maximum(1,model.gradient_step_num-start_time_gradient_step) # model.display_log_figure() # model.generator_changed = False logger.print_format_results('val', print_rlt,keys_ignore_list=['avg_est_err']) print('-----------------------------------') model.feed_data(train_data,mixed_Y=True) model.optimize_parameters() # update learning rate if not_within_batch: lr_too_low = model.update_learning_rate(model.gradient_step_num) # current_step += 1 if lr_too_low: print('Stopping training because LR is too low') break print('Saving the final model.') model.save(model.gradient_step_num) model.save_log() print('End of training.')
test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] test_results['niqe'] = [] test_results['niqe_gt'] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt[ 'dataroot_HR'] is None else True model.feed_data(data, need_HR=need_HR) img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals['SR']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = os.path.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # print(save_img_path)
def main(): ################# # configurations ################# os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'Vid4' # Vid4 | sharp_bicubic (REDS) # Possible combinations: (2, 16), (3, 16), (4, 16), (4, 28), (4, 52) scale = 4 layer = 52 assert (scale, layer) in [(2, 16), (3, 16), (4, 16), (4, 28), (4, 52) ], 'Unrecognized (scale, layer) combination' # model N_in = 7 model_path = '../experiments/pretrained_models/DUF_x{}_{}L_official.pth'.format( scale, layer) adapt_official = True if 'official' in model_path else False DUF_downsampling = True # True | False if layer == 16: model = DUF_arch.DUF_16L(scale=scale, adapt_official=adapt_official) elif layer == 28: model = DUF_arch.DUF_28L(scale=scale, adapt_official=adapt_official) elif layer == 52: model = DUF_arch.DUF_52L(scale=scale, adapt_official=adapt_official) #### dataset if data_mode == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4/*' else: # sharp_bicubic (REDS) test_dataset_folder = '../datasets/REDS4/{}/*'.format(data_mode) #### evaluation crop_border = 8 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode padding = 'new_info' # different from the official testing codes, which pads zeros. save_imgs = True ############################################################################ device = torch.device('cuda') save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) def read_image(img_path): '''read one image from img_path Return img: HWC, BGR, [0,1], numpy ''' img_GT = cv2.imread(img_path) img = img_GT.astype(np.float32) / 255. return img def read_seq_imgs(img_seq_path): '''read a sequence of images''' img_path_l = sorted(glob.glob(img_seq_path + '/*')) img_l = [read_image(v) for v in img_path_l] # stack to TCHW, RGB, [0,1], torch imgs = np.stack(img_l, axis=0) imgs = imgs[:, :, :, [2, 1, 0]] imgs = torch.from_numpy( np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() return imgs def index_generation(crt_i, max_n, N, padding='reflection'): ''' padding: replicate | reflection | new_info | circle ''' max_n = max_n - 1 n_pad = N // 2 return_l = [] for i in range(crt_i - n_pad, crt_i + n_pad + 1): if i < 0: if padding == 'replicate': add_idx = 0 elif padding == 'reflection': add_idx = -i elif padding == 'new_info': add_idx = (crt_i + n_pad) + (-i) elif padding == 'circle': add_idx = N + i else: raise ValueError('Wrong padding mode') elif i > max_n: if padding == 'replicate': add_idx = max_n elif padding == 'reflection': add_idx = max_n * 2 - i elif padding == 'new_info': add_idx = (crt_i - n_pad) - (i - max_n) elif padding == 'circle': add_idx = i - N else: raise ValueError('Wrong padding mode') else: add_idx = i return_l.append(add_idx) return return_l def single_forward(model, imgs_in): with torch.no_grad(): model_output = model(imgs_in) if isinstance(model_output, list) or isinstance( model_output, tuple): output = model_output[0] else: output = model_output return output sub_folder_l = sorted(glob.glob(test_dataset_folder)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] sub_folder_name_l = [] # for each sub-folder for sub_folder in sub_folder_l: sub_folder_name = sub_folder.split('/')[-1] sub_folder_name_l.append(sub_folder_name) save_sub_folder = osp.join(save_folder, sub_folder_name) img_path_l = sorted(glob.glob(sub_folder + '/*')) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_sub_folder) #### read LR images imgs = read_seq_imgs(sub_folder) #### read GT images img_GT_l = [] if data_mode == 'Vid4': sub_folder_GT = osp.join(sub_folder.replace('/BIx4/', '/GT/'), '*') else: sub_folder_GT = osp.join( sub_folder.replace('/{}/'.format(data_mode), '/GT/'), '*') for img_GT_path in sorted(glob.glob(sub_folder_GT)): img_GT_l.append(read_image(img_GT_path)) # When using the downsampling in DUF official code, we downsample the HR images if DUF_downsampling: sub_folder = sub_folder_GT img_path_l = sorted(glob.glob(sub_folder)) max_idx = len(img_path_l) imgs = read_seq_imgs(sub_folder[:-2]) avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 cal_n_border, cal_n_center = 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): c_idx = int(osp.splitext(osp.basename(img_path))[0]) select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) # get input images imgs_in = imgs.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) # Downsample the HR images H, W = imgs_in.size(3), imgs_in.size(4) if DUF_downsampling: imgs_in = util.DUF_downsample(imgs_in, scale=scale) output = single_forward(model, imgs_in) # Crop to the original shape if scale == 3: pad_h = 3 - (H % 3) pad_w = 3 - (W % 3) if pad_h > 0: output = output[:, :, :-pad_h, :] if pad_w > 0: output = output[:, :, :, :-pad_w] output_f = output.data.float().cpu().squeeze(0) output = util.tensor2img(output_f) # save imgs if save_imgs: cv2.imwrite( osp.join(save_sub_folder, '{:08d}.png'.format(c_idx)), output) #### calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT) output = data_util.bgr2ycbcr(output) if crop_border == 0: cropped_output = output cropped_GT = GT else: cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB'.format( img_idx + 1, c_idx, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr cal_n_center += 1 else: # border frames avg_psnr_border += crt_psnr cal_n_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) avg_psnr_center = avg_psnr_center / cal_n_center if cal_n_border == 0: avg_psnr_border = 0 else: avg_psnr_border = avg_psnr_border / cal_n_border logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( sub_folder_name, avg_psnr, (cal_n_center + cal_n_border), avg_psnr_center, cal_n_center, avg_psnr_border, cal_n_border)) avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('################ Tidy Outputs ################') for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(name, psnr, psnr_center, psnr_border)) logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(sub_folder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))
def main(): ################# # configurations ################# parser = argparse.ArgumentParser() parser.add_argument("--input_path", type=str, required=True) #parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--gpu_id", type=str, required=True) parser.add_argument("--screen_notation", type=str, required=True) parser.add_argument("--use_screen_notation", type=int, default=1) parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt, is_train=False) PAD = 32 total_run_time = AverageMeter() print("GPU ", torch.cuda.device_count()) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) device = torch.device('cuda') # os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'sharp_bicubic' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic" # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k" # Result_folder = "/DATA7_DB7/data/4khdr/data/Results" Input_folder = args.input_path #GT_folder = args.gt_path Result_folder = args.output_path Model_path = args.model_path # create results folder if not os.path.exists(Result_folder): os.makedirs(Result_folder, exist_ok=True) ############################################################################ #### model # if data_mode == 'Vid4': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' # else: # raise ValueError('Vid4 does not support stage 2.') # elif data_mode == 'sharp_bicubic': # if stage == 1: # # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' # elif data_mode == 'blur_bicubic': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' # elif data_mode == 'blur': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' # elif data_mode == 'blur_comp': # if stage == 1: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' # else: # model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' # else: # raise NotImplementedError model_path = Model_path if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(nf=opt['network_G']['nf'], nframes=opt['network_G']['nframes'], groups=opt['network_G']['groups'], front_RBs=opt['network_G']['front_RBs'], back_RBs=opt['network_G']['back_RBs'], predeblur=opt['network_G']['predeblur'], HR_in=opt['network_G']['HR_in'], w_TSA=opt['network_G']['w_TSA']) # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' GT_dataset_folder = '../datasets/Vid4/GT' else: if stage == 1: # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4' test_dataset_folder = Input_folder else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') # GT_dataset_folder = '../datasets/REDS4/GT' # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp' # GT_dataset_folder = GT_folder #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True # save_folder = '../results/{}'.format(data_mode) # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode) save_folder = os.path.join(Result_folder, data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) #subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*'))) # for each subfolder # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l): end = time.time() # import json # with open(args.screen_notation) as f: # frame_notation = json.load(f) frame_notation = None for subfolder in subfolder_l: input_subfolder = os.path.split(subfolder)[1] # subfolder_GT = os.path.join(GT_dataset_folder, input_subfolder) # # if not os.path.exists(subfolder_GT): # continue print("Evaluate Folders: ", input_subfolder) subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) #### read LQ and GT images imgs_LQ = data_util.read_img_seq(subfolder) # Num x 3 x H x W # img_GT_l = [] # for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))): # img_GT_l.append(data_util.read_img(None, img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] #select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug( input_subfolder, frame_notation, img_idx, max_idx, N_in, padding=padding, enable=args.use_screen_notation) if not log1 == None: logger.info('screen change') logger.info(nota) logger.info(log1) logger.info(log2) imgs_in = imgs_LQ.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to( device) # 960 x 540 gtWidth = 3840 gtHeight = 2160 intWidth_ori = imgs_in.shape[4] # 960 intHeight_ori = imgs_in.shape[3] # 540 split_lengthY = 180 split_lengthX = 320 scale = 4 intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_ intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_ pader0 = torch.nn.ReplicationPad2d( [0, intPaddingRight_, 0, intPaddingBottom_]) print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_)) intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([ intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom ]) imgs_in = torch.squeeze(imgs_in, 0) # N C H W #imgs_in = pader0(imgs_in) # N C 540 960 imgs_in = pader(imgs_in) # N C 604 1024 # assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX)) # split_lengthY = int(split_lengthY) # split_lengthX = int(split_lengthX) # split_numY = int(float(intHeight_ori) / split_lengthY) # split_numX = int(float(intWidth_ori) / split_lengthX) # splitsY = range(0, split_numY) # splitsX = range(0, split_numX) # intWidth = split_lengthX # intWidth_pad = intWidth + intPaddingLeft + intPaddingRight # intHeight = split_lengthY # intHeight_pad = intHeight + intPaddingTop + intPaddingBottom # print("split " + str(split_numY) + ' , ' + str(split_numX)) # y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32") # HWC # todo: output 4k X0 = imgs_in X0 = torch.unsqueeze(X0, 0) if flip_test: output = util.flipx4_forward(model, X0) else: output = util.single_forward(model, X0) # todo remove padding output = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight_ori) * scale, intPaddingLeft * scale:(intPaddingLeft + intWidth_ori) * scale] output = util.tensor2img(output.squeeze(0)) # for split_j, split_i in itertools.product(splitsY, splitsX): # # print(str(split_j) + ", \t " + str(split_i)) # X0 = imgs_in[:, :, # split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop, # split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft] # # # y_ = torch.FloatTensor() # # X0 = torch.unsqueeze(X0, 0) # N C H W -> 1 N C H W # # if flip_test: # output = util.flipx4_forward(model, X0) # else: # output = util.single_forward(model, X0) # # output_depadded = output[0, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale, # intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale] # output_depadded = output_depadded.squeeze(0) # output = util.tensor2img(output_depadded) # # y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale, # split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \ # np.round(output).astype(np.uint8) # plt.figure(0) # plt.title("pic") # plt.imshow(y_all) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) print("*****************current image process time \t " + str(time.time() - end) + "s ******************") total_run_time.update(time.time() - end, 1) logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder, img_idx + 1, img_name)) # # calculate PSNR # y_all = y_all / 255. # GT = np.copy(img_GT_l[img_idx]) # # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel # if data_mode == 'Vid4': # bgr2y, [0, 1] # GT = data_util.bgr2ycbcr(GT, only_y=True) # y_all = data_util.bgr2ycbcr(y_all, only_y=True) # y_all, GT = util.crop_border([y_all, GT], crop_border) # crt_psnr = util.calculate_psnr(y_all * 255, GT * 255) # logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr)) # if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames # avg_psnr_center += crt_psnr # N_center += 1 # else: # border frames # avg_psnr_border += crt_psnr # N_border += 1 # avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) # avg_psnr_center = avg_psnr_center / N_center # avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border # avg_psnr_l.append(avg_psnr) # avg_psnr_center_l.append(avg_psnr_center) # avg_psnr_border_l.append(avg_psnr_border) # logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' # 'Center PSNR: {:.6f} dB for {} frames; ' # 'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr, # (N_center + N_border), # avg_psnr_center, N_center, # avg_psnr_border, N_border)) # logger.info('################ Tidy Outputs ################') # for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, # avg_psnr_center_l, avg_psnr_border_l): # logger.info('Folder {} - Average PSNR: {:.6f} dB. ' # 'Center PSNR: {:.6f} dB. ' # 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, # psnr_border)) # logger.info('################ Final Results ################') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test))
def main(): ################# # configurations ################# os.environ['CUDA_VISIBLE_DEVICES'] = '0' data_mode = 'sharp_bicubic' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ############################################################################ #### model if data_mode == 'Vid4': if stage == 1: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth') else: raise ValueError('Vid4 does not support stage 2.') elif data_mode == 'sharp_bicubic': if stage == 1: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_L.pth') else: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth') elif data_mode == 'blur_bicubic': if stage == 1: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth') else: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth') elif data_mode == 'blur': if stage == 1: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth') else: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth') elif data_mode == 'blur_comp': if stage == 1: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth') else: model_path = osp.join(root, '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth') else: raise NotImplementedError if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, HR_in = False, False back_RBs = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, HR_in = True, True if stage == 2: HR_in = True back_RBs = 20 model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in) #### dataset if data_mode == 'Vid4': test_dataset_folder = osp.join(root, '../datasets/Vid4/BIx4/*') GT_dataset_folder = osp.join(root, '../datasets/Vid4/GT/*') else: if stage == 1: test_dataset_folder = osp.join(root, f'../datasets/REDS4/{data_mode}/*') else: raise ValueError('You should modify the test_dataset_folder path for stage 2') GT_dataset_folder = osp.join(root, '../datasets/REDS4/GT/*') #### evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True device = torch.device('cuda') save_folder = f'../results/{data_mode}' util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #### log info logger.info(f'Data: {data_mode} - {test_dataset_folder}') logger.info(f'Padding mode: {padding}') logger.info(f'Model path: {model_path}') logger.info(f'Save images: {save_imgs}') logger.info(f'Flip Test: {flip_test}') def read_image(img_path): '''read one image from img_path Return img: HWC, BGR, [0,1], numpy ''' img_GT = cv2.imread(img_path) img = img_GT.astype(np.float32) / 255. return img def read_seq_imgs(img_seq_path): '''read a sequence of images''' img_path_l = sorted(glob.glob(img_seq_path + '/*')) img_l = [read_image(v) for v in img_path_l] # stack to TCHW, RGB, [0,1], torch imgs = np.stack(img_l, axis=0) imgs = imgs[:, :, :, [2, 1, 0]] imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() return imgs def index_generation(crt_i, max_n, N, padding='reflection'): ''' padding: replicate | reflection | new_info | circle ''' max_n = max_n - 1 n_pad = N // 2 return_l = [] for i in range(crt_i - n_pad, crt_i + n_pad + 1): if i < 0: if padding == 'replicate': add_idx = 0 elif padding == 'reflection': add_idx = -i elif padding == 'new_info': add_idx = (crt_i + n_pad) + (-i) elif padding == 'circle': add_idx = N + i else: raise ValueError('Wrong padding mode') elif i > max_n: if padding == 'replicate': add_idx = max_n elif padding == 'reflection': add_idx = max_n * 2 - i elif padding == 'new_info': add_idx = (crt_i - n_pad) - (i - max_n) elif padding == 'circle': add_idx = i - N else: raise ValueError('Wrong padding mode') else: add_idx = i return_l.append(add_idx) return return_l def single_forward(model, imgs_in): with torch.no_grad(): model_output = model(imgs_in) if isinstance(model_output, list) or isinstance(model_output, tuple): output = model_output[0] else: output = model_output return output sub_folder_l = sorted(glob.glob(test_dataset_folder)) sub_folder_GT_l = sorted(glob.glob(GT_dataset_folder)) #### set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] sub_folder_name_l = [] # for each sub-folder for sub_folder, sub_folder_GT in zip(sub_folder_l, sub_folder_GT_l): sub_folder_name = sub_folder.split('/')[-1] sub_folder_name_l.append(sub_folder_name) save_sub_folder = osp.join(save_folder, sub_folder_name) img_path_l = sorted(glob.glob(sub_folder + '/*')) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_sub_folder) #### read LR images imgs = read_seq_imgs(sub_folder) #### read GT images img_GT_l = [] for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT, '*'))): img_GT_l.append(read_image(img_GT_path)) avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 cal_n_border, cal_n_center = 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): c_idx = int(osp.splitext(osp.basename(img_path))[0]) select_idx = index_generation(c_idx, max_idx, N_in, padding=padding) # get input images imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) output = single_forward(model, imgs_in) output_f = output.data.float().cpu().squeeze(0) if flip_test: # flip W output = single_forward(model, torch.flip(imgs_in, (-1, ))) output = torch.flip(output, (-1, )) output = output.data.float().cpu().squeeze(0) output_f = output_f + output # flip H output = single_forward(model, torch.flip(imgs_in, (-2, ))) output = torch.flip(output, (-2, )) output = output.data.float().cpu().squeeze(0) output_f = output_f + output # flip both H and W output = single_forward(model, torch.flip(imgs_in, (-2, -1))) output = torch.flip(output, (-2, -1)) output = output.data.float().cpu().squeeze(0) output_f = output_f + output output_f = output_f / 4 output = util.tensor2img(output_f) # save imgs if save_imgs: cv2.imwrite(osp.join(save_sub_folder, f'{c_idx:08d}.png'), output) #### calculate PSNR output = output / 255. GT = np.copy(img_GT_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on Y channels if data_mode == 'Vid4': # bgr2y, [0, 1] GT = data_util.bgr2ycbcr(GT) output = data_util.bgr2ycbcr(output) if crop_border == 0: cropped_output = output cropped_GT = GT else: cropped_output = output[crop_border:-crop_border, crop_border:-crop_border] cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border] crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255) logger.info(f'{img_idx+1:3d} - {c_idx:25}.png \tPSNR: {crt_psnr:.6f} dB') if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr cal_n_center += 1 else: # border frames avg_psnr_border += crt_psnr cal_n_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (cal_n_center + cal_n_border) avg_psnr_center = avg_psnr_center / cal_n_center if cal_n_border == 0: avg_psnr_border = 0 else: avg_psnr_border = avg_psnr_border / cal_n_border logger.info(f'Folder {sub_folder_name} - Average PSNR: {avg_psnr:.6f} dB for {(cal_n_center + cal_n_border)} frames; ' f'Center PSNR: {avg_psnr_center:.6f} dB for {cal_n_center} frames; ' f'Border PSNR: {avg_psnr_border:.6f} dB for {cal_n_border} frames.') avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('################ Tidy Outputs ################') for name, psnr, psnr_center, psnr_border in zip(sub_folder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info(f'Folder {name} - Average PSNR: {psnr:.6f} dB. ' f'Center PSNR: {psnr_center:.6f} dB. ' f'Border PSNR: {psnr_border:.6f} dB.') logger.info('################ Final Results ################') logger.info(f'Data: {data_mode} - {test_dataset_folder}') logger.info(f'Padding mode: {padding}') logger.info(f'Model path: {model_path}') logger.info(f'Save images: {save_imgs}') logger.info(f'Flip Test: {flip_test}') logger.info(f'Total Average PSNR: {sum(avg_psnr_l) / len(avg_psnr_l):.6f} dB for {len(sub_folder_l)} clips. ' f'Center PSNR: {sum(avg_psnr_center_l) / len(avg_psnr_center_l):.6f} dB. ' f'Border PSNR: {sum(avg_psnr_border_l) / len(avg_psnr_border_l):.6f} dB.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YMAL file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) first_time = True save_count = 0 max_psnr = 0 for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): if first_time: start_time = time.time() first_time = False current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: end_time = time.time() logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, , time:{:.3f}> '.format( epoch, current_step, model.get_current_learning_rate(), end_time - start_time) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) start_time = time.time() # validation if current_step % opt['train'][ 'val_freq'] == 0 and rank <= 0 and current_step >= opt[ 'train']['val_min_iter']: avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['img_GT_bic4x']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) #util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info( '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format( epoch, current_step, avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: model.save_training_state(epoch, current_step) if current_step % opt['logger'][ 'save_checkpoint_freq'] == 0 and current_step >= opt[ 'train']['val_min_iter']: if rank <= 0: logger.info('Saving models and training states.') save_count += 1 if avg_psnr >= max_psnr: max_psnr = avg_psnr model.save('best') else: model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def main(): need_chop = True cal_metrics = True save_images = False scale = 4 # data_mode = 'KON_GRU_NONLOCAL_ratio=1' # test_dataset_folder = '../../datasets/KON/HR' # file_list = '../datasets/KON/test_100.txt' data_mode = 'Vimeo_GRU' test_dataset_folder = '../../datasets/vimeo_septuplet/sequences' file_list = '../datasets/Vimeo/sep_testlist.txt' # model path # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_GRU/models/latest_G.pth' # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_LSTM/models/latest_G.pth' # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_LSTM_NONLOCAL/models/latest_G.pth' # model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/KON_GRU_NONLOCAL_ratio=1/models/latest_G.pth' model_path = '/mnt/02520c27-ec8e-4661-b88f-05aa2011ffa7/sxw/Zooming-Slow-Mo-CVPR-2020/experiments/Vimeo_GRU/models/480000_G.pth' # model = Sakuya_arch.LunaTokisGRU(64, 7, 8, 5, 5) # model = Sakuya_arch.LunaTokis(64, 7, 8, 5, 5) # model = Sakuya_arch.NonLocalNet(64, 7, 8, 5, 5) model = Sakuya_arch.LunaTokisGRU(64, 7, 8, 5, 5) if torch.cuda.is_available() and os.environ['CUDA_VISIBLE_DEVICES'] != '': device = torch.device('cuda') else: device = torch.device('cpu') save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') model_params = util.get_model_total_params(model) # log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Model path: {}'.format(model_path)) logger.info('Model parameters: {} M'.format(model_params)) logger.info('Device: {}'.format(device)) model.load_state_dict(torch.load(model_path), strict=True) model = model.to(device) model.eval() test_set = DatasetFromFolderTest(test_dataset_folder, 7, scale, file_list, transform=transform()) test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=8) test_num = len(test_loader) avg_psnr = 0.0 avg_ssim = 0.0 avg_time = 0.0 with torch.no_grad(): for data in test_loader: input = Variable(data['LQs']).to(device) target = Variable(data['GT']).to(device) info = os.path.join(data['INFO'][0].split('/')[-2], data['INFO'][0].split('/')[-1]) t0 = time.time() # 显存不足情况下,将整张图片分为多个部分进行测试 if need_chop: predictions = chop_forward(input, model, scale, device) else: predictions = model(input) predictions = predictions[0] # batch t1 = time.time() pre_num = len(predictions) time_predicted = (t1 - t0) / pre_num psnr_predicted = 0.0 ssim_predicted = 0.0 for i in range(pre_num): # save images pre = util.tensor2img(predictions[i]) if save_images: img_path = os.path.join(save_folder, info) if not os.path.exists(img_path): os.makedirs(img_path) img_path = os.path.join(img_path, 'im{}.jpg'.format(i + 1)) util.save_img(pre, img_path) if cal_metrics: # calculate PSNR and SSIM tar = util.tensor2img(target[0][i]) psnr_predicted += util.PSNR(pre, tar) ssim_predicted += util.SSIM(pre, tar) psnr_predicted /= pre_num ssim_predicted /= pre_num avg_psnr += psnr_predicted avg_ssim += ssim_predicted avg_time += time_predicted logger.info( "Processing: %s || PSNR: %.4f || SSIM: %.4f || Avg Timer: %.4f sec." % (info, psnr_predicted, ssim_predicted, time_predicted)) avg_time /= test_num avg_psnr /= test_num avg_ssim /= test_num logger.info( "Finished: %s || PSNR: %.4f || SSIM: %.4f || Avg Timer: %.4f sec." % (data_mode, avg_psnr, avg_ssim, avg_time))