def main(): #### setup options of three networks parser = argparse.ArgumentParser() parser.add_argument('-opt_P', type=str, help='Path to option YMAL file of Predictor.') parser.add_argument('-opt_C', type=str, help='Path to option YMAL file of Corrector.') parser.add_argument('-opt_F', type=str, help='Path to option YMAL file of SFTMD_Net.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt_P = option.parse(args.opt_P, is_train=True) opt_C = option.parse(args.opt_C, is_train=True) opt_F = option.parse(args.opt_F, is_train=True) # convert to NoneDict, which returns None for missing keys opt_P = option.dict_to_nonedict(opt_P) opt_C = option.dict_to_nonedict(opt_C) opt_F = option.dict_to_nonedict(opt_F) # choose small opt for SFTMD test, fill path of pre-trained model_F opt_F = opt_F['sftmd'] # create PCA matrix of enough kernel batch_ker = util.random_batch_kernel(batch=30000, l=opt_P['kernel_size'], sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, tensor=False) print('batch kernel shape: {}'.format(batch_ker.shape)) b = np.size(batch_ker, 0) batch_ker = batch_ker.reshape((b, -1)) pca_matrix = util.PCA(batch_ker, k=opt_P['code_length']).float() print('PCA matrix shape: {}'.format(pca_matrix.shape)) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt_P['dist'] = False opt_F['dist'] = False opt_C['dist'] = False rank = -1 print('Disabled distributed training.') else: opt_P['dist'] = True opt_F['dist'] = True opt_C['dist'] = True init_dist() world_size = torch.distributed.get_world_size( ) #Returns the number of processes in the current process group rank = torch.distributed.get_rank( ) #Returns the rank of current process group torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ###### Predictor&Corrector train ###### #### loading resume state if exists if opt_P['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt_P['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt_P, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0-7) if resume_state is None: # Predictor path util.mkdir_and_rename( opt_P['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt_P['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # Corrector path util.mkdir_and_rename( opt_C['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt_C['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt_P['path']['log'], 'train_' + opt_P['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt_P['path']['log'], 'val_' + opt_P['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt_P)) logger.info(option.dict2str(opt_C)) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_P['name']) else: util.setup_logger('base', opt_P['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') #### random seed seed = opt_P['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt_P['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt_P['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt_P['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt_P, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt_P, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None assert val_loader is not None #### create model model_F = create_model(opt_F) #load pretrained model of SFTMD model_P = create_model(opt_P) model_C = create_model(opt_C) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model_P.resume_training( resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt_P['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate, schedulers # model.update_learning_rate(current_step, warmup_iter=opt_P['train']['warmup_iter']) #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_P['scale'], pca_matrix, para_input=opt_P['code_length'], kernel=opt_P['kernel_size'], noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(train_data['GT']) #### training Predictor model_P.feed_data(LR_img, ker_map) model_P.optimize_parameters(current_step) P_visuals = model_P.get_current_visuals() est_ker_map = P_visuals['Batch_est_ker_map'] #### log of model_P if current_step % opt_P['logger']['print_freq'] == 0: logs = model_P.get_current_log() message = 'Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_P.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### training Corrector for step in range(opt_C['step']): # test SFTMD for corresponding SR image model_F.feed_data(train_data, LR_img, est_ker_map) model_F.test() F_visuals = model_F.get_current_visuals() SR_img = F_visuals['Batch_SR'] # Test SFTMD to produce SR images # train corrector given SR image and estimated kernel map model_C.feed_data(SR_img, est_ker_map, ker_map) model_C.optimize_parameters(current_step) C_visuals = model_C.get_current_visuals() est_ker_map = C_visuals['Batch_est_ker_map'] #### log of model_C if current_step % opt_C['logger']['print_freq'] == 0: logs = model_C.get_current_log() message = 'Corrector <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_C.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_C['use_tb_logger'] and 'debug' not in opt_C[ 'name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation, to produce ker_map_list(fake) if current_step % opt_P['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): prepro = util.SRMDPreprocessing( opt_P['scale'], pca_matrix, para_input=opt_P['code_length'], kernel=opt_P['kernel_size'], noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(val_data['GT']) single_img_psnr = 0.0 # valid Predictor model_P.feed_data(LR_img, ker_map) model_P.test() P_visuals = model_P.get_current_visuals() est_ker_map = P_visuals['Batch_est_ker_map'] for step in range(opt_C['step']): step += 1 idx += 1 model_F.feed_data(val_data, LR_img, est_ker_map) model_F.test() F_visuals = model_F.get_current_visuals() SR_img = F_visuals['Batch_SR'] # Test SFTMD to produce SR images model_C.feed_data(SR_img, est_ker_map, ker_map) model_C.test() C_visuals = model_C.get_current_visuals() est_ker_map = C_visuals['Batch_est_ker_map'] sr_img = util.tensor2img(F_visuals['SR']) # uint8 gt_img = util.tensor2img(F_visuals['GT']) # uint8 # Save SR images for reference img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt_P['path']['val_images'], img_name) # img_dir = os.path.join(opt_F['path']['val_images'], str(current_step), '_', str(step)) util.mkdir(img_dir) save_img_path = os.path.join( img_dir, '{:s}_{:d}_{:d}.png'.format( img_name, current_step, step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt_P['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] step_psnr = util.calculate_psnr( cropped_sr_img * 255, cropped_gt_img * 255) logger.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, psnr: {:.4f}' .format(epoch, current_step, step, img_name, step_psnr)) single_img_psnr += step_psnr avg_psnr += util.calculate_psnr( cropped_sr_img * 255, cropped_gt_img * 255) avg_signle_img_psnr = single_img_psnr / step logger.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, average psnr: {:.4f}' .format(epoch, current_step, step, img_name, avg_signle_img_psnr)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4f}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info( '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> psnr: {:.4f}'. format(epoch, current_step, step, avg_psnr)) # tensorboard logger if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt_P['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model_P.save(current_step) model_P.save_training_state(epoch, current_step) model_C.save(current_step) model_C.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model_P.save('latest') model_C.save('latest') logger.info('End of Predictor and Corrector training.') tb_logger.close()
def SFTMD_train(opt_F, rank, world_size, pca_matrix): #### loading resume state if exists if opt_F['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load(opt_F['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt_F, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: if resume_state is None: util.mkdir_and_rename( opt_F['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt_F['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt_F['path']['log'], 'train_' + opt_F['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt_F['path']['log'], 'val_' + opt_F['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt_F)) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_F['name']) else: util.setup_logger('base', opt_F['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt_F['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt_F['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt_F['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt_F, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt_F, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None assert val_loader is not None #### create model model_F = create_model(opt_F) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model_F.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt_F['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=10, kernel=21, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(train_data['GT']) #### update learning rate, schedulers model_F.update_learning_rate(current_step, warmup_iter=opt_F['train']['warmup_iter']) #### training model_F.feed_data(train_data, LR_img, ker_map) model_F.optimize_parameters(current_step) #### log if current_step % opt_F['logger']['print_freq'] == 0: logs = model_F.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_F.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if current_step % opt_F['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): idx += 1 #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=15, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(val_data['GT']) model_F.feed_data(val_data, LR_img, ker_map) model_F.test() visuals = model_F.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] #img_dir = os.path.join(opt_F['path']['val_images'], img_name) img_dir = os.path.join(opt_F['path']['val_images'], str(current_step)) util.mkdir(img_dir) save_img_path = os.path.join(img_dir,'{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt_F['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(epoch, current_step, avg_psnr)) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt_F['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model_F.save(current_step) model_F.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model_F.save('latest') logger.info('End of SFTMD training.')
def generate_mod_LR_bic(): # set parameters up_scale = 4 mod_scale = 4 # set data dir sourcedir = '/mnt/yjchai/SR_data/BSDS100' #'/mnt/yjchai/SR_data/DIV2K_test_HR' #'/mnt/yjchai/SR_data/Flickr2K/Flickr2K_HR' savedir = '/mnt/yjchai/SR_data/BSDS100_test' #'/mnt/yjchai/SR_data/DIV2K_test' #'/mnt/yjchai/SR_data/Flickr2K_train' # set random seed util.set_random_seed(0) # load PCA matrix of enough kernel print('load PCA matrix') pca_matrix = torch.load('/media/sdc/yjchai/IKC/codes/pca_matrix.pth') print('PCA matrix shape: {}'.format(pca_matrix.shape)) saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) saveLRblurpath = os.path.join(savedir, 'LRblur', 'x' + str(up_scale)) if not os.path.isdir(sourcedir): print('Error: No source data found') exit(0) if not os.path.isdir(savedir): os.mkdir(savedir) if not os.path.isdir(os.path.join(savedir, 'HR')): os.mkdir(os.path.join(savedir, 'HR')) if not os.path.isdir(os.path.join(savedir, 'LR')): os.mkdir(os.path.join(savedir, 'LR')) if not os.path.isdir(os.path.join(savedir, 'Bic')): os.mkdir(os.path.join(savedir, 'Bic')) if not os.path.isdir(os.path.join(savedir, 'LRblur')): os.mkdir(os.path.join(savedir, 'LRblur')) if not os.path.isdir(saveHRpath): os.mkdir(saveHRpath) else: print('It will cover ' + str(saveHRpath)) if not os.path.isdir(saveLRpath): os.mkdir(saveLRpath) else: print('It will cover ' + str(saveLRpath)) if not os.path.isdir(saveBicpath): os.mkdir(saveBicpath) else: print('It will cover ' + str(saveBicpath)) if not os.path.isdir(saveLRblurpath): os.mkdir(saveLRblurpath) else: print('It will cover ' + str(saveLRblurpath)) filepaths = sorted( [f for f in os.listdir(sourcedir) if f.endswith('.png')]) print(filepaths) num_files = len(filepaths) kernel_map_tensor = torch.zeros( (num_files, 1, 10)) # each kernel map: 1*10 # prepare data with augementation for i in range(num_files): filename = filepaths[i] print('No.{} -- Processing {}'.format(i, filename)) # read image image = cv2.imread(os.path.join(sourcedir, filename)) width = int(np.floor(image.shape[1] / mod_scale)) height = int(np.floor(image.shape[0] / mod_scale)) # modcrop if len(image.shape) == 3: image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] else: image_HR = image[0:mod_scale * height, 0:mod_scale * width] img_HR = util.img2tensor(image_HR) C, H, W = img_HR.size() # LR_blur, by random gaussian kernel prepro = util.SRMDPreprocessing(up_scale, pca_matrix, para_input=10, kernel=21, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(img_HR.view(1, C, H, W)) image_LR_blur = util.tensor2img(LR_img) # LR image_LR = imresize_np(image_HR, 1 / up_scale, True) # bic image_Bic = imresize_np(image_LR, up_scale, True) cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) cv2.imwrite(os.path.join(saveLRblurpath, filename), image_LR_blur) kernel_map_tensor[i] = ker_map # save dataset corresponding kernel maps torch.save(kernel_map_tensor, './BSDS100_kermap.pth') print("Image Blurring & Down smaple Done: X" + str(up_scale))
def generate_mod_LR_bic(): # set parameters up_scale = 4 mod_scale = 4 # set data dir sourcedir = "/data/DIV2K_public/gt_k_x4" #'/mnt/yjchai/SR_data/DIV2K_test_HR' #'/mnt/yjchai/SR_data/Flickr2K/Flickr2K_HR' savedir = "/data/DIV2KRK_public/x4HRblur.lmdb" #'/mnt/yjchai/SR_data/DIV2K_test' #'/mnt/yjchai/SR_data/Flickr2K_train' # set random seed util.set_random_seed(0) # load PCA matrix of enough kernel print("load PCA matrix") pca_matrix = torch.load( "/data/IKC/pca_aniso_matrix.pth", map_location=lambda storage, loc: storage ) print("PCA matrix shape: {}".format(pca_matrix.shape)) saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale)) saveLRpath = os.path.join(savedir, "LR", "x" + str(up_scale)) saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale)) saveLRblurpath = os.path.join(savedir, "LRblur", "x" + str(up_scale)) if not os.path.isdir(sourcedir): print("Error: No source data found") exit(0) if not os.path.isdir(savedir): os.mkdir(savedir) if not os.path.isdir(os.path.join(savedir, "HR")): os.mkdir(os.path.join(savedir, "HR")) if not os.path.isdir(os.path.join(savedir, "LR")): os.mkdir(os.path.join(savedir, "LR")) if not os.path.isdir(os.path.join(savedir, "Bic")): os.mkdir(os.path.join(savedir, "Bic")) if not os.path.isdir(os.path.join(savedir, "LRblur")): os.mkdir(os.path.join(savedir, "LRblur")) if not os.path.isdir(saveHRpath): os.mkdir(saveHRpath) else: print("It will cover " + str(saveHRpath)) if not os.path.isdir(saveLRpath): os.mkdir(saveLRpath) else: print("It will cover " + str(saveLRpath)) if not os.path.isdir(saveBicpath): os.mkdir(saveBicpath) else: print("It will cover " + str(saveBicpath)) if not os.path.isdir(saveLRblurpath): os.mkdir(saveLRblurpath) else: print("It will cover " + str(saveLRblurpath)) filepaths = sorted([f for f in os.listdir(sourcedir) if f.endswith(".png")]) print(filepaths) num_files = len(filepaths) # kernel_map_tensor = torch.zeros((num_files, 1, 10)) # each kernel map: 1*10 # prepare data with augementation for i in range(num_files): filename = filepaths[i] print("No.{} -- Processing {}".format(i, filename)) # read image image = cv2.imread(os.path.join(sourcedir, filename)) width = int(np.floor(image.shape[1] / mod_scale)) height = int(np.floor(image.shape[0] / mod_scale)) # modcrop if len(image.shape) == 3: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width, :] else: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width] # LR_blur, by random gaussian kernel img_HR = util.img2tensor(image_HR) C, H, W = img_HR.size() # sig_list = [1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2] # # sig = 2.6 for sig in np.linspace(1.8, 3.2, 8): prepro = util.SRMDPreprocessing( up_scale, pca_matrix, random=True, para_input=10, kernel=11, noise=False, cuda=True, sig=0, sig_min=0.6, sig_max=5, rate_iso=0, scaling=3, rate_cln=0.2, noise_high=0.0, ) # random(sig_min, sig_max) | stable kernel(sig) LR_img, ker_map = prepro(img_HR.view(1, C, H, W)) image_LR_blur = util.tensor2img(LR_img) cv2.imwrite(os.path.join(saveLRblurpath, 'sig{}_{}'.format(sig,filename)), image_LR_blur) cv2.imwrite(os.path.join(saveHRpath, 'sig{}_{}'.format(sig,filename)), image_HR) # LR image_LR = imresize_np(image_HR, 1 / up_scale, True) # bic image_Bic = imresize_np(image_LR, up_scale, True) # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) # kernel_map_tensor[i] = ker_map # save dataset corresponding kernel maps # torch.save(kernel_map_tensor, './Set5_sig2.6_kermap.pth') print("Image Blurring & Down smaple Done: X" + str(up_scale))
def main(): #### setup options of three networks parser = argparse.ArgumentParser() parser.add_argument("-opt", type=str, help="Path to option YMAL file of Predictor.") parser.add_argument("--launcher", choices=["none", "pytorch"], default="none", help="job launcher") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) # choose small opt for SFTMD test, fill path of pre-trained model_F #### set random seed seed = opt["train"]["manual_seed"] if seed is None: seed = random.randint(1, 10000) util.set_random_seed(seed) # load PCA matrix of enough kernel print("load PCA matrix") pca_matrix = torch.load(opt["pca_matrix_path"], map_location=lambda storage, loc: storage) print("PCA matrix shape: {}".format(pca_matrix.shape)) #### distributed training settings if args.launcher == "none": # disabled distributed training opt["dist"] = False opt["dist"] = False rank = -1 print("Disabled distributed training.") else: opt["dist"] = True opt["dist"] = True init_dist() world_size = ( torch.distributed.get_world_size() ) # Returns the number of processes in the current process group rank = torch.distributed.get_rank( ) # Returns the rank of current process group torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ###### Predictor&Corrector train ###### #### loading resume state if exists if opt["path"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) option.check_resume(opt, resume_state["iter"]) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0-7) if resume_state is None: # Predictor path util.mkdir_and_rename( opt["path"] ["experiments_root"]) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key)) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( "You are using PyTorch {}. Tensorboard will use [tensorboardX]" .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir="log/tb_logger/" + opt["name"]) else: util.setup_logger("base", opt["path"]["log"], "train", level=logging.INFO, screen=True) logger = logging.getLogger("base") torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt["batch_size"])) total_iters = int(opt["train"]["niter"]) total_epochs = int(math.ceil(total_iters / train_size)) if opt["dist"]: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), train_size)) logger.info("Total epochs needed: {:d} for iters {:,d}".format( total_epochs, total_iters)) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info("Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set))) else: raise NotImplementedError( "Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None #### create model model = create_model(opt) # load pretrained model of SFTMD #### resume training if resume_state: logger.info("Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"])) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 prepro = util.SRMDPreprocessing( opt["scale"], pca_matrix, random=True, para_input=opt["code_length"], kernel=opt["kernel_size"], noise=False, cuda=True, sig=None, sig_min=opt["sig_min"], sig_max=opt["sig_max"], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0, ) #### training logger.info("Start training from epoch: {:d}, iter: {:d}".format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt["dist"]: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### preprocessing for LR_img and kernel map LR_img, ker_map = prepro(train_data["GT"]) LR_img = (LR_img * 255).round() / 255 #### training Predictor model.feed_data(LR_img, train_data["GT"], ker_map) model.optimize_parameters(current_step) model.update_learning_rate(current_step, warmup_iter=opt["train"]["warmup_iter"]) visuals = model.get_current_visuals() #### log of model_P if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = "Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += "{:s}: {:.4e} ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation, to produce ker_map_list(fake) if current_step % opt["train"]["val_freq"] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): # LR_img, ker_map = prepro(val_data['GT']) LR_img = val_data["LQ"] lr_img = util.tensor2img( LR_img) # save LR image for reference # valid Predictor model.feed_data(LR_img, val_data["GT"]) model.test() visuals = model.get_current_visuals() # Save images for reference img_name = os.path.splitext( os.path.basename(val_data["LQ_path"][0]))[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) # img_dir = os.path.join(opt['path']['val_images'], str(current_step), '_', str(step)) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["SR"]) # uint8 gt_img = util.tensor2img(visuals["GT"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt["scale"] gt_img = gt_img / 255.0 sr_img = sr_img / 255.0 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) idx += 1 avg_psnr = avg_psnr / idx # log logger.info("# Validation # PSNR: {:.6f}".format(avg_psnr)) logger_val = logging.getLogger("val") # validation logger logger_val.info( "<epoch:{:3d}, iter:{:8,d}, psnr: {:.6f}".format( epoch, current_step, avg_psnr)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr", avg_psnr, current_step) #### save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank <= 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of Predictor and Corrector training.") tb_logger.close()
test_results['psnr_y'] = [] test_results['ssim_y'] = [] for test_data in test_loader: need_GT = False if test_loader.dataset.opt[ 'dataroot_GT'] is None else True img_path = test_data['GT_path'][0] if need_GT else test_data[ 'LQ_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=15, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(test_data['GT']) model_F.feed_data(test_data, LR_img, ker_map) model_F.test() F_visuals = model_F.get_current_visuals() sr_img = util.tensor2img(F_visuals['SR']) # uint8 # save images
def train(self): self.scheduler.step() self.loss.step() epoch = self.scheduler.last_epoch + 1 # lr stepwise if epoch <= self.args.epochs_encoder: lr = self.args.lr_encoder * (self.args.gamma_encoder** (epoch // self.args.lr_decay_encoder)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr else: lr = self.args.lr_sr * (self.args.gamma_sr**( (epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format( epoch, Decimal(lr))) self.loss.start_log() self.model.train() degrade = util.SRMDPreprocessing(self.scale[0], kernel_size=self.args.blur_kernel, blur_type=self.args.blur_type, sig_min=self.args.sig_min, sig_max=self.args.sig_max, lambda_min=self.args.lambda_min, lambda_max=self.args.lambda_max, noise=self.args.noise) timer = utility.timer() losses_contrast, losses_sr = utility.AverageMeter( ), utility.AverageMeter() for batch, (hr, _, idx_scale) in enumerate(self.loader_train): hr = hr.cuda() # b, n, c, h, w lr, b_kernels = degrade(hr) # bn, c, h, w self.optimizer.zero_grad() timer.tic() # forward ## train degradation encoder if epoch <= self.args.epochs_encoder: _, output, target = self.model_E(im_q=lr[:, 0, ...], im_k=lr[:, 1, ...]) loss_constrast = self.contrast_loss(output, target) loss = loss_constrast losses_contrast.update(loss_constrast.item()) ## train the whole network else: sr, output, target = self.model(lr) loss_SR = self.loss(sr, hr[:, 0, ...]) loss_constrast = self.contrast_loss(output, target) loss = loss_constrast + loss_SR losses_sr.update(loss_SR.item()) losses_contrast.update(loss_constrast.item()) # backward loss.backward() self.optimizer.step() timer.hold() if epoch <= self.args.epochs_encoder: if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('Epoch: [{:03d}][{:04d}/{:04d}]\t' 'Loss [contrastive loss: {:.3f}]\t' 'Time [{:.1f}s]'.format( epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), losses_contrast.avg, timer.release())) else: if (batch + 1) % self.args.print_every == 0: self.ckp.write_log( 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 'Loss [SR loss:{:.3f} | contrastive loss: {:.3f}]\t' 'Time [{:.1f}s]'.format( epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), losses_sr.avg, losses_contrast.avg, timer.release(), )) self.loss.end_log(len(self.loader_train)) # save model target = self.model.get_model() model_dict = target.state_dict() keys = list(model_dict.keys()) for key in keys: if 'E.encoder_k' in key or 'queue' in key: del model_dict[key] torch.save( model_dict, os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch)))
def test(self): self.ckp.write_log('\nEvaluation:') self.ckp.add_log(torch.zeros(1, len(self.scale))) self.model.eval() timer_test = utility.timer() with torch.no_grad(): for idx_scale, scale in enumerate(self.scale): self.loader_test.dataset.set_scale(idx_scale) eval_psnr = 0 eval_ssim = 0 degrade = util.SRMDPreprocessing( self.scale[0], kernel_size=self.args.blur_kernel, blur_type=self.args.blur_type, sig=self.args.sig, lambda_1=self.args.lambda_1, lambda_2=self.args.lambda_2, theta=self.args.theta, noise=self.args.noise) for idx_img, (hr, filename, _) in enumerate(self.loader_test): hr = hr.cuda() # b, 1, c, h, w hr = self.crop_border(hr, scale) lr, _ = degrade(hr, random=False) # b, 1, c, h, w hr = hr[:, 0, ...] # b, c, h, w # inference timer_test.tic() sr = self.model(lr[:, 0, ...]) timer_test.hold() sr = utility.quantize(sr, self.args.rgb_range) hr = utility.quantize(hr, self.args.rgb_range) # metrics eval_psnr += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, benchmark=self.loader_test.dataset.benchmark) eval_ssim += utility.calc_ssim( sr, hr, scale, benchmark=self.loader_test.dataset.benchmark) # save results if self.args.save_results: save_list = [sr] filename = filename[0] self.ckp.save_results(filename, save_list, scale) self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test) self.ckp.write_log( '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f}'.format( self.args.resume, self.args.data_test, scale, eval_psnr / len(self.loader_test), eval_ssim / len(self.loader_test), ))
def degradation_model(): # set parameters scale = 4 mod_scale = 4 # set data dir sourcedir = 'D:/NEU/ImageRestoration/div2k/Set5/HR/' savedir = 'D:/NEU/ImageRestoration/div2k/Set5/' # set random seed和 util.set_random_seed(0) # load PCA matrix of enough kernelh # print('load PCA matrix') # pca_matrix = torch.load('/media/sdc/yjchai/IKC/codes/pca_matrix.pth', map_location=lambda storage, loc: storage) # print('PCA matrix shape: {}'.format(pca_matrix.shape)) # 初始化kernelmap batch_ker = util.random_batch_kernel(batch=1000, k=15, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, tensor=False) print('batch kernel shape: {}'.format(batch_ker.shape)) b = np.size(batch_ker, 0) batch_ker = batch_ker.reshape((b, -1)) dim_pca = 15 # torch.Size([225, 15]) pca_matrix = util.PCA(batch_ker, dim_pca).float() print('PCA matrix shape: {}'.format(pca_matrix.shape)) # saveHRpath = os.path.join(savedir, 'HR') # saveLRpath = os.path.join(savedir, 'LR', 'x' + str(scale)) # saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(scale)) saveLRblurpath = os.path.join(savedir, 'LRblur_sig0.6', 'X' + str(scale)) if not os.path.isdir(sourcedir): print('Error: No source data found') exit(0) if not os.path.isdir(savedir): os.mkdir(savedir) # if not os.path.isdir(os.path.join(savedir, 'HR')): # os.mkdir(os.path.join(savedir, 'HR')) # if not os.path.isdir(os.path.join(savedir, 'LR')): # os.mkdir(os.path.join(savedir, 'LR')) # if not os.path.isdir(os.path.join(savedir, 'Bic')): # os.mkdir(os.path.join(savedir, 'Bic')) if not os.path.isdir(os.path.join(savedir, 'LRblur_sig0.6')): os.mkdir(os.path.join(savedir, 'LRblur_sig0.6')) # if not os.path.isdir(saveHRpath): # os.mkdir(saveHRpath) # else: # print('It will cover ' + str(saveHRpath)) # if not os.path.isdir(saveLRpath): # os.mkdir(saveLRpath) # else: # print('It will cover ' + str(saveLRpath)) # if not os.path.isdir(saveBicpath): # os.mkdir(saveBicpath) # else: # print('It will cover ' + str(saveBicpath)) if not os.path.isdir(saveLRblurpath): os.mkdir(saveLRblurpath) else: print('It will cover ' + str(saveLRblurpath)) # 记得更改图片格式,不然无法生成 filepaths = sorted( [f for f in os.listdir(sourcedir) if f.endswith('.bmp')]) num_files = len(filepaths) # prepare data with augementation for i in range(num_files): filename = filepaths[i] print('No.{} -- Processing {}'.format((i + 1), filename)) # read image image_HR = cv2.imread(os.path.join(sourcedir, filename)) image = cv2.imread(os.path.join(sourcedir, filename)) width = int(np.floor(image.shape[1] / mod_scale)) height = int(np.floor(image.shape[0] / mod_scale)) # modcrop if len(image.shape) == 3: image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] else: image_HR = image[0:mod_scale * height, 0:mod_scale * width] # LR_blur, by random gaussian kernel img_HR = util.img2tensor(image_HR) C, H, W = img_HR.size() # sig_list = [1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2] sig = 0.6 prepro = util.SRMDPreprocessing( scale, pca_matrix, random=False, kernel=15, noise=False, cuda=True, sig=sig, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=4, rate_cln=0.2, noise_high=0.0) #random(sig_min, sig_max) | stable kernel(sig) LR_img, ker_map = prepro(img_HR.view(1, C, H, W)) image_LR_blur = util.tensor2img(LR_img) cv2.imwrite( os.path.join(saveLRblurpath, 'sig{}_'.format(str(sig)) + filename), image_LR_blur) # LR # image_LR = imresize_np(image_HR, 1 / scale, True) # bic # image_Bic = imresize_np(image_LR, scale, True) # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) # cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) # cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) # kernel_map_tensor[i] = ker_map # save dataset corresponding kernel maps # 1016windows下路径不可含特殊符号, # torch.save(kernel_map_tensor, 'D:/NEU/ImageRestoration/datasets/try/try_sig2.6_kermap.pth') print("Image Blurring & Down smaple Done: X" + str(scale))
test_results['ssim_y'] = [] for test_data in test_loader: need_GT = False if test_loader.dataset.opt[ 'dataroot_GT'] is None else True img_path = test_data['GT_path'][0] if need_GT else test_data[ 'LQ_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing( opt_F['scale'], pca_matrix, random=False, para_input=opt_F['code_length'], noise=False, cuda=True, sig=opt_F['sig'], sig_min=opt_F['sig_min'], sig_max=opt_F['sig_max'], rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) # random(sig_min, sig_max) | stable kernel(sig) LR_img, ker_map = prepro(test_data['GT']) model_F.feed_data(test_data, LR_img, ker_map) model_F.test() F_visuals = model_F.get_current_visuals() sr_img = util.tensor2img(F_visuals['SR']) # uint8