def get_sr_and_score(imset, model, min_L=16): ''' Super resolves an imset with a given model. Args: imset: imageset model: HRNet, pytorch model min_L: int, pad length Returns: sr: tensor (1, C_out, W, H), super resolved image scPSNR: float, shift cPSNR score ''' if imset.__class__ is ImageSet: collator = collateFunction(min_L=min_L) lrs, alphas, hrs, hr_maps, names = collator([imset]) elif isinstance(imset, tuple): # imset is a tuple of batches lrs, alphas, hrs, hr_maps, names = imset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') lrs = lrs.float().to(device) alphas = alphas.float().to(device) sr = model(lrs, alphas)[:, 0] sr = sr.detach().cpu().numpy()[0] if len(hrs) > 0: scPSNR = shift_cPSNR(sr=np.clip(sr, 0, 1), hr=hrs.numpy()[0], hr_map=hr_maps.numpy()[0]) else: scPSNR = None return sr, scPSNR
def main(config): """ Given a configuration, trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model. Args: config: dict, configuration file """ # Reproducibility options np.random.seed(0) # RNG seeds torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Initialize the network based on the network configuration fusion_model = HRNet(config["network"]) regis_model = ShiftNet() optimizer = optim.Adam(list(fusion_model.parameters()) + list(regis_model.parameters()), lr=config["training"]["lr"]) # optim # ESA dataset data_directory = config["paths"]["prefix"] baseline_cpsnrs = None if os.path.exists(os.path.join(data_directory, "norm.csv")): baseline_cpsnrs = readBaselineCPSNR( os.path.join(data_directory, "norm.csv")) train_set_directories = getImageSetDirectories( os.path.join(data_directory, "train")) val_proportion = config['training']['val_proportion'] train_list, val_list = train_test_split(train_set_directories, test_size=val_proportion, random_state=1, shuffle=True) # Dataloaders batch_size = config["training"]["batch_size"] n_workers = config["training"]["n_workers"] n_views = config["training"]["n_views"] min_L = config["training"]["min_L"] # minimum number of views beta = config["training"]["beta"] train_dataset = ImagesetDataset(imset_dir=train_list, config=config["training"], top_k=n_views, beta=beta) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, collate_fn=collateFunction(min_L=min_L), pin_memory=True) config["training"]["create_patches"] = False val_dataset = ImagesetDataset(imset_dir=val_list, config=config["training"], top_k=n_views, beta=beta) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=collateFunction(min_L=min_L), pin_memory=True) dataloaders = {'train': train_dataloader, 'val': val_dataloader} # Train model torch.cuda.empty_cache() #fusion_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/HRNet.pth")) #regis_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/ShiftNet.pth")) trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders, baseline_cpsnrs, config)
def get_sr_and_score(imset, model, aposterior_gt, next_sr, num_frames, min_L=16): ''' Super resolves an imset with a given model. Args: imset: imageset model: HRNet, pytorch model min_L: int, pad length Returns: sr: tensor (1, C_out, W, H), super resolved image scPSNR: float, shift cPSNR score ''' if imset.__class__ is ImageSet: collator = collateFunction(num_frames, min_L=min_L) lrs, alphas, hrs, hr_maps, names = collator([imset]) elif isinstance(imset, tuple): # imset is a tuple of batches lrs, alphas, hrs, hr_maps, names = imset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #print("LRS SHAPE:", lrs.shape) #print("ALPHAS SHAPE", alphas.shape) #lrs = lrs[:, :num_frames, :, :] #alphas = alphas[:, :num_frames] lrs = lrs.float().to(device) alphas = alphas.float().to(device) sr = model(lrs, alphas)[:, 0] sr = sr.detach().cpu().numpy()[0] sr = np.clip(sr, 0, 1) # sr = downscale_local_mean(sr, (2, 2)) cur_hr = hrs.numpy()[0] cur_hr_map = hr_maps.numpy()[0] cur_sr = sr # cur_hr = downscale_local_mean(cur_hr, (2, 2)) # cur_hr_map = downscale_local_mean(cur_hr_map, (2, 2)) assert (cur_sr.ndim == 2) assert (cur_hr.ndim == 2) assert (cur_hr_map.ndim == 2) if cur_sr.dtype.type is np.uint16: # integer array is in the range [0, 65536] cur_sr = cur_sr / np.iinfo( np.uint16).max # normalize in the range [0, 1] else: assert 0 <= cur_sr.min() and cur_sr.max( ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).' if cur_hr.dtype.type is np.uint16: cur_hr = cur_hr / np.iinfo(np.uint16).max if len(hrs) > 0: val_gt_SSIM = cSSIM(sr=cur_sr, hr=cur_hr) val_L2 = mean_squared_error(cur_hr, cur_sr) else: val_gt_SSIM = None val_L2 = None if (str(type(aposterior_gt)) == "<class 'NoneType'>"): val_aposterior_SSIM = 1.0 else: val_aposterior_SSIM = cSSIM(sr=cur_sr, hr=aposterior_gt) if (str(type(next_sr)) == "<class 'NoneType'>"): val_delta_L2 = None else: assert (next_sr.ndim == 2) val_delta_L2 = mean_squared_error(next_sr, cur_sr) if len(cur_sr.shape) == 2: cur_sr = cur_sr[None, ] cur_hr = cur_hr[None, ] cur_hr_map = cur_hr_map[None, ] if len(hrs) > 0: val_cMSE = cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) val_cPSNR = -10 * np.log10(val_cMSE) val_usual_PSNR = -10 * np.log10(val_L2) val_shift_cPSNR = shift_cPSNR(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) val_shift_cMSE = shift_cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) else: val_cMSE = None val_cPSNR = None val_usual_PSNR = None val_shift_cPSNR = None val_shift_cMSE = None if (str(type(next_sr)) == "<class 'NoneType'>"): val_delta_cMSE = None val_delta_shift_cMSE = None else: if next_sr.dtype.type is np.uint16: # integer array is in the range [0, 65536] next_sr = next_sr / np.iinfo( np.uint16).max # normalize in the range [0, 1] else: assert 0 <= next_sr.min() and next_sr.max( ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).' if len(cur_sr.shape) == 2: next_sr = next_sr[None, ] val_delta_cMSE = cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map) val_delta_shift_cMSE = shift_cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map) return sr, val_gt_SSIM, val_aposterior_SSIM, val_cPSNR, val_usual_PSNR, val_shift_cPSNR, val_cMSE, \ val_L2, val_shift_cMSE, val_delta_cMSE, val_delta_L2, val_delta_shift_cMSE