def vimeo(img_root, lmdb_save_path): """Create lmdb for the vimeo dataset, each image with a fixed size GT: [3, 256, 448], key: 00001_0001_4 """ #### configurations BATCH = 50000 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') txt_file = osp.join(img_root, 'sep_trainlist.txt') with open(txt_file, 'r') as f: lines = f.readlines() img_list = [line.strip() for line in lines] imgs, keys = [], [] for item in img_list: key_pre = item.replace('/', '_') im_dir = osp.join(img_root, 'sequences', item) names = sorted(os.listdir(im_dir)) for name in names: imgs.append(osp.join(im_dir, name)) keys.append(key_pre + '_' + name[2]) im1 = cv2.imread(imgs[0], cv2.IMREAD_UNCHANGED) H, W, C = im1.shape print('data size per image is: ', im1.nbytes) data_size = im1.nbytes * len(imgs) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb txn = env.begin(write=True) for i in range(0, len(imgs), BATCH): batch_imgs = imgs[i:i + BATCH] batch_keys = keys[i:i + BATCH] batch_data = read_imgs_multi_thread(batch_imgs, batch_keys, n_thread) pbar = util.ProgressBar(len(batch_imgs)) for k, v in batch_data.items(): pbar.update('Write {}'.format(k)) key_byte = k.encode('ascii') txn.put(key_byte, v) txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'vimeo_train' meta_info['resolution'] = '{}_{}_{}'.format(C, H, W) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def vimeo_test(img_root, lmdb_save_path): gt_root = osp.join(img_root, 'target') lq_root = osp.join(img_root, 'low_resolution') txt_file = osp.join(img_root, 'sep_testlist.txt') with open(txt_file, 'r') as f: lines = f.readlines() img_list = [line.strip() for line in lines] imgs, keys = [], [] for item in img_list: gt_key = 'gt_' + item.replace('/', '_') + '_4' lq_key_pre = 'lq_' + item.replace('/', '_') gt_img = osp.join(gt_root, item, 'im4.png') imgs.append(gt_img) keys.append(gt_key) lq_im_dir = osp.join(lq_root, item) lq_names = sorted(os.listdir(lq_im_dir)) for name in lq_names: imgs.append(osp.join(lq_im_dir, name)) keys.append(lq_key_pre + '_' + name[2]) im1 = cv2.imread(imgs[0], cv2.IMREAD_UNCHANGED) H, W, C = im1.shape im2 = cv2.imread(imgs[1], cv2.IMREAD_UNCHANGED) lH, lW, lC = im2.shape print('data size per image is: ', im1.nbytes) data_size = im1.nbytes * len(imgs) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb txn = env.begin(write=True) img_data = read_imgs_multi_thread(imgs, keys, 40) pbar = util.ProgressBar(len(imgs)) for k, v in img_data.items(): pbar.update('Write {}'.format(k)) key_byte = k.encode('ascii') txn.put(key_byte, v) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'vimeo_test' meta_info['gt_resolution'] = '{}_{}_{}'.format(C, H, W) meta_info['lq_resolution'] = '{}_{}_{}'.format(lC, lH, lW) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def read_imgs_multi_thread(imgs, keys, n_thread=40): #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(imgs)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(imgs, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.'.format(len(imgs))) return dataset
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--exp_name', type=str, default='temp') parser.add_argument('--degradation_type', type=str, default=None) parser.add_argument('--sigma_x', type=float, default=None) parser.add_argument('--sigma_y', type=float, default=None) parser.add_argument('--theta', type=float, default=None) args = parser.parse_args() if args.exp_name == 'temp': opt = option.parse(args.opt, is_train=True) else: opt = option.parse(args.opt, is_train=True, exp_name=args.exp_name) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) inner_loop_name = opt['train']['maml']['optimizer'][0] + str( opt['train']['maml']['adapt_iter']) + str( math.floor(math.log10(opt['train']['maml']['lr_alpha']))) meta_loop_name = opt['train']['optim'][0] + str( math.floor(math.log10(opt['train']['lr_G']))) if args.degradation_type is not None: if args.degradation_type == 'preset': opt['datasets']['val']['degradation_mode'] = args.degradation_type else: opt['datasets']['val']['degradation_type'] = args.degradation_type if args.sigma_x is not None: opt['datasets']['val']['sigma_x'] = args.sigma_x if args.sigma_y is not None: opt['datasets']['val']['sigma_y'] = args.sigma_y if args.theta is not None: opt['datasets']['val']['theta'] = args.theta if opt['datasets']['val']['degradation_mode'] == 'set': degradation_name = str(opt['datasets']['val']['degradation_type'])\ + '_' + str(opt['datasets']['val']['sigma_x']) \ + '_' + str(opt['datasets']['val']['sigma_y'])\ + '_' + str(opt['datasets']['val']['theta']) else: degradation_name = opt['datasets']['val']['degradation_mode'] patch_name = 'p{}x{}'.format( opt['train']['maml']['patch_size'], opt['train']['maml'] ['num_patch']) if opt['train']['maml']['use_patch'] else 'full' use_real_flag = '_ideal' if opt['train']['use_real'] else '' folder_name = opt[ 'name'] + '_' + degradation_name # + '_' + inner_loop_name + meta_loop_name + '_' + degradation_name + '_' + patch_name + use_real_flag if args.exp_name != 'temp': folder_name = args.exp_name #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: #util.mkdir_and_rename( # opt['path']['experiments_root']) # rename experiment folder if exists #util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' # and 'pretrain_model' not in key and 'resume' not in key)) if not os.path.exists(opt['path']['experiments_root']): os.mkdir(opt['path']['experiments_root']) # raise ValueError('Path does not exists - check path') # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') #logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + folder_name) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': pass elif phase == 'val': if '+' in opt['datasets']['val']['name']: val_set, val_loader = [], [] valname_list = opt['datasets']['val']['name'].split('+') for i in range(len(valname_list)): val_set.append( create_dataset( dataset_opt, scale=opt['scale'], kernel_size=opt['datasets']['train'] ['kernel_size'], model_name=opt['network_E']['which_model_E'], idx=i)) val_loader.append( create_dataloader(val_set[-1], dataset_opt, opt, None)) else: val_set = create_dataset( dataset_opt, scale=opt['scale'], kernel_size=opt['datasets']['train']['kernel_size'], model_name=opt['network_E']['which_model_E']) # val_set = loader.get_dataset(opt, train=False) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) #### create model models = create_model(opt) assert len(models) == 2 model, est_model = models[0], models[1] modelcp, est_modelcp = create_model(opt) _, est_model_fixed = create_model(opt) center_idx = (opt['datasets']['val']['N_frames']) // 2 lr_alpha = opt['train']['maml']['lr_alpha'] update_step = opt['train']['maml']['adapt_iter'] pd_log = pd.DataFrame( columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours']) def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44): """ Crop given patches. Args: LR_seq: (B=1) x T x C x H x W HR: (B=1) x C x H x W patch_size (int, optional): Return: B(=batch_size) x T x C x H x W """ # Find the lowest resolution cropped_lr = [] cropped_hr = [] assert HR.size(0) == 1 LR_seq_ = LR_seq[0] HR_ = HR[0] for _ in range(num_patches_for_batch): patch_lr, patch_hr = preprocessing.common_crop( LR_seq_, HR_, patch_size=patch_size // 2) cropped_lr.append(patch_lr) cropped_hr.append(patch_hr) cropped_lr = torch.stack(cropped_lr, dim=0) cropped_hr = torch.stack(cropped_hr, dim=0) return cropped_lr, cropped_hr # Single GPU # PSNR_rlt: psnr_init, psnr_before, psnr_after psnr_rlt = [{}, {}] # SSIM_rlt: ssim_init, ssim_after ssim_rlt = [{}, {}] pbar = util.ProgressBar(len(val_set)) for val_data in val_loader: folder = val_data['folder'][0] idx_d = int(val_data['idx'][0].split('/')[0]) if 'name' in val_data.keys(): name = val_data['name'][0][center_idx][0] else: #name = '{}/{:08d}'.format(folder, idx_d) name = folder train_folder = os.path.join('../results_for_paper', folder_name, name) hr_train_folder = os.path.join(train_folder, 'hr') bic_train_folder = os.path.join(train_folder, 'bic') maml_train_folder = os.path.join(train_folder, 'maml') #slr_train_folder = os.path.join(train_folder, 'slr') # print(train_folder) if not os.path.exists(train_folder): os.makedirs(train_folder, exist_ok=False) if not os.path.exists(hr_train_folder): os.mkdir(hr_train_folder) if not os.path.exists(bic_train_folder): os.mkdir(bic_train_folder) if not os.path.exists(maml_train_folder): os.mkdir(maml_train_folder) #if not os.path.exists(slr_train_folder): # os.mkdir(slr_train_folder) for i in range(len(psnr_rlt)): if psnr_rlt[i].get(folder, None) is None: psnr_rlt[i][folder] = [] for i in range(len(ssim_rlt)): if ssim_rlt[i].get(folder, None) is None: ssim_rlt[i][folder] = [] if idx_d % 10 != 5: #continue pass cropped_meta_train_data = {} meta_train_data = {} meta_test_data = {} # Make SuperLR seq using estimation model meta_train_data['GT'] = val_data['LQs'][:, center_idx] meta_test_data['LQs'] = val_data['LQs'][0:1] meta_test_data['GT'] = val_data['GT'][0:1, center_idx] # Check whether the batch size of each validation data is 1 assert val_data['SuperLQs'].size(0) == 1 if opt['network_G']['which_model_G'] == 'TOF': LQs = meta_test_data['LQs'] B, T, C, H, W = LQs.shape LQs = LQs.reshape(B * T, C, H, W) Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True) meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H * opt['scale'], W * opt['scale']) ## Before start training, first save the bicubic, real outputs # Bicubic modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG) modelcp.feed_data(meta_test_data) modelcp.test() model_start_visuals = modelcp.get_current_visuals(need_GT=True) hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb') start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb') #####imageio.imwrite(os.path.join(hr_train_folder, '{:08d}.png'.format(idx_d)), hr_image) #####imageio.imwrite(os.path.join(bic_train_folder, '{:08d}.png'.format(idx_d)), start_image) psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image)) ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image)) modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy( est_model.netE) ########## SLR LOSS Preparation ############ est_model_fixed.load_network(opt['path']['fixed_E'], est_model_fixed.netE) optim_params = [] for k, v in modelcp.netG.named_parameters(): if v.requires_grad: optim_params.append(v) if not opt['train']['use_real']: for k, v in est_modelcp.netE.named_parameters(): if v.requires_grad: optim_params.append(v) if opt['train']['maml']['optimizer'] == 'Adam': inner_optimizer = torch.optim.Adam( optim_params, lr=lr_alpha, betas=(opt['train']['maml']['beta1'], opt['train']['maml']['beta2'])) elif opt['train']['maml']['optimizer'] == 'SGD': inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha) else: raise NotImplementedError() # Inner Loop Update st = time.time() for i in range(update_step): # Make SuperLR seq using UPDATED estimation model if not opt['train']['use_real']: est_modelcp.feed_data(val_data) # est_model.test() est_modelcp.forward_without_optim() superlr_seq = est_modelcp.fake_L meta_train_data['LQs'] = superlr_seq else: meta_train_data['LQs'] = val_data['SuperLQs'] if opt['network_G']['which_model_G'] == 'TOF': # Bicubic upsample to match the size LQs = meta_train_data['LQs'] B, T, C, H, W = LQs.shape LQs = LQs.reshape(B * T, C, H, W) Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True) meta_train_data['LQs'] = Bic_LQs.reshape( B, T, C, H * opt['scale'], W * opt['scale']) # Update both modelcp + estmodelcp jointly inner_optimizer.zero_grad() if opt['train']['maml']['use_patch']: cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \ crop(meta_train_data['LQs'], meta_train_data['GT'], opt['train']['maml']['num_patch'], opt['train']['maml']['patch_size']) modelcp.feed_data(cropped_meta_train_data) else: modelcp.feed_data(meta_train_data) loss_train = modelcp.calculate_loss() ##################### SLR LOSS ################### est_model_fixed.feed_data(val_data) est_model_fixed.test() slr_initialized = est_model_fixed.fake_L slr_initialized = slr_initialized.to('cuda') if opt['network_G']['which_model_G'] == 'TOF': loss_train += 10 * F.l1_loss( LQs.to('cuda').squeeze(0), slr_initialized) else: loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'), slr_initialized) loss_train.backward() inner_optimizer.step() et = time.time() update_time = et - st modelcp.feed_data(meta_test_data) modelcp.test() model_update_visuals = modelcp.get_current_visuals(need_GT=False) update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb') # Save and calculate final image imageio.imwrite( os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)), update_image) psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image)) ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image)) name_df = '{}/{:08d}'.format(folder, idx_d) if name_df in pd_log.index: pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1] pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1] pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1] pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1] else: pd_log.loc[name_df] = [ psnr_rlt[0][folder][-1], psnr_rlt[1][folder][-1], ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1] ] pd_log.to_csv( os.path.join('../results_for_paper', folder_name, 'psnr_update.csv')) pbar.update( 'Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s' .format(folder, idx_d, psnr_rlt[0][folder][-1], ssim_rlt[0][folder][-1], psnr_rlt[1][folder][-1], ssim_rlt[1][folder][-1], update_time)) psnr_rlt_avg = {} psnr_total_avg = 0. # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2]) for k, v in psnr_rlt[0].items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt[0]) log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) psnr_rlt_avg = {} psnr_total_avg = 0. # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2]) for k, v in psnr_rlt[1].items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt[1]) log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) ssim_rlt_avg = {} ssim_total_avg = 0. # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1]) for k, v in ssim_rlt[0].items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt[0]) log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) ssim_rlt_avg = {} ssim_total_avg = 0. # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1]) for k, v in ssim_rlt[1].items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt[1]) log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) logger.info('End of evaluation.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, default='options/train/train_EDVR_woTSA_M.yml', help='Path to option YAML file.') parser.add_argument('--set', dest='set_opt', default=None, nargs=argparse.REMAINDER, help='set options') args = parser.parse_args() opt = option.parse(args.opt, args.set_opt, is_train=True) #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU print('Training from state: {}'.format(opt['path']['resume_state'])) device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options elif opt['auto_resume']: exp_dir = opt['path']['experiments_root'] # first time run: create dirs if not os.path.exists(exp_dir): os.makedirs(exp_dir) os.makedirs(opt['path']['models']) os.makedirs(opt['path']['training_state']) os.makedirs(opt['path']['val_images']) os.makedirs(opt['path']['tb_logger']) resume_state = None else: # detect experiment directory and get the latest state state_dir = opt['path']['training_state'] state_files = [ x for x in os.listdir(state_dir) if x.endswith('state') ] # no valid state detected if len(state_files) < 1: print( 'No previous training state found, train from start state') resume_state = None else: state_files = sorted(state_files, key=lambda x: int(x.split('.')[0])) latest_state = state_files[-1] print('Training from lastest state: {}'.format(latest_state)) latest_state_file = os.path.join(state_dir, latest_state) opt['path']['resume_state'] = latest_state_file device_id = torch.cuda.current_device() resume_state = torch.load( latest_state_file, map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) else: resume_state = None if resume_state is None and not opt['auto_resume'] and not opt['no_log']: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.2: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader if opt['datasets']['train']['ratio']: dataset_ratio = opt['datasets']['train']['ratio'] else: dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) if dataset_opt['mode'] in ['MetaREDS', 'MetaREDSOnline']: train_sampler = MetaIterSampler(train_set, dataset_opt['batch_size'], len(opt['scale']), dataset_ratio) elif dataset_opt['mode'] in ['REDS', 'MultiREDS']: train_sampler = IterSampler(train_set, dataset_opt['batch_size'], dataset_ratio) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) print("PROGRESS: {:02d}%".format( int(current_step / total_iters * 100))) #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. for val_data in val_loader: folder = val_data['folder'][0] idx_d = val_data['idx'].item() # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update('Test {} - {}'.format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()
def REDS(mode): """Create lmdb for the REDS dataset, each image with a fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2] Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4 Flow map is quantized by mmcv and saved in png format """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'train_sharp': img_folder = '../../datasets/REDS/train_sharp' lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_bicubic': img_folder = '../../datasets/REDS/train_sharp_bicubic' lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur_bicubic': img_folder = '../../datasets/REDS/train_blur_bicubic' lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur': img_folder = '../../datasets/REDS/train_blur' lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_blur_comp': img_folder = '../../datasets/REDS/train_blur_comp' lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_flowx4': img_folder = '../../datasets/REDS/train_sharp_flowx4' lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb' H_dst, W_dst = 360, 320 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_{}_wval'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def SR4K(mode): """Create lmdb for the 4k dataset, each image with a fixed size GT: [3, 3840, 2160], key: 000_00000000 LR: [3, 960, 540], key: 000_00000000 key: 000_00000000 """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 500 # After BATCH images, lmdb commits, if read_all_imgs = False train_txt = '/home/mcc/4khdr/train.txt' if mode == 'train_4k': img_folder = '/home/mcc/4khdr/image/4k' lmdb_save_path = '/home/mcc/4khdr/4k.lmdb' H_dst, W_dst = 1080, 1920 BATCH = 1000 elif mode == 'train_540p': img_folder = '/home/mcc/4khdr/image/540p' lmdb_save_path = '/home/mcc/4khdr/540p.lmdb' H_dst, W_dst = 270, 480 BATCH = 5000 n_thread = 12 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') with open(train_txt, 'r') as f: train_list = [x.strip() for x in f.readlines()] all_img_list = [] for dirpath, _, fnames in sorted(os.walk(img_folder)): if not osp.basename(dirpath)[:-2] in train_list: continue for fname in sorted(fnames): if fname.endswith('.png'): img_path = osp.join(dirpath, fname) all_img_list.append(img_path) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_{}_wval'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def youku(mode): """Create lmdb for the youku dataset, each image with a fixed size GT: [3, 1080, 1920] or [3, 1152, 2048], key: 00000_000000 LR: [3, 270, 480] or [3, 288, 512], key: 00000_000000 """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'gt': train_folder = '/media/tclwh2/public/youku/train/gt' val_folder = '/media/tclwh2/public/youku/val/gt' lmdb_save_path = '/media/tclwh2/public/youku/youku_train_gt.lmdb' H_dst, W_dst = (1080, 1152), (1920, 2048) elif mode == 'lq': train_folder = '/media/tclwh2/public/youku/train/lq' val_folder = '/media/tclwh2/public/youku/val/lq' lmdb_save_path = '/media/tclwh2/public/youku/youku_train_lq.lmdb' H_dst, W_dst = (270, 288), (480, 512) n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') train_img_list = data_util._get_paths_from_images(train_folder) val_img_list = data_util._get_paths_from_images(val_folder) all_img_list = sorted(train_img_list + val_img_list) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment cnt_1 = 993 * 100 cnt_2 = 7 * 100 assert cnt_1 + cnt_2 == len(all_img_list) img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED) data_size_per_img_1 = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes data_size_per_img_2 = cv2.imread(all_img_list[30 * 100], cv2.IMREAD_UNCHANGED).nbytes assert data_size_per_img_1 != data_size_per_img_2 print('data size per image is: %d and %d' % (data_size_per_img_1, data_size_per_img_2)) data_size = data_size_per_img_1 * cnt_1 + data_size_per_img_2 + cnt_2 env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape assert H in H_dst and W in W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'youku_train_{}'.format(mode) channel = 3 meta_info['resolution_2'] = '{}_{}_{}'.format(channel, H_dst[1], W_dst[1]) meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst[0], W_dst[0]) meta_info['res_2_list'] = '_'.join( ['31', '44', '54', '101', '121', '142', '177']) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def REDS(): '''create lmdb for the REDS dataset, each image with fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 ''' #### configurations mode = 'train_sharp' # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp if mode == 'train_sharp': img_folder = '/home/xtwang/datasets/REDS/train_sharp' lmdb_save_path = '/home/xtwang/datasets/REDS/train_sharp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_bicubic': img_folder = '/home/xtwang/datasets/REDS/train_sharp_bicubic' lmdb_save_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur_bicubic': img_folder = '/home/xtwang/datasets/REDS/train_blur_bicubic' lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur': img_folder = '/home/xtwang/datasets/REDS/train_blur' lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_blur_comp': img_folder = '/home/xtwang/datasets/REDS/train_blur_comp' lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_comp_wval.lmdb' H_dst, W_dst = 720, 1280 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") #### whether the lmdb file exist if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') a = split_rlt[-2] b = split_rlt[-1].split('.png')[0] keys.append(a + '_' + b) #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### create lmdb environment data_size_per_img = dataset['000_00000000'].nbytes if 'flow' in mode: data_size_per_img = dataset['000_00000002_n1'].nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) with env.begin(write=True) as txn: for key in keys: pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_{}_wval'.format(mode) if 'flow' in mode: meta_info['resolution'] = '{}_{}_{}'.format(1, H_dst, W_dst) else: meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): ############################################ # # set options # ############################################ parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) ############################################ # # distributed training settings # ############################################ if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("Rank:", rank) print("World Size", world_size) print("------------------DIST-------------------------") ############################################ # # loading resume state if exists # ############################################ if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None ############################################ # # mkdir and loggers # ############################################ if 'debug' in opt['name']: debug_mode = True else: debug_mode = False if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('base_val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_', level=logging.INFO, screen=True) print("set train log") util.setup_logger('base_val', opt['path']['log'], 'val_', level=logging.INFO, screen=True) print("set val log") logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ############################################ # # create train and val dataloader # ############################################ #### # dataset_ratio = 200 # enlarge the size of each epoch dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': if opt['datasets']['train'].get('split', None): train_set, val_set = create_dataset(dataset_opt) else: train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) # total_iters = int(opt['train']['niter']) # total_epochs = int(math.ceil(total_iters / train_size)) total_iters = train_size total_epochs = int(opt['train']['epoch']) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) total_epochs = int(opt['train']['epoch']) if opt['train']['enable'] == False: total_epochs = 1 else: # train_sampler = None train_sampler = RandomBalancedSampler(train_set, train_size) train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler, vscode_debug=debug_mode) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': if not opt['datasets']['train'].get('split', None): val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None, vscode_debug=debug_mode) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None ############################################ # # create model # ############################################ #### model = create_model(opt) print("Model Created! ") #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 print("Not Resume Training") ############################################ # # training # ############################################ logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) model.train_AverageMeter() saved_total_loss = 10e10 saved_total_PSNR = -1 saved_total_SSIM = -1 for epoch in range(start_epoch, total_epochs): ############################################ # # Start a new epoch # ############################################ current_step = 0 if opt['dist']: train_sampler.set_epoch(epoch) for train_idx, train_data in enumerate(train_loader): # print('current_step', current_step) if 'debug' in opt['name']: img_dir = os.path.join(opt['path']['train_images']) util.mkdir(img_dir) LQs = train_data['LQs'] # B N C H W if not 'sr' in opt['name']: GTenh = train_data['GTenh'] GTinp = train_data['GTinp'] for imgs, name in zip([LQs, GTenh, GTinp], ['LQs', 'GTenh', 'GTinp']): num = imgs.size(1) for i in range(num): img = util.tensor2img(imgs[0, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format( train_idx, str(name), i)) util.save_img(img, save_img_path) else: if 'GT' in train_data: GT_name = 'GT' elif 'GTs' in train_data: GT_name = 'GTs' GT = train_data[GT_name] for imgs, name in zip([LQs, GT], ['LQs', GT_name]): if name == 'GT': num = imgs.size(0) img = util.tensor2img(imgs[0, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format( train_idx, str(name), 0)) util.save_img(img, save_img_path) elif name == 'GTs': num = imgs.size(1) for i in range(num): img = util.tensor2img(imgs[:, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format( train_idx, str(name), i)) util.save_img(img, save_img_path) else: num = imgs.size(1) for i in range(num): img = util.tensor2img(imgs[:, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format( train_idx, str(name), i)) util.save_img(img, save_img_path) if (train_idx >= 3): # set to 0, just do validation break # if pre-load weight first do validation and skip the first epoch # if opt['path'].get('pretrain_model_G', None) and epoch == 0: # epoch += 1 # break if opt['train']['enable'] == False: message_train_loss = 'None' break current_step += 1 if current_step > total_iters: print("Total Iteration Reached !") break #### update learning rate if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': pass else: model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) model.train_AverageMeter_update() #### log if current_step % opt['logger']['print_freq'] == 0: logs_inst, logs_avg = model.get_current_log( ) # training loss mode='train' message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' # if 'debug' in opt['name']: # debug model print the instant loss # for k, v in logs_inst.items(): # message += '{:s}: {:.4e} '.format(k, v) # # tensorboard logger # if opt['use_tb_logger'] and 'debug' not in opt['name']: # if rank <= 0: # tb_logger.add_scalar(k, v, current_step) # for avg loss current_iters_epoch = epoch * total_iters + current_step for k, v in logs_avg.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_iters_epoch) if rank <= 0: logger.info(message) # saving models if epoch == 1: save_filename = '{:04d}_{}.pth'.format(0, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) if rank <= 0: logger.info('Saving models and training states.') save_filename = '{:04d}'.format(epoch) model.save(save_filename) # ======================================================================= # # Main validation loop # # ======================================================================= # if opt['datasets'].get('val', None): if opt['dist']: # multi-GPU testing psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} # the averaged loss val_loss_rlt_avg = {} val_loss_total_avg = 0. if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range( rank, len(val_set), world_size): # distributed parallel validation # print('idx', idx) if 'debug' in opt['name']: if (idx >= 3): break if (idx >= 1000): break val_data = val_set[idx] # use idx method to fetch must extend batch dimension val_data['LQs'].unsqueeze_(0) val_data['GTenh'].unsqueeze_(0) val_data['GTinp'].unsqueeze_(0) key = val_data['key'][0] # IMG_0034_00809 max_idx = len(val_set) val_name = 'val_set' num = model.get_info( ) # each model has different number of loss if psnr_rlt.get(val_name, None) is None: psnr_rlt[val_name] = torch.zeros([num, max_idx], dtype=torch.float32, device='cuda') if ssim_rlt.get(val_name, None) is None: ssim_rlt[val_name] = torch.zeros([num, max_idx], dtype=torch.float32, device='cuda') if val_loss_rlt.get(val_name, None) is None: val_loss_rlt[val_name] = torch.zeros( [num, max_idx], dtype=torch.float32, device='cuda') model.feed_data(val_data) model.test() avg_loss, loss_list = model.get_loss(ret=1) save_enable = True if idx >= 100: save_enable = False psnr_list, ssim_list = model.compute_current_psnr_ssim( save=save_enable, name=key, save_path=opt['path']['val_images']) # print('psnr_list',psnr_list) assert len(loss_list) == num assert len(psnr_list) == num for i in range(num): psnr_rlt[val_name][i, idx] = psnr_list[i] ssim_rlt[val_name][i, idx] = ssim_list[i] val_loss_rlt[val_name][i, idx] = loss_list[i] # print('psnr_rlt[val_name][i, idx]',psnr_rlt[val_name][i, idx]) # print('ssim_rlt[val_name][i, idx]',ssim_rlt[val_name][i, idx]) # print('val_loss_rlt[val_name][i, idx] ',val_loss_rlt[val_name][i, idx] ) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( key, idx, max_idx)) # # collect data for _, v in psnr_rlt.items(): for i in v: dist.reduce(i, 0) for _, v in ssim_rlt.items(): for i in v: dist.reduce(i, 0) for _, v in val_loss_rlt.items(): for i in v: dist.reduce(i, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): # key, value # print('k', k, 'v', v, 'v.shape', v.shape) psnr_rlt_avg[k] = [] for i in range(num): non_zero_idx = v[i, :].nonzero() # logger.info('non_zero_idx {}'.format(non_zero_idx.shape)) # check matrix = v[i, :][non_zero_idx] # print('matrix', matrix) value = torch.mean(matrix).cpu().item() # print('value', value) psnr_rlt_avg[k].append(value) psnr_total_avg += psnr_rlt_avg[k][i] psnr_total_avg = psnr_total_avg / (len(psnr_rlt) * num) log_p = '# Validation # Avg. PSNR: {:.2f},'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): for i, it in enumerate(v): log_p += ' {}: {:.2f}'.format(i, it) logger.info(log_p) logger_val.info(log_p) # ssim ssim_rlt_avg = {} ssim_total_avg = 0. for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = [] for i in range(num): non_zero_idx = v[i, :].nonzero() # print('non_zero_idx', non_zero_idx) matrix = v[i, :][non_zero_idx] # print('matrix', matrix) value = torch.mean(matrix).cpu().item() # print('value', value) ssim_rlt_avg[k].append( torch.mean(matrix).cpu().item()) ssim_total_avg += ssim_rlt_avg[k][i] ssim_total_avg /= (len(ssim_rlt) * num) log_s = '# Validation # Avg. SSIM: {:.2f},'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): for i, it in enumerate(v): log_s += ' {}: {:.2f}'.format(i, it) logger.info(log_s) logger_val.info(log_s) # added val_loss_rlt_avg = {} val_loss_total_avg = 0. for k, v in val_loss_rlt.items(): # k, key, the folder name # v, value, the torch matrix val_loss_rlt_avg[k] = [] # loss0 - loss_N for i in range(num): non_zero_idx = v[i, :].nonzero() # print('non_zero_idx', non_zero_idx) matrix = v[i, :][non_zero_idx] # print('matrix', matrix) value = torch.mean(matrix).cpu().item() # print('value', value) val_loss_rlt_avg[k].append( torch.mean(matrix).cpu().item()) val_loss_total_avg += val_loss_rlt_avg[k][i] val_loss_total_avg /= (len(val_loss_rlt) * num) log_l = '# Validation # Avg. Loss: {:.4e},'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): for i, it in enumerate(v): log_l += ' {}: {:.4e}'.format(i, it) logger.info(log_l) logger_val.info(log_l) message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f}, Val Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, val_loss_total_avg)) else: pbar = util.ProgressBar(len(val_loader)) model.val_loss_AverageMeter() model.val_AverageMeter_para() for val_inx, val_data in enumerate(val_loader): # if 'debug' in opt['name']: # if (val_inx >= 10): # break save_enable = True if val_inx >= 100: save_enable = False if val_inx >= 100: break key = val_data['key'][0] folder = key[:-6] model.feed_data(val_data) model.test() avg_loss, loss_list = model.get_loss(ret=1) model.val_loss_AverageMeter_update(loss_list, avg_loss) psnr_list, ssim_list = model.compute_current_psnr_ssim( save=save_enable, name=key, save_path=opt['path']['val_images']) model.val_AverageMeter_para_update(psnr_list, ssim_list) if 'debug' in opt['name']: msg_psnr = '' msg_ssim = '' for i, psnr in enumerate(psnr_list): msg_psnr += '{} :{:.02f} '.format(i, psnr) for i, ssim in enumerate(ssim_list): msg_ssim += '{} :{:.02f} '.format(i, ssim) logger.info('{}_{:02d} {}'.format( key, val_inx, msg_psnr)) logger.info('{}_{:02d} {}'.format( key, val_inx, msg_ssim)) pbar.update('Test {} - {}'.format(key, val_inx)) # toal validation log lr = '' for v in model.get_current_learning_rate(): lr += '{:.5e}'.format(v) logs_avg, logs_psnr_avg, psnr_total_avg, ssim_total_avg, val_loss_total_avg = model.get_current_log( mode='val') msg_logs_avg = '' for k, v in logs_avg.items(): msg_logs_avg += '{:s}: {:.4e} '.format(k, v) logger_val.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format( epoch, lr, msg_logs_avg)) logger.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format( epoch, lr, msg_logs_avg)) msg_logs_psnr_avg = '' for k, v in logs_psnr_avg.items(): msg_logs_psnr_avg += '{:s}: {:.4e} '.format(k, v) logger_val.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format( epoch, lr, msg_logs_psnr_avg)) logger.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format( epoch, lr, msg_logs_psnr_avg)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('val_psnr', psnr_total_avg, epoch) tb_logger.add_scalar('val_loss', val_loss_total_avg, epoch) ############################################ # # end of validation, save model # ############################################ # if rank <= 0: logger.info("Finished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_loss_total_avg: saved_total_loss = val_loss_total_avg #torch.save(model.state_dict(), args.save_path + "/best" + ".pth") model.save('best') logger.info( "Best Weights updated for decreased validation loss") else: logger.info( "Weights Not updated for undecreased validation loss") if saved_total_PSNR <= psnr_total_avg: saved_total_PSNR = psnr_total_avg model.save('bestPSNR') logger.info( "Best Weights updated for increased validation PSNR") else: logger.info( "Weights Not updated for unincreased validation PSNR") ############################################ # # end of one epoch, schedule LR # ############################################ model.train_AverageMeter_reset() # add scheduler todo if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': for scheduler in model.schedulers: # scheduler.step(val_loss_total_avg) scheduler.step(val_loss_total_avg) if rank <= 0: logger.info('Saving the final model.') model.save('last') logger.info('End of training.') tb_logger.close()
def SDR4k(mode): """Create lmdb for the REDS dataset, each image with a fixed size 10bit: [3, 2160, 3840], key: 00000000_000 4bit: [3, 2160, 3840], key: 00000000_000 """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == "10bit": # img_folder = "..\\..\\datasets\\SDR_10bit" # for windows # lmdb_save_path = "..\\..\\datasets\\SDR_10bit.lmdb" # for windows img_folder = "../../datasets/SDR4k/train/SDR_10BIT_patch" # for linux lmdb_save_path = "../../datasets/SDR4k/train/SDR_10BIT_patch.lmdb" # for linux # H_dst, W_dst = 2160, 3840 H_dst, W_dst = 480, 480 elif mode == "4bit": # img_folder = "..\\..\\datasets\\SDR_4bit" # for windows # lmdb_save_path = "..\\..\\datasets\\SDR_4bit.lmdb" img_folder = "../../datasets/SDR4k/train/SDR_4BIT_patch" # for linux lmdb_save_path = "../../datasets/SDR4k/train/SDR_4BIT_patch.lmdb" # for linux # H_dst, W_dst = 2160, 3840 H_dst, W_dst = 480, 480 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') # for linux # split_rlt = img_path.split("\\") # for windows folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) # keys: 00000000_000_000 if read_all_imgs: #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb print("Start writing...") pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) rm_list = [] for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) # delete the image rm_list.append(path) if not read_all_imgs and idx % BATCH == 0: txn.commit() for img in rm_list: print('os.system('rm {}')'.format(img)) rm_list = [] txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'SDR4k_{}'.format(mode) channel = 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def OURS(mode="input"): '''create lmdb for the REDS dataset, each image with fixed size GT: [3, H, W], key: 000000_000000 LR: [3, H, W], key: 000000_000000 key: 000000_00000 ** 记得前面我们的数据结构吗?{子目录名}_{图片名} ''' #### configurations mode = mode # ** 数据模式: input / gt read_all_imgs = False # whether real all images to the memory. Set False with limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'input': img_folder = './../../datasets/train/input' # ** 使用相对路径指向我们的数据集的input lmdb_save_path = './../../datasets/train_input_wval.lmdb' # ** 待会生成的lmdb文件存储的路径 '''原来使用全局路径,我们使用相对路径''' H_dst, W_dst = 480, 640 # 帧的大小:H,W elif mode == 'gt': img_folder = './../../datasets/train/gt' # ** 使用相对路径指向我们的数据集的input lmdb_save_path = './../../datasets/train_gt_wval.lmdb' # ** 待会生成的lmdb文件存储的路径 '''原来使用全局路径,我们使用相对路径''' H_dst, W_dst = 480, 640 # 帧的大小:H,W n_thread = 2 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError( "lmdb_save_path must end with \'lmdb\'.") # 保存格式必须以“.lmdb”结尾 #### whether the lmdb file exist if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format( lmdb_save_path)) # 文件是否已经存在 sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images( img_folder) # 获取input/gt下所有帧的完整路径名,作为list keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') # 取子文件夹名 xxxxxx a = split_rlt[-2] # 取帧的名字,出去文件后缀 xxxxxx b = split_rlt[-1].split('.jpg')[0] # ** 我们的图像是".jpg"结尾的 keys.append(a + '_' + b) if read_all_imgs: # read_all_images = False,所以这部分不管 #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread( all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes # 每帧图像大小(byte为单位) if 'flow' in mode: data_size_per_img = dataset['000_00000002_n1'].nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) # 总的需要多少空间 env = lmdb.open(lmdb_save_path, map_size=data_size * 10) # 索取这么多的比特数 #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_img_list, keys): idx = idx + 1 pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information # 存储元数据:名字(str)+分辨率(str) meta_info = {} meta_info['name'] = 'OURS_{}_wval'.format(mode) # ** 现在的数据集是OURS了 if 'flow' in mode: meta_info['resolution'] = '{}_{}_{}'.format(1, H_dst, W_dst) else: meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def vimeo7(): '''create lmdb for the Vimeo90K-7 frames dataset, each image with fixed size GT: [3, 256, 448] Only need the 4th frame currently, e.g., 00001_0001_4 LR: [3, 64, 112] With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 key: Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 ''' #### configurations mode = 'GT' # GT | LR batch = 3000 # TODO: depending on your mem size if mode == 'GT': img_folder = '/data/datasets/SR/vimeo_septuplet/sequences/train' lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_GT.lmdb' txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 256, 448 elif mode == 'LR': img_folder = '/data/datasets/SR/vimeo_septuplet/sequences_LR/LR/x4/train' lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_LR7.lmdb' txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 64, 112 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") #### whether the lmdb file exist if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') with open(txt_file) as f: train_l = f.readlines() train_l = [v.strip() for v in train_l] all_img_list = [] keys = [] for line in train_l: folder = line.split('/')[0] sub_folder = line.split('/')[1] file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + '/*') all_img_list.extend(file_l) for j in range(7): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) if mode == 'GT': all_img_list = [v for v in all_img_list if v.endswith('.png')] keys = [v for v in keys] print('Calculating the total size of images...') data_size = sum(os.stat(v).st_size for v in all_img_list) #### read all images to memory (multiprocessing) print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) #### create lmdb environment env = lmdb.open(lmdb_save_path, map_size=data_size * 10) txn = env.begin(write=True) # txn is a Transaction object #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) i = 0 for path, key in zip(all_img_list, keys): pbar.update('Write {}'.format(key)) img = cv2.imread(path, cv2.IMREAD_UNCHANGED) key_byte = key.encode('ascii') H, W, C = img.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, img) i += 1 if i % batch == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish reading and writing {} images.'.format(len(all_img_list))) print('Finish writing lmdb.') #### create meta information meta_info = {} if mode == 'GT': meta_info['name'] = 'Vimeo7_train_GT' elif mode == 'LR': meta_info['name'] = 'Vimeo7_train_LR7' meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) key_set = set() for key in keys: a, b, _ = key.split('_') key_set.add('{}_{}'.format(a, b)) meta_info['keys'] = key_set pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'Vimeo7_train_keys.pkl'), "wb")) print('Finish creating lmdb meta info.')
def VideoSR(mode): """Create lmdb for the Video dataset, each image with a fixed size LR: [3, 540, 960], key: 000_00000000 GT: [3, 2160, 3840], key: 000_00000000 key: 000_00000000 """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 2000 #5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'GT': img_folder = '/home/yhliu/AI4K/train1_HR_png/' lmdb_save_path = '/home/yhliu/AI4K/train1_HR.lmdb' H_dst, W_dst = 2160, 3840 elif mode == 'LR': #img_folder = '/home/yhliu/AI4K/contest2/train2_LR_png/' #lmdb_save_path = '/home/yhliu/AI4K/contest2/train2_LR.lmdb' img_folder = '/home/yhliu/BasicSR/results/trainLR_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k_220000/trainLR_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k_220000/' lmdb_save_path = '/home/yhliu/AI4K/contest2/train2_LR_35_220000.lmdb' H_dst, W_dst = 540, 960 n_thread = 40 #40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys.append(folder + '_' + img_name) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'AI4K_{}_train1'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--exp_name', type=str, default='temp') parser.add_argument('--degradation_type', type=str, default=None) parser.add_argument('--sigma_x', type=float, default=None) parser.add_argument('--sigma_y', type=float, default=None) parser.add_argument('--theta', type=float, default=None) args = parser.parse_args() if args.exp_name == 'temp': opt = option.parse(args.opt, is_train=False) else: opt = option.parse(args.opt, is_train=False, exp_name=args.exp_name) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) inner_loop_name = opt['train']['maml']['optimizer'][0] + str(opt['train']['maml']['adapt_iter']) + str(math.floor(math.log10(opt['train']['maml']['lr_alpha']))) meta_loop_name = opt['train']['optim'][0] + str(math.floor(math.log10(opt['train']['lr_G']))) if args.degradation_type is not None: if args.degradation_type == 'preset': opt['datasets']['val']['degradation_mode'] = args.degradation_type else: opt['datasets']['val']['degradation_type'] = args.degradation_type if args.sigma_x is not None: opt['datasets']['val']['sigma_x'] = args.sigma_x if args.sigma_y is not None: opt['datasets']['val']['sigma_y'] = args.sigma_y if args.theta is not None: opt['datasets']['val']['theta'] = args.theta if 'degradation_mode' not in opt['datasets']['val'].keys(): degradation_name = '' elif opt['datasets']['val']['degradation_mode'] == 'set': degradation_name = '_' + str(opt['datasets']['val']['degradation_type'])\ + '_' + str(opt['datasets']['val']['sigma_x']) \ + '_' + str(opt['datasets']['val']['sigma_y'])\ + '_' + str(opt['datasets']['val']['theta']) else: degradation_name = '_' + opt['datasets']['val']['degradation_mode'] folder_name = opt['name'] + '_' + degradation_name if args.exp_name != 'temp': folder_name = args.exp_name torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': pass elif phase == 'val': if '+' in opt['datasets']['val']['name']: raise NotImplementedError('Do not use + signs in test mode') else: val_set = create_dataset(dataset_opt, scale=opt['scale'], kernel_size=opt['datasets']['train']['kernel_size'], model_name=opt['network_E']['which_model_E']) # val_set = loader.get_dataset(opt, train=False) val_loader = create_dataloader(val_set, dataset_opt, opt, None) print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) #### create model models = create_model(opt) assert len(models) == 2 model, est_model = models[0], models[1] modelcp, est_modelcp = create_model(opt) _, est_model_fixed = create_model(opt) center_idx = (opt['datasets']['val']['N_frames']) // 2 lr_alpha = opt['train']['maml']['lr_alpha'] update_step = opt['train']['maml']['adapt_iter'] with_GT = False if opt['datasets']['val']['mode'] == 'demo' else True pd_log = pd.DataFrame(columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours']) def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44): """ Crop given patches. Args: LR_seq: (B=1) x T x C x H x W HR: (B=1) x C x H x W patch_size (int, optional): Return: B(=batch_size) x T x C x H x W """ # Find the lowest resolution cropped_lr = [] cropped_hr = [] assert HR.size(0) == 1 LR_seq_ = LR_seq[0] HR_ = HR[0] for _ in range(num_patches_for_batch): patch_lr, patch_hr = preprocessing.common_crop(LR_seq_, HR_, patch_size=patch_size // 2) cropped_lr.append(patch_lr) cropped_hr.append(patch_hr) cropped_lr = torch.stack(cropped_lr, dim=0) cropped_hr = torch.stack(cropped_hr, dim=0) return cropped_lr, cropped_hr # Single GPU # PSNR_rlt: psnr_init, psnr_before, psnr_after psnr_rlt = [{}, {}] # SSIM_rlt: ssim_init, ssim_after ssim_rlt = [{}, {}] pbar = util.ProgressBar(len(val_set)) for val_data in val_loader: folder = val_data['folder'][0] idx_d = int(val_data['idx'][0].split('/')[0]) if 'name' in val_data.keys(): name = val_data['name'][0][center_idx][0] else: name = folder train_folder = os.path.join('../test_results', folder_name, name) maml_train_folder = os.path.join(train_folder, 'DynaVSR') if not os.path.exists(train_folder): os.makedirs(train_folder, exist_ok=False) if not os.path.exists(maml_train_folder): os.mkdir(maml_train_folder) for i in range(len(psnr_rlt)): if psnr_rlt[i].get(folder, None) is None: psnr_rlt[i][folder] = [] for i in range(len(ssim_rlt)): if ssim_rlt[i].get(folder, None) is None: ssim_rlt[i][folder] = [] cropped_meta_train_data = {} meta_train_data = {} meta_test_data = {} # Make SuperLR seq using estimation model meta_train_data['GT'] = val_data['LQs'][:, center_idx] meta_test_data['LQs'] = val_data['LQs'][0:1] meta_test_data['GT'] = val_data['GT'][0:1, center_idx] if with_GT else None # Check whether the batch size of each validation data is 1 assert val_data['LQs'].size(0) == 1 if opt['network_G']['which_model_G'] == 'TOF': LQs = meta_test_data['LQs'] B, T, C, H, W = LQs.shape LQs = LQs.reshape(B*T, C, H, W) Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True) meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale']) ## Before start testing # Bicubic Model Results modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG) modelcp.feed_data(meta_test_data, need_GT=with_GT) modelcp.test() if with_GT: model_start_visuals = modelcp.get_current_visuals(need_GT=True) hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb') start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb') psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image)) ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image)) modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(est_model.netE) ########## SLR LOSS Preparation ############ est_model_fixed.load_network(opt['path']['fixed_E'], est_model_fixed.netE) optim_params = [] for k, v in modelcp.netG.named_parameters(): if v.requires_grad: optim_params.append(v) if not opt['train']['use_real']: for k, v in est_modelcp.netE.named_parameters(): if v.requires_grad: optim_params.append(v) if opt['train']['maml']['optimizer'] == 'Adam': inner_optimizer = torch.optim.Adam(optim_params, lr=lr_alpha, betas=( opt['train']['maml']['beta1'], opt['train']['maml']['beta2'])) elif opt['train']['maml']['optimizer'] == 'SGD': inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha) else: raise NotImplementedError() # Inner Loop Update st = time.time() for i in range(update_step): # Make SuperLR seq using UPDATED estimation model if not opt['train']['use_real']: est_modelcp.feed_data(val_data) est_modelcp.forward_without_optim() superlr_seq = est_modelcp.fake_L meta_train_data['LQs'] = superlr_seq else: meta_train_data['LQs'] = val_data['SuperLQs'] if opt['network_G']['which_model_G'] == 'TOF': # Bicubic upsample to match the size LQs = meta_train_data['LQs'] B, T, C, H, W = LQs.shape LQs = LQs.reshape(B*T, C, H, W) Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True) meta_train_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale']) # Update both modelcp + estmodelcp jointly inner_optimizer.zero_grad() if opt['train']['maml']['use_patch']: cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \ crop(meta_train_data['LQs'], meta_train_data['GT'], opt['train']['maml']['num_patch'], opt['train']['maml']['patch_size']) modelcp.feed_data(cropped_meta_train_data) else: modelcp.feed_data(meta_train_data) loss_train = modelcp.calculate_loss() ##################### SLR LOSS ################### est_model_fixed.feed_data(val_data) est_model_fixed.test() slr_initialized = est_model_fixed.fake_L slr_initialized = slr_initialized.to('cuda') if opt['network_G']['which_model_G'] == 'TOF': loss_train += 10 * F.l1_loss(LQs.to('cuda').squeeze(0), slr_initialized) else: loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'), slr_initialized) loss_train.backward() inner_optimizer.step() et = time.time() update_time = et - st modelcp.feed_data(meta_test_data, need_GT=with_GT) modelcp.test() model_update_visuals = modelcp.get_current_visuals(need_GT=False) update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb') # Save and calculate final image imageio.imwrite(os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)), update_image) if with_GT: psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image)) ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image)) name_df = '{}/{:08d}'.format(folder, idx_d) if name_df in pd_log.index: pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1] pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1] pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1] pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1] else: pd_log.loc[name_df] = [psnr_rlt[0][folder][-1], psnr_rlt[1][folder][-1], ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]] pd_log.to_csv(os.path.join('../test_results', folder_name, 'psnr_update.csv')) pbar.update('Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s' .format(folder, idx_d, psnr_rlt[0][folder][-1], ssim_rlt[0][folder][-1], psnr_rlt[1][folder][-1], ssim_rlt[1][folder][-1], update_time )) else: pbar.update() if with_GT: psnr_rlt_avg = {} psnr_total_avg = 0. # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2]) for k, v in psnr_rlt[0].items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt[0]) log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) print(log_s) psnr_rlt_avg = {} psnr_total_avg = 0. # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2]) for k, v in psnr_rlt[1].items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt[1]) log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) print(log_s) ssim_rlt_avg = {} ssim_total_avg = 0. # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1]) for k, v in ssim_rlt[0].items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt[0]) log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) print(log_s) ssim_rlt_avg = {} ssim_total_avg = 0. # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1]) for k, v in ssim_rlt[1].items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt[1]) log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) print(log_s) print('End of evaluation.')
def vimeo90k(): '''create lmdb for the Vimeo90K dataset, each image with fixed size GT: [3, 256, 448] Only need the 4th frame currently, e.g., 00001_0001_4 LR: [3, 64, 112] With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 key: Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 ''' #### configurations mode = 'GT' # GT | LR if mode == 'GT': img_folder = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences' lmdb_save_path = '/home/xtwang/datasets/vimeo90k/vimeo90k_train_GT.lmdb' txt_file = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 256, 448 elif mode == 'LR': img_folder = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' lmdb_save_path = '/home/xtwang/datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' txt_file = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 64, 112 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") #### whether the lmdb file exist if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') with open(txt_file) as f: train_l = f.readlines() train_l = [v.strip() for v in train_l] all_img_list = [] keys = [] for line in train_l: folder = line.split('/')[0] sub_folder = line.split('/')[1] file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + '/*') all_img_list.extend(file_l) for j in range(7): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) if mode == 'GT': # read the 4th frame only for GT mode print('Only keep the 4th frame.') all_img_list = [v for v in all_img_list if v.endswith('im4.png')] keys = [v for v in keys if v.endswith('_4')] #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format(n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list))) #### create lmdb environment data_size_per_img = dataset['00001_0001_4'].nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) with env.begin(write=True) as txn: for key in keys: pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) print('Finish writing lmdb.') #### create meta information meta_info = {} if mode == 'GT': meta_info['name'] = 'Vimeo90K_train_GT' elif mode == 'LR': meta_info['name'] = 'Vimeo90K_train_LR' meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) key_set = set() for key in keys: a, b, _ = key.split('_') key_set.add('{}_{}'.format(a, b)) meta_info['keys'] = key_set pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #pdb.set_trace() #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist(args.launcher) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader #pdb.set_trace() dataset_ratio = 1000 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: if rank <= 0: # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if current_step % 50000 == 0: torch.cuda.empty_cache() torch.cuda.empty_cache() if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def AI4K(model='gt'): model = model read_all_imgs = False BATCH = 700 if model == 'gt': img_folder = 'dataset/gt' lmdb_save_path = 'dataset/train_gt_wval.lmdb' H_dst, W_dst = 2160, 3840 if model == 'X4': img_folder = 'dataset/X4' lmdb_save_path = 'dataset/train_x4_wval.lmdb' H_dst, W_dst = 540, 960 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") #### whether the lmdb file exist # if osp.exists(lmdb_save_path): # print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) # sys.exit(1) print('Reading image path list ...') all_clips_list = sorted(os.listdir(img_folder)) all_clips_list_path = [] for x in all_clips_list: all_clips_list_path.append(os.path.join(img_folder, x)) keys = [] all_imgs_path = [] index_clip = 0 for clips_path in all_clips_list_path: index_clip += 1 if model == 'X4': for imgs_x4_path in data_util._get_paths_from_images(clips_path): all_imgs_path.append(imgs_x4_path) for index_imgs_x4 in range(100): a = (index_imgs_x4 + 1) // 7 + 1 b = (index_imgs_x4 + 1) % 7 if b == 0: b = 7 c = '%.5d' % (index_clip) + '_' + '%.4d' % (a) + '_' + '%d' % ( b) keys.append(c) else: for index, imgs_path in enumerate( data_util._get_paths_from_images(clips_path)): if index % 7 == 3: all_imgs_path.append(imgs_path) for index_imgs_gt in range(100): if index_imgs_gt % 7 == 3: a = (index_imgs_gt + 1) // 7 + 1 c = '%.5d' % (index_clip) + '_' + '%.4d' % (a) + '_4' keys.append(c) data_size_per_img = cv2.imread(all_imgs_path[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_imgs_path) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) pbar = util.ProgressBar(len(all_imgs_path)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_imgs_path, keys): idx = idx + 1 pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = cv2.imread(path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} if model == 'gt': meta_info['name'] = 'AI4K_train_GT' elif model == 'X4': meta_info['name'] = 'AI4K_train_X4' meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) key_set = set() for key in keys: a, b, _ = key.split('_') key_set.add('{}_{}'.format(a, b)) meta_info['keys'] = key_set pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key and 'wandb_load_run_path' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) if opt['use_wandb_logger'] and 'debug' not in opt['name']: json_path = os.path.join(os.path.expanduser('~'), '.wandb_api_keys.json') if os.path.exists(json_path): with open(json_path, 'r') as j: json_file = json.loads(j.read()) os.environ['WANDB_API_KEY'] = json_file['ryul99'] wandb.init(project="mmsr", config=opt, sync_tensorboard=True) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) if opt['use_wandb_logger'] and 'debug' not in opt['name']: wandb.config.update({'random_seed': seed}) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data, noise_mode=opt['datasets']['train']['noise_mode'], noise_rate=opt['datasets']['train']['noise_rate']) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt['name']: if rank <= 0: wandb.log({k: v}, step=current_step) if rank <= 0: logger.info(message) #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data( val_data, noise_mode=opt['datasets']['val']['noise_mode'], noise_rate=opt['datasets']['val']['noise_rate']) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) if opt['use_wandb_logger'] and 'debug' not in opt['name']: wandb.log({'psnr': avg_psnr}, step=current_step) else: # video restoration validation if opt['dist']: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data, noise_mode=opt['datasets']['val'] ['noise_mode'], noise_rate=opt['datasets']['val'] ['noise_rate']) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr( rlt_img, gt_img) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt[ 'name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt[ 'name']: lq_img, rlt_img, gt_img = map( util.tensor2img, [ visuals['LQ'], visuals['rlt'], visuals['GT'] ]) wandb.log({'psnr_avg': psnr_total_avg}, step=current_step) wandb.log(psnr_rlt_avg, step=current_step) wandb.log( { 'Validation Image': [ wandb.Image(lq_img[:, :, [2, 1, 0]], caption='LQ'), wandb.Image(rlt_img[:, :, [2, 1, 0]], caption='output'), wandb.Image(gt_img[:, :, [2, 1, 0]], caption='GT'), ] }, step=current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. for val_data in val_loader: folder = val_data['folder'][0] idx_d = val_data['idx'].item() # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data, noise_mode=opt['datasets']['val'] ['noise_mode'], noise_rate=opt['datasets']['val'] ['noise_rate']) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update('Test {} - {}'.format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) if opt['use_wandb_logger'] and 'debug' not in opt[ 'name']: lq_img, rlt_img, gt_img = map( util.tensor2img, [visuals['LQ'], visuals['rlt'], visuals['GT']]) wandb.log({'psnr_avg': psnr_total_avg}, step=current_step) wandb.log(psnr_rlt_avg, step=current_step) wandb.log( { 'Validation Image': [ wandb.Image(lq_img[:, :, [2, 1, 0]], caption='LQ'), wandb.Image(rlt_img[:, :, [2, 1, 0]], caption='output'), wandb.Image(gt_img[:, :, [2, 1, 0]], caption='GT'), ] }, step=current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.close()
def REDS(): """create lmdb for the REDS dataset, each image with fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 """ #### configurations mode = "train_sharp" read_all_imgs = ( False ) # whether real all images to the memory. Set False with limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp if mode == "train_sharp": img_folder = "/home/xtwang/datasets/REDS/train_sharp" lmdb_save_path = "/home/xtwang/datasets/REDS/train_sharp_wval.lmdb" H_dst, W_dst = 720, 1280 elif mode == "train_sharp_bicubic": img_folder = "/home/xtwang/datasets/REDS/train_sharp_bicubic" lmdb_save_path = "/home/xtwang/datasets/REDS/train_sharp_bicubic_wval.lmdb" H_dst, W_dst = 180, 320 elif mode == "train_blur_bicubic": img_folder = "/home/xtwang/datasets/REDS/train_blur_bicubic" lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_bicubic_wval.lmdb" H_dst, W_dst = 180, 320 elif mode == "train_blur": img_folder = "/home/xtwang/datasets/REDS/train_blur" lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_wval.lmdb" H_dst, W_dst = 720, 1280 elif mode == "train_blur_comp": img_folder = "/home/xtwang/datasets/REDS/train_blur_comp" lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_comp_wval.lmdb" H_dst, W_dst = 720, 1280 n_thread = 40 ######################################################## if not lmdb_save_path.endswith(".lmdb"): raise ValueError("lmdb_save_path must end with 'lmdb'.") #### whether the lmdb file exist if osp.exists(lmdb_save_path): print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print("Reading image path list ...") all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split("/") a = split_rlt[-2] b = split_rlt[-1].split(".png")[0] keys.append(a + "_" + b) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print("Read images with multiprocessing, #thread: {} ...".format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): """get the image data and update pbar""" key = arg[0] dataset[key] = arg[1] pbar.update("Reading {}".format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print("Finish reading {} images.\nWrite lmdb...".format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes if "flow" in mode: data_size_per_img = dataset["000_00000002_n1"].nbytes print("data size per image is: ", data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_img_list, keys): idx = idx + 1 pbar.update("Write {}".format(key)) key_byte = key.encode("ascii") data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if "flow" in mode: H, W = data.shape assert H == H_dst and W == W_dst, "different shape." else: H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, "different shape." txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print("Finish writing lmdb.") #### create meta information meta_info = {} meta_info["name"] = "REDS_{}_wval".format(mode) if "flow" in mode: meta_info["resolution"] = "{}_{}_{}".format(1, H_dst, W_dst) else: meta_info["resolution"] = "{}_{}_{}".format(3, H_dst, W_dst) meta_info["keys"] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb")) print("Finish creating lmdb meta info.")
def HDR(mode): """Create lmdb for the REDS dataset, each image with a fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2] Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4 Flow map is quantized by mmcv and saved in png format """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = args.batch # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'train_sharp': # img_folder = '../../datasets/REDS/train_sharp' # lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb' img_folder = '/DATA/wangshen_data/REDS/train_sharp' lmdb_save_path = '/DATA/wangshen_data/REDS/train_sharp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_bicubic': # img_folder = '../../datasets/REDS/train_sharp_bicubic' # lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' img_folder = '/DATA/wangshen_data/REDS/train_sharp_bicubic' lmdb_save_path = '/DATA/wangshen_data/REDS/train_sharp_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur_bicubic': img_folder = '../../datasets/REDS/train_blur_bicubic' lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb' H_dst, W_dst = 180, 320 elif mode == 'train_blur': img_folder = '../../datasets/REDS/train_blur' lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_blur_comp': img_folder = '../../datasets/REDS/train_blur_comp' lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb' H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_flowx4': img_folder = '../../datasets/REDS/train_sharp_flowx4' lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb' H_dst, W_dst = 360, 320 elif mode == 'train_540p': img_folder = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_540p" lmdb_save_path = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}.lmdb'.format( args.name) H_dst, W_dst = 540, 960 elif mode == 'train_4k': img_folder = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_4k" lmdb_save_path = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}.lmdb'.format( args.name) H_dst, W_dst = 540, 960 H_dst, W_dst = 2160, 3840 elif mode == 'both': img_folder_S = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_540p" lmdb_save_path_S = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}_540p.lmdb'.format( args.name) H_dst_S, W_dst_S = 256, 256 img_folder_L = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_4k" lmdb_save_path_L = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}_4k.lmdb'.format( args.name) H_dst_L, W_dst_L = 1024, 1024 assert mode == 'both' N = 8 # divide one 4k into 8 parts n_thread = 40 ######################################################## import os import shutil if os.path.exists(lmdb_save_path_S): shutil.rmtree(lmdb_save_path_S) if os.path.exists(lmdb_save_path_L): shutil.rmtree(lmdb_save_path_L) if not lmdb_save_path_S.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path_S): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path_S)) sys.exit(1) if not lmdb_save_path_L.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path_L): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path_L)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list_S = data_util._get_paths_from_images_suzhou( img_folder_S, args.small) all_img_list_L = data_util._get_paths_from_images_suzhou( img_folder_L, args.small) keys_S = [] keys_L = [] for img_path in all_img_list_S: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys_S.append(folder + '_' + img_name) for img_path in all_img_list_L: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] keys_L.append(folder + '_' + img_name) assert keys_S == keys_L keys_patch = [] for img_path in all_img_list_L: split_rlt = img_path.split('/') folder = split_rlt[-2] img_name = split_rlt[-1].split('.png')[0] for i in range(N): keys_patch.append(folder + '_' + img_name + '_' + str(i)) keys = keys_S if read_all_imgs: # todo never use that # read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list_S)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list_S, keys_S): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list_S))) #### create lmdb environment # for small pic data_size_per_img = cv2.imread(all_img_list_S[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per small image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list_S) # env_S = lmdb.open(lmdb_save_path_S, map_size=data_size * 10) # for large pic data_size_per_img = cv2.imread(all_img_list_L[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per large image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list_S) # env_L = lmdb.open(lmdb_save_path_L, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list_S)) # txn_S = env_S.begin(write=True) # txn_L = env_L.begin(write=True) for idx_all, (path_S, path_L, key) in enumerate(zip(all_img_list_S, all_img_list_L, keys)): # pbar.update('Write {}'.format(key)) data_S = dataset[key] if read_all_imgs else cv2.imread( path_S, cv2.IMREAD_UNCHANGED) # shape H W C ndarray data_L = dataset[key] if read_all_imgs else cv2.imread( path_L, cv2.IMREAD_UNCHANGED) # process the black blank H_S = data_S.shape[0] # 540 W_S = data_S.shape[1] # 960 H_L = data_L.shape[0] W_L = data_L.shape[1] blank_1_S = 0 blank_2_S = 0 for i in range(H_S): if not sum(data_S[:, :, 0][i]) == 0: blank_1_S = i - 1 # assert not sum(data_S[:, :, 0][i+1]) == 0 break for i in range(H_S): if not sum(data_S[:, :, 0][H_S - i - 1]) == 0: blank_2_S = (H_S - 1) - i - 1 # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0 break print('LQ :', blank_1_S, blank_2_S) if blank_1_S == -1: print('LQ has no blank') blank_1_L = 0 blank_2_L = 0 for i in range(H_L): if not sum(data_L[:, :, 0][i]) == 0: blank_1_L = i - 1 # assert not sum(data_L[:, :, 0][i + 1]) == 0 break for i in range(H_L): if not sum(data_L[:, :, 0][H_L - i - 1]) == 0: blank_2_L = (H_L - 1) - i - 1 # assert not sum(data_L[:, :, 0][blank_2_L - 1]) == 0 break print('GT :', blank_1_L, blank_2_L) if blank_1_L == -1 and blank_2_L == H_L - 2: print('No blank', key) U_L1 = 56 D_L2 = 2104 else: U_L1 = ((blank_1_L >> 2) + 1) << 2 D_L2 = ((blank_2_L >> 2) - 1) << 2 print('Content:', U_L1, D_L2) # crop into eight patches U_L1_d = U_L1 + H_dst_L D_L2_u = D_L2 - H_dst_L assert U_L1_d <= H_L assert D_L2_u >= 0 # for idx in range(N): # crop eight part H_list = [U_L1, D_L2_u] W_list = [0, 1024, 2048, 2816] h_list = [U_L1 // 4, D_L2_u // 4] w_list = [0, 256, 512, 704] for h_idx, _ in enumerate(H_list): for w_idx, _ in enumerate(W_list): key_idx = key + '_' + str(4 * h_idx + w_idx) print(key_idx) key_byte = key_idx.encode('ascii') data_gt = data_L[H_list[h_idx]:(H_list[h_idx] + 1024), W_list[w_idx]:(W_list[w_idx] + 1024), :] data_lq = data_S[h_list[h_idx]:(h_list[h_idx] + 256), w_list[w_idx]:(w_list[w_idx] + 256), :] # txn_L.put(key_byte, data_gt.copy(order='C')) # txn_S.put(key_byte, data_lq.copy(order='C')) # if not read_all_imgs and idx_all % BATCH == 0: # txn_L.commit() # txn_S.commit() # txn_L = env_L.begin(write=True) # txn_S = env_S.begin(write=True) # txn_L.commit() # txn_S.commit() # env_L.close() # env_S.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'HDR_{}_wval'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst_L, W_dst_L) meta_info['keys'] = keys_patch pickle.dump(meta_info, open(osp.join(lmdb_save_path_L, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.') # for LQ meta_info = {} meta_info['name'] = 'HDR_{}_wval'.format(mode) channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst_S, W_dst_S) meta_info['keys'] = keys_patch pickle.dump(meta_info, open(osp.join(lmdb_save_path_S, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def general_image_folder(opt): """Create lmdb for general image folders Users should define the keys, such as: '0321_s035' for DIV2K sub-images If all the images have the same resolution, it will only store one copy of resolution info. Otherwise, it will store every resolution info. """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False n_thread = 40 ######################################################## img_folder = opt['img_folder'] lmdb_save_path = opt['lmdb_save_path'] meta_info = {'name': opt['name']} if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = sorted(glob.glob(osp.join(img_folder, '*'))) keys = [] for img_path in all_img_list: keys.append(osp.splitext(osp.basename(img_path))[0]) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) resolutions = [] for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if data.ndim == 2: H, W = data.shape C = 1 else: H, W, C = data.shape txn.put(key_byte, data) resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W)) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information # check whether all the images are the same size assert len(keys) == len(resolutions) if len(set(resolutions)) <= 1: meta_info['resolution'] = [resolutions[0]] meta_info['keys'] = keys print('All images have the same resolution. Simplify the meta info.') else: meta_info['resolution'] = resolutions meta_info['keys'] = keys print( 'Not all images have the same resolution. Save meta info for each image.' ) pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): # options parser = argparse.ArgumentParser() parser.add_argument("-opt", type=str, help="Path to option YAML file.") parser.add_argument("--launcher", choices=["none", "pytorch"], default="none", help="job launcher") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) # distributed training settings if args.launcher == "none": # disabled distributed training opt["dist"] = False rank = -1 print("Disabled distributed training.") else: opt["dist"] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() # loading resume state if exists if opt["path"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state["iter"]) # check resume options else: resume_state = None # mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt["path"] ["experiments_root"]) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key)) # config loggers. Before it, the log will not work util.setup_logger("base", opt["path"]["log"], "train_" + opt["name"], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info("You are using PyTorch {}. \ Tensorboard will use [tensorboardX]".format( version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"]) else: util.setup_logger("base", opt["path"]["log"], "train", level=logging.INFO, screen=True) logger = logging.getLogger("base") # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) # random seed seed = opt["train"]["manual_seed"] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info("Random seed: {}".format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt["batch_size"])) total_iters = int(opt["train"]["niter"]) total_epochs = int(math.ceil(total_iters / train_size)) if opt["dist"]: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), train_size)) logger.info("Total epochs needed: {:d} for iters {:,d}".format( total_epochs, total_iters)) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info("Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set))) else: raise NotImplementedError( "Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None # create model model = create_model(opt) print("Model created!") # resume training if resume_state: logger.info("Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"])) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 # training logger.info("Start training from epoch: {:d}, iter: {:d}".format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt["dist"]: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # update learning rate model.update_learning_rate(current_step, warmup_iter=opt["train"]["warmup_iter"]) # training model.feed_data(train_data) model.optimize_parameters(current_step) # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format( epoch, current_step) for v in model.get_current_learning_rate(): message += "{:.3e},".format(v) message += ")] " for k, v in logs.items(): message += "{:s}: {:.4e} ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if opt["datasets"].get( "val", None) and current_step % opt["train"]["val_freq"] == 0: # image restoration validation if opt["model"] in ["sr", "srgan"] and rank <= 0: # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data["LQ_path"][0]))[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt["scale"]) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update("Test {}".format(img_name)) avg_psnr = avg_psnr / idx # log logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr", avg_psnr, current_step) else: # video restoration validation if opt["dist"]: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data["LQs"].unsqueeze_(0) val_data["GT"].unsqueeze_(0) folder = val_data["folder"] idx_d, max_idx = val_data["idx"].split("/") idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device="cuda") model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr( rlt_img, gt_img) if rank == 0: for _ in range(world_size): pbar.update("Test {} - {}/{}".format( folder, idx_d, max_idx)) # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0.0 for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = "# Validation # PSNR: {:.4e}:".format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += " {}: {:.4e}".format(k, v) logger.info(log_s) if opt["use_tb_logger"] and "debug" not in opt[ "name"]: tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0.0 for val_data in val_loader: folder = val_data["folder"][0] idx_d, max_id = val_data["idx"][0].split("/") # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals["rlt"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 lq_img = util.tensor2img(visuals["LQ"][2]) # uint8 img_dir = opt["path"]["val_images"] util.mkdir(img_dir) save_img_path = os.path.join( img_dir, "{}.png".format(idx_d)) util.save_img(np.hstack((lq_img, rlt_img, gt_img)), save_img_path) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) pbar.update("Test {} - {}".format(folder, idx_d)) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = "# Validation # PSNR: {:.4e}:".format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += " {}: {:.4e}".format(k, v) logger.info(log_s) if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank <= 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") tb_logger.close()
def REDS(mode = 'train_sharp', overwrite = True): '''create lmdb for the REDS dataset, each image with fixed size GT: [3, 720, 1280], key: 000_00000000 LR: [3, 180, 320], key: 000_00000000 key: 000_00000000 ''' #### configurations read_all_imgs = False # whether real all images to the memory. Set False with limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp if mode == 'train_sharp': img_folder = osp.join(root,'datasets/REDS/train/sharp') lmdb_save_path = osp.join(root,'datasets/REDS/train/sharp_wval.lmdb') H_dst, W_dst = 720, 1280 elif mode == 'train_sharp_bicubic': img_folder = osp.join(root,'datasets/REDS/train/sharp_bicubic') lmdb_save_path = osp.join(root,'datasets/REDS/train/sharp_bicubic_wval.lmdb') H_dst, W_dst = 180, 320 elif mode == 'train_blur_bicubic': img_folder = osp.join(root,'datasets/REDS/train/blur_bicubic') lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_bicubic_wval.lmdb') H_dst, W_dst = 180, 320 elif mode == 'train_blur': img_folder = osp.join(root,'datasets/REDS/train/blur') lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_wval.lmdb') H_dst, W_dst = 720, 1280 elif mode == 'train_blur_comp': img_folder = osp.join(root,'datasets/REDS/train/blur_comp') lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_comp_wval.lmdb') H_dst, W_dst = 720, 1280 n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") #### whether the lmdb file exist if not overwrite and osp.exists(lmdb_save_path): print(f'Folder [{lmdb_save_path}] already exists. Exit...') sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images(img_folder) keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') a = split_rlt[-2] b = split_rlt[-1].split('.png')[0] keys.append(a + '_' + b) if read_all_imgs: #### read all images to memory (multiprocessing) dataset = {} # store all image data. list cannot keep the order, use dict print(f'Read images with multiprocessing, #thread: {n_thread} ...') pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): '''get the image data and update pbar''' key = arg[0] dataset[key] = arg[1] pbar.update(f'Reading {key}') pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print(f'Finish reading {len(all_img_list)} images.\nWrite lmdb...') #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes if 'flow' in mode: data_size_per_img = dataset['000_00000002_n1'].nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_img_list, keys): idx = idx + 1 pbar.update(f'Write {key}') key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_{}_wval'.format(mode) if 'flow' in mode: meta_info['resolution'] = f'1_{H_dst}_{W_dst}') else: meta_info['resolution'] = f'{3}_{H_dst}_{W_dst}') meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def vimeo90k(): """create lmdb for the Vimeo90K dataset, each image with fixed size GT: [3, 256, 448] Only need the 4th frame currently, e.g., 00001_0001_4 LR: [3, 64, 112] With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 key: Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 """ #### configurations mode = "GT" # GT | LR read_all_imgs = ( False ) # whether real all images to the memory. Set False with limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == "GT": img_folder = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences" lmdb_save_path = "/home/xtwang/datasets/vimeo90k/vimeo90k_train_GT.lmdb" txt_file = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt" H_dst, W_dst = 256, 448 elif mode == "LR": img_folder = ( "/home/xtwang/datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences" ) lmdb_save_path = "/home/xtwang/datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb" txt_file = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt" H_dst, W_dst = 64, 112 n_thread = 40 ######################################################## if not lmdb_save_path.endswith(".lmdb"): raise ValueError("lmdb_save_path must end with 'lmdb'.") #### whether the lmdb file exist if osp.exists(lmdb_save_path): print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print("Reading image path list ...") with open(txt_file) as f: train_l = f.readlines() train_l = [v.strip() for v in train_l] all_img_list = [] keys = [] for line in train_l: folder = line.split("/")[0] sub_folder = line.split("/")[1] file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + "/*") all_img_list.extend(file_l) for j in range(7): keys.append("{}_{}_{}".format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) if mode == "GT": # read the 4th frame only for GT mode print("Only keep the 4th frame.") all_img_list = [v for v in all_img_list if v.endswith("im4.png")] keys = [v for v in keys if v.endswith("_4")] if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print("Read images with multiprocessing, #thread: {} ...".format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): """get the image data and update pbar""" key = arg[0] dataset[key] = arg[1] pbar.update("Reading {}".format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print("Finish reading {} images.\nWrite lmdb...".format( len(all_img_list))) #### create lmdb environment data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print("data size per image is: ", data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_img_list, keys): idx = idx + 1 pbar.update("Write {}".format(key)) key_byte = key.encode("ascii") data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, "different shape." txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print("Finish writing lmdb.") #### create meta information meta_info = {} if mode == "GT": meta_info["name"] = "Vimeo90K_train_GT" elif mode == "LR": meta_info["name"] = "Vimeo90K_train_LR" meta_info["resolution"] = "{}_{}_{}".format(3, H_dst, W_dst) key_set = set() for key in keys: a, b, _ = key.split("_") key_set.add("{}_{}".format(a, b)) meta_info["keys"] = key_set pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb")) print("Finish creating lmdb meta info.")
def MultiScaleREDS(img_root, lmdb_save_path, scales): """Create lmdb for the REDS dataset with multiple scales """ #### configurations BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) print('Reading image path list ...') # all_img_list = get_paths_from_images(img_folder) scale_folders = sorted(os.listdir(img_root)) all_imgs, all_keys = [], [] resolution = {} for i, folder in enumerate(scale_folders): print('[{:02d}/{:02d}] Reading scale-folder: {:s} ...'.format( i, len(scale_folders), folder)) folder_dir = osp.join(img_root, folder) sub_folders = sorted(os.listdir(folder_dir)) for sub in sub_folders: sub_dir = osp.join(folder_dir, sub) img_names = sorted(os.listdir(sub_dir)) imgs = [osp.join(sub_dir, name) for name in img_names] keys = [folder + '_' + sub + '_' + name[:-4] for name in img_names] all_imgs.extend(imgs) all_keys.extend(keys) resolution[folder] = cv2.imread(imgs[-1]).shape #### create lmdb environment data_size_per_img = cv2.imread(all_imgs[0], cv2.IMREAD_UNCHANGED).nbytes print('max data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_imgs) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) #### write data to lmdb txn = env.begin(write=True) for i in range(0, len(all_imgs), BATCH): imgs = all_imgs[i:i + BATCH] keys = all_keys[i:i + BATCH] batch_data = read_imgs_multi_thread(imgs, keys, n_thread) pbar = util.ProgressBar(len(imgs)) for k, v in batch_data.items(): pbar.update('Write {}'.format(k)) key_byte = k.encode('ascii') txn.put(key_byte, v) txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} meta_info['name'] = 'REDS_X1_X6_wval' channel = 3 meta_info['resolution'] = resolution meta_info['keys'] = all_keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def vimeo90k(mode): """Create lmdb for the Vimeo90K dataset, each image with a fixed size GT: [3, 256, 448] Now only need the 4th frame, e.g., 00001_0001_4 LR: [3, 64, 112] 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7 key: Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001 flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3] Each flow is calculated with GT images by PWCNet and then downsampled by 1/4 Flow map is quantized by mmcv and saved in png format """ #### configurations read_all_imgs = False # whether real all images to memory with multiprocessing # Set False for use limited memory BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'GT': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 256, 448 elif mode == 'LR': img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 64, 112 elif mode == 'flow': img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' H_dst, W_dst = 128, 112 else: raise ValueError('Wrong dataset mode: {}'.format(mode)) n_thread = 40 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with \'lmdb\'.") if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') with open(txt_file) as f: train_l = f.readlines() train_l = [v.strip() for v in train_l] all_img_list = [] keys = [] for line in train_l: folder = line.split('/')[0] sub_folder = line.split('/')[1] all_img_list.extend( glob.glob(osp.join(img_folder, folder, sub_folder, '*'))) if mode == 'flow': for j in range(1, 4): keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j)) keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j)) else: for j in range(7): keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) all_img_list = sorted(all_img_list) keys = sorted(keys) if mode == 'GT': # only read the 4th frame for the GT mode print('Only keep the 4th frame.') all_img_list = [v for v in all_img_list if v.endswith('im4.png')] keys = [v for v in keys if v.endswith('_4')] if read_all_imgs: #### read all images to memory (multiprocessing) dataset = { } # store all image data. list cannot keep the order, use dict print('Read images with multiprocessing, #thread: {} ...'.format( n_thread)) pbar = util.ProgressBar(len(all_img_list)) def mycallback(arg): """get the image data and update pbar""" key = arg[0] dataset[key] = arg[1] pbar.update('Reading {}'.format(key)) pool = Pool(n_thread) for path, key in zip(all_img_list, keys): pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) pool.close() pool.join() print('Finish reading {} images.\nWrite lmdb...'.format( len(all_img_list))) #### write data to lmdb data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) txn = env.begin(write=True) pbar = util.ProgressBar(len(all_img_list)) for idx, (path, key) in enumerate(zip(all_img_list, keys)): pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = dataset[key] if read_all_imgs else cv2.imread( path, cv2.IMREAD_UNCHANGED) if 'flow' in mode: H, W = data.shape assert H == H_dst and W == W_dst, 'different shape.' else: H, W, C = data.shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 0: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information meta_info = {} if mode == 'GT': meta_info['name'] = 'Vimeo90K_train_GT' elif mode == 'LR': meta_info['name'] = 'Vimeo90K_train_LR' elif mode == 'flow': meta_info['name'] = 'Vimeo90K_train_flowx4' channel = 1 if 'flow' in mode else 3 meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst) key_set = set() for key in keys: if mode == 'flow': a, b, _, _ = key.split('_') else: a, b, _ = key.split('_') key_set.add('{}_{}'.format(a, b)) meta_info['keys'] = list(key_set) pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): ############################################ # # set options # ############################################ parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) ############################################ # # distributed training settings # ############################################ if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("Rank:", rank) print("------------------DIST-------------------------") ############################################ # # loading resume state if exists # ############################################ if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None ############################################ # # mkdir and loggers # ############################################ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('base_val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_', level=logging.INFO, screen=True) print("set train log") util.setup_logger('base_val', opt['path']['log'], 'val_', level=logging.INFO, screen=True) print("set val log") logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ############################################ # # create train and val dataloader # ############################################ #### # dataset_ratio = 200 # enlarge the size of each epoch, todo: what it is dataset_ratio = 1 # enlarge the size of each epoch, todo: what it is for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) # total_iters = int(opt['train']['niter']) # total_epochs = int(math.ceil(total_iters / train_size)) total_iters = train_size total_epochs = int(opt['train']['epoch']) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) total_epochs = int(opt['train']['epoch']) if opt['train']['enable'] == False: total_epochs = 1 else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None ############################################ # # create model # ############################################ #### model = create_model(opt) print("Model Created! ") #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 print("Not Resume Training") ############################################ # # training # ############################################ #### #### logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) Avg_train_loss = AverageMeter() # total if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() Avg_train_loss_vmaf = AverageMeter() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() saved_total_loss = 10e10 saved_total_PSNR = -1 for epoch in range(start_epoch, total_epochs): ############################################ # # Start a new epoch # ############################################ # Turn into training mode #model = model.train() # reset total loss Avg_train_loss.reset() current_step = 0 if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() Avg_train_loss_vmaf.reset() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() if opt['dist']: train_sampler.set_epoch(epoch) for train_idx, train_data in enumerate(train_loader): if 'debug' in opt['name']: img_dir = os.path.join(opt['path']['train_images']) util.mkdir(img_dir) LQ = train_data['LQs'] GT = train_data['GT'] GT_img = util.tensor2img(GT) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT')) util.save_img(GT_img, save_img_path) for i in range(5): LQ_img = util.tensor2img(LQ[0, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ', i)) util.save_img(LQ_img, save_img_path) if (train_idx >= 3): break if opt['train']['enable'] == False: message_train_loss = 'None' break current_step += 1 if current_step > total_iters: print("Total Iteration Reached !") break #### update learning rate if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': pass else: model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': # model.optimize_parameters_without_schudlue(current_step) # else: model.optimize_parameters(current_step) if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1) elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) else: Avg_train_loss.update(model.log_dict['l_pix'], 1) # add total train loss if (opt['train']['pixel_criterion'] == 'cb+ssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' vmaf_avg_loss: {:.4e}'.format( Avg_train_loss_vmaf.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'ssim'): message_train_loss = ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'msssim'): message_train_loss = ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) else: message_train_loss = ' train_avg_loss: {:.4e}'.format( Avg_train_loss.avg) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) message += message_train_loss if rank <= 0: logger.info(message) ############################################ # # end of one epoch, save epoch model # ############################################ #### save models and training states # if current_step % opt['logger']['save_checkpoint_freq'] == 0: # if rank <= 0: # logger.info('Saving models and training states.') # model.save(current_step) # model.save('latest') # # model.save_training_state(epoch, current_step) # # todo delete previous weights # previous_step = current_step - opt['logger']['save_checkpoint_freq'] # save_filename = '{}_{}.pth'.format(previous_step, 'G') # save_path = os.path.join(opt['path']['models'], save_filename) # if os.path.exists(save_path): # os.remove(save_path) if epoch == 1: save_filename = '{:04d}_{}.pth'.format(0, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) if rank <= 0: logger.info('Saving models and training states.') save_filename = '{:04d}'.format(epoch) model.save(save_filename) # model.save('latest') # model.save_training_state(epoch, current_step) ############################################ # # end of one epoch, do validation # ############################################ #### validation #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None): if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) #util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # todo : multi-GPU testing psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): print('idx', idx) if 'debug' in opt['name']: if (idx >= 3): break val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) # model.test() # model.test_stitch() if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # visuals = model.get_current_visuals() visuals = model.get_current_visuals( save=True, name='{}_{}'.format(folder, idx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder][idx_d] = psnr # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder][idx_d] = ssim # calculate Val loss val_loss = model.get_loss() val_loss_rlt[folder][idx_d] = val_loss logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, idx, psnr, ssim)) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) for _, v in ssim_rlt.items(): dist.reduce(v, 0) for _, v in val_loss_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # ssim ssim_rlt_avg = {} ssim_total_avg = 0. for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = torch.mean(v).cpu().item() ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # added val_loss_rlt_avg = {} val_loss_total_avg = 0. for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = torch.mean(v).cpu().item() val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: # Todo: our function One GPU pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. for val_inx, val_data in enumerate(val_loader): if 'debug' in opt['name']: if (val_inx >= 5): break folder = val_data['folder'][0] # idx_d = val_data['idx'].item() idx_d = val_data['idx'] # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = [] if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = [] # process the black blank [B N C H W] print(val_data['LQs'].size()) H_S = val_data['LQs'].size(3) # 540 W_S = val_data['LQs'].size(4) # 960 print(H_S) print(W_S) blank_1_S = 0 blank_2_S = 0 print(val_data['LQs'][0, 2, 0, :, :].size()) for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0: blank_1_S = i - 1 # assert not sum(data_S[:, :, 0][i+1]) == 0 break for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, :, H_S - i - 1]) == 0: blank_2_S = (H_S - 1) - i - 1 # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0 break print('LQ :', blank_1_S, blank_2_S) if blank_1_S == -1: print('LQ has no blank') blank_1_S = 0 blank_2_S = H_S # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:] print("LQ", val_data['LQs'].size()) # end of process the black blank model.feed_data(val_data) if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # process blank blank_1_L = blank_1_S << 2 blank_2_L = blank_2_S << 2 print(blank_1_L, blank_2_L) print(model.fake_H.size()) if not blank_1_S == 0: # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:] model.fake_H[:, :, 0:blank_1_L, :] = 0 model.fake_H[:, :, blank_2_L:H_S, :] = 0 # end of # process blank visuals = model.get_current_visuals( save=True, name='{}_{:02d}'.format(folder, val_inx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder].append(ssim) # val loss val_loss = model.get_loss() val_loss_rlt[folder].append(val_loss.item()) logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, val_inx, psnr, ssim)) pbar.update('Test {} - {}'.format(folder, idx_d)) # average PSNR for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average SSIM for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average VMAF # average Val LOSS for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = sum(v) / len(v) val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) # toal validation log message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) # end add if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step) # for k, v in ssim_rlt_avg.items(): # tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) ############################################ # # end of validation, save model # ############################################ # logger.info("Finished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_loss_total_avg: saved_total_loss = val_loss_total_avg #torch.save(model.state_dict(), args.save_path + "/best" + ".pth") model.save('best') logger.info( "Best Weights updated for decreased validation loss") else: logger.info( "Weights Not updated for undecreased validation loss") if saved_total_PSNR <= psnr_total_avg: saved_total_PSNR = psnr_total_avg model.save('bestPSNR') logger.info( "Best Weights updated for increased validation PSNR") else: logger.info( "Weights Not updated for unincreased validation PSNR") ############################################ # # end of one epoch, schedule LR # ############################################ # add scheduler todo if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': for scheduler in model.schedulers: # scheduler.step(val_loss_total_avg) scheduler.step(val_loss_total_avg) if rank <= 0: logger.info('Saving the final model.') model.save('last') logger.info('End of training.') tb_logger.close()
def AI_4K(mode): '''create lmdb for the REDS dataset, each image with fixed size GT: [3, 2160, 3840], key: 000000_000000 LR: [3, 540, 960], key: 000000_000000 key: 000000_00000 ** 记得前面我们的数据结构吗?{子目录名}_{图片名} ''' mode = mode # ** 数据模式: input / gt read_all_imgs = False # whether real all images to the memory. Set False with limited memory BATCH = 500 # After BATCH images, lmdb commits, if read_all_imgs = False if mode == 'input': img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/val/input' # ** 使用相对路径指向我们的数据集的input lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train_input_wval.lmdb' # ** 待会生成的lmdb文件存储的路径 '''原来使用全局路径,我们使用相对路径''' H_dst, W_dst = 540, 960 # 帧的大小:H,W elif mode == 'gt': img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train/gt' # ** 使用相对路径指向我们的数据集的input lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train_gt_wval.lmdb' # ** 待会生成的lmdb文件存储的路径 '''原来使用全局路径,我们使用相对路径''' H_dst, W_dst = 2160, 3840 # 帧的大小:H,W elif mode == 'test': img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/test/gt' # ** 使用相对路径指向我们的数据集的input lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/test_gt_wval.lmdb' # ** 待会生成的lmdb文件存储的路径 '''原来使用全局路径,我们使用相对路径''' H_dst, W_dst = 2160, 3840 ######################################################## if not lmdb_save_path.endswith('.lmdb'): raise ValueError( "lmdb_save_path must end with \'lmdb\'.") # 保存格式必须以“.lmdb”结尾 #### whether the lmdb file exist if osp.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format( lmdb_save_path)) # 文件是否已经存在 sys.exit(1) #### read all the image paths to a list print('Reading image path list ...') all_img_list = data_util._get_paths_from_images( img_folder) # 获取input/gt下所有帧的完整路径名,作为list keys = [] for img_path in all_img_list: split_rlt = img_path.split('/') # 取子文件夹名 xxxxxx a = split_rlt[-2] # 取帧的名字,出去文件后缀 xxxxxx b = split_rlt[-1].split('.png')[0] # ** 我们的图像是".jpg"结尾的 keys.append(a + '_' + b) #### create lmdb environment data_size_per_img = cv2.imread( all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes # 每帧图像大小(byte为单位) print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) # 总的需要多少空间 env = lmdb.open(lmdb_save_path, map_size=data_size * 10) # 索取这么多的比特数 #### write data to lmdb pbar = util.ProgressBar(len(all_img_list)) txn = env.begin(write=True) idx = 1 for path, key in zip(all_img_list, keys): idx = idx + 1 pbar.update('Write {}'.format(key)) key_byte = key.encode('ascii') data = cv2.imread(path, cv2.IMREAD_UNCHANGED) H, W, C = data.shape # fixed shape assert H == H_dst and W == W_dst and C == 3, 'different shape.' txn.put(key_byte, data) if not read_all_imgs and idx % BATCH == 1: txn.commit() txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') #### create meta information # 存储元数据:名字(str)+分辨率(str) meta_info = {} meta_info['name'] = 'OURS_{}_wval'.format(mode) # ** 现在的数据集是OURS了 meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst) meta_info['keys'] = keys pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.')
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### loading resume state if exists if 'resume_latest' in opt and opt['resume_latest'] == True: if os.path.isdir(opt['path']['training_state']): name_state_files = os.listdir(opt['path']['training_state']) if len(name_state_files) > 0: latest_state_num = 0 for name_state_file in name_state_files: state_num = int(name_state_file.split('.')[0]) if state_num > latest_state_num: latest_state_num = state_num opt['path']['resume_state'] = os.path.join( opt['path']['training_state'], str(latest_state_num)+'.state') else: raise ValueError if opt['path'].get('resume_state', None): device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'] + '_{}'.format(util.get_timestamp())) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) train_sampler = None train_loader = create_dataloader( train_set, dataset_opt, opt, train_sampler) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training is_time = False logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) if is_time: batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') for epoch in range(start_epoch, total_epochs + 1): if current_step > total_iters: break if is_time: torch.cuda.synchronize() end = time.time() for _, train_data in enumerate(train_loader): if 'adv_train' in opt: current_step += opt['adv_train']['m'] else: current_step += 1 if current_step > total_iters: break #### training model.feed_data(train_data) if is_time: torch.cuda.synchronize() data_time.update(time.time() - end) model.optimize_parameters(current_step) #### update learning rate model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) if is_time: torch.cuda.synchronize() batch_time.update(time.time() - end) #### log if current_step % opt['logger']['print_freq'] == 0: # FIXME remove debug debug = True if debug: torch.cuda.empty_cache() logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) if is_time: logger.info(str(data_time)) logger.info(str(batch_time)) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan']: # image restoration validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join( opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border( [sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) tb_logger.flush() if is_time: torch.cuda.synchronize() end = time.time() logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()