def get_sr_and_score(imset, model, min_L=16): ''' Super resolves an imset with a given model. Args: imset: imageset model: HRNet, pytorch model min_L: int, pad length Returns: sr: tensor (1, C_out, W, H), super resolved image scPSNR: float, shift cPSNR score ''' if imset.__class__ is ImageSet: collator = collateFunction(min_L=min_L) lrs, alphas, hrs, hr_maps, names = collator([imset]) elif isinstance(imset, tuple): # imset is a tuple of batches lrs, alphas, hrs, hr_maps, names = imset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') lrs = lrs.float().to(device) alphas = alphas.float().to(device) sr = model(lrs, alphas)[:, 0] sr = sr.detach().cpu().numpy()[0] if len(hrs) > 0: scPSNR = shift_cPSNR(sr=np.clip(sr, 0, 1), hr=hrs.numpy()[0], hr_map=hr_maps.numpy()[0]) else: scPSNR = None return sr, scPSNR
def trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders, baseline_cpsnrs, config): """ Trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model. Args: fusion_model: torch.model, HRNet regis_model: torch.model, ShiftNet optimizer: torch.optim, optimizer to minimize loss dataloaders: dict, wraps train and validation dataloaders baseline_cpsnrs: dict, ESA baseline scores config: dict, configuration file """ np.random.seed(123) # seed all RNGs for reproducibility torch.manual_seed(123) num_epochs = config["training"]["num_epochs"] batch_size = config["training"]["batch_size"] n_views = config["training"]["n_views"] min_L = config["training"]["min_L"] # minimum number of views beta = config["training"]["beta"] subfolder_pattern = 'batch_{}_views_{}_min_{}_beta_{}_time_{}'.format( batch_size, n_views, min_L, beta, f"{datetime.datetime.now():%Y-%m-%d-%H-%M-%S-%f}") checkpoint_dir_run = os.path.join(config["paths"]["checkpoint_dir"], subfolder_pattern) os.makedirs(checkpoint_dir_run, exist_ok=True) tb_logging_dir = config['paths']['tb_log_file_dir'] logging_dir = os.path.join(tb_logging_dir, subfolder_pattern) os.makedirs(logging_dir, exist_ok=True) writer = SummaryWriter(logging_dir) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') best_score = 100 P = config["training"]["patch_size"] offset = (3 * config["training"]["patch_size"] - 128) // 2 C = config["training"]["crop"] torch_mask = get_crop_mask(patch_size=P, crop_size=C) torch_mask = torch_mask.to(device) # crop borders (loss) fusion_model.to(device) regis_model.to(device) scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config['training']['lr_decay'], verbose=True, patience=config['training']['lr_step']) for epoch in tqdm(range(1, num_epochs + 1)): # Train fusion_model.train() regis_model.train() train_loss = 0.0 # monitor train loss # Iterate over data. for lrs, alphas, hrs, hr_maps, names in tqdm(dataloaders['train']): optimizer.zero_grad() # zero the parameter gradients lrs = lrs.float().to(device) alphas = alphas.float().to(device) hr_maps = hr_maps.float().to(device) hrs = hrs.float().to(device) # torch.autograd.set_detect_anomaly(mode=True) srs = fusion_model(lrs, alphas) # fuse multi frames (B, 1, 3*W, 3*H) # Register batch wrt HR shifts = register_batch(regis_model, srs[:, :, offset:(offset + 128), offset:(offset + 128)], reference=hrs[:, offset:(offset + 128), offset:(offset + 128)].view( -1, 1, 128, 128)) srs_shifted = apply_shifts(regis_model, srs, shifts, device)[:, 0] # Training loss cropped_mask = hr_maps #torch_mask[0] * hr_maps # Compute current mask (Batch size, W, H) # srs_shifted = torch.clamp(srs_shifted, min=0.0, max=1.0) # correct over/under-shoots loss = -get_loss(srs_shifted, hrs, cropped_mask, metric='cPSNR') loss = torch.mean(loss) loss += config["training"]["lambda"] * torch.mean(shifts)**2 # Backprop loss.backward() optimizer.step() epoch_loss = loss.detach().cpu().numpy() * len(hrs) / len( dataloaders['train'].dataset) train_loss += epoch_loss # Eval fusion_model.eval() val_score = 0.0 # monitor val score for lrs, alphas, hrs, hr_maps, names in dataloaders['val']: lrs = lrs.float().to(device) alphas = alphas.float().to(device) hrs = hrs.numpy() hr_maps = hr_maps.numpy() srs = fusion_model(lrs, alphas)[:, 0] # fuse multi frames (B, 1, 3*W, 3*H) # compute ESA score srs = srs.detach().cpu().numpy() baseline_cpsnrs = None for i in range(srs.shape[0]): # batch size if baseline_cpsnrs is None: val_score -= shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i], hr_maps[i]) #else: #ESA = baseline_cpsnrs[names[i]] #val_score += ESA / shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i], hr_maps[i]) val_score /= len(dataloaders['val'].dataset) #if best_score > val_score: torch.save(fusion_model.state_dict(), os.path.join(checkpoint_dir_run, 'HRNet.pth')) torch.save(regis_model.state_dict(), os.path.join(checkpoint_dir_run, 'ShiftNet.pth')) best_score = val_score writer.add_image('SR Image', (srs[0] - np.min(srs[0])) / np.max(srs[0]), epoch, dataformats='HW') error_map = hrs[0] - srs[0] writer.add_image('Error Map', error_map, epoch, dataformats='HW') writer.add_scalar("train/loss", train_loss, epoch) writer.add_scalar("train/val_loss", val_score, epoch) scheduler.step(val_score) writer.close()
def get_sr_and_score(imset, model, aposterior_gt, next_sr, num_frames, min_L=16): ''' Super resolves an imset with a given model. Args: imset: imageset model: HRNet, pytorch model min_L: int, pad length Returns: sr: tensor (1, C_out, W, H), super resolved image scPSNR: float, shift cPSNR score ''' if imset.__class__ is ImageSet: collator = collateFunction(num_frames, min_L=min_L) lrs, alphas, hrs, hr_maps, names = collator([imset]) elif isinstance(imset, tuple): # imset is a tuple of batches lrs, alphas, hrs, hr_maps, names = imset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #print("LRS SHAPE:", lrs.shape) #print("ALPHAS SHAPE", alphas.shape) #lrs = lrs[:, :num_frames, :, :] #alphas = alphas[:, :num_frames] lrs = lrs.float().to(device) alphas = alphas.float().to(device) sr = model(lrs, alphas)[:, 0] sr = sr.detach().cpu().numpy()[0] sr = np.clip(sr, 0, 1) # sr = downscale_local_mean(sr, (2, 2)) cur_hr = hrs.numpy()[0] cur_hr_map = hr_maps.numpy()[0] cur_sr = sr # cur_hr = downscale_local_mean(cur_hr, (2, 2)) # cur_hr_map = downscale_local_mean(cur_hr_map, (2, 2)) assert (cur_sr.ndim == 2) assert (cur_hr.ndim == 2) assert (cur_hr_map.ndim == 2) if cur_sr.dtype.type is np.uint16: # integer array is in the range [0, 65536] cur_sr = cur_sr / np.iinfo( np.uint16).max # normalize in the range [0, 1] else: assert 0 <= cur_sr.min() and cur_sr.max( ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).' if cur_hr.dtype.type is np.uint16: cur_hr = cur_hr / np.iinfo(np.uint16).max if len(hrs) > 0: val_gt_SSIM = cSSIM(sr=cur_sr, hr=cur_hr) val_L2 = mean_squared_error(cur_hr, cur_sr) else: val_gt_SSIM = None val_L2 = None if (str(type(aposterior_gt)) == "<class 'NoneType'>"): val_aposterior_SSIM = 1.0 else: val_aposterior_SSIM = cSSIM(sr=cur_sr, hr=aposterior_gt) if (str(type(next_sr)) == "<class 'NoneType'>"): val_delta_L2 = None else: assert (next_sr.ndim == 2) val_delta_L2 = mean_squared_error(next_sr, cur_sr) if len(cur_sr.shape) == 2: cur_sr = cur_sr[None, ] cur_hr = cur_hr[None, ] cur_hr_map = cur_hr_map[None, ] if len(hrs) > 0: val_cMSE = cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) val_cPSNR = -10 * np.log10(val_cMSE) val_usual_PSNR = -10 * np.log10(val_L2) val_shift_cPSNR = shift_cPSNR(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) val_shift_cMSE = shift_cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map) else: val_cMSE = None val_cPSNR = None val_usual_PSNR = None val_shift_cPSNR = None val_shift_cMSE = None if (str(type(next_sr)) == "<class 'NoneType'>"): val_delta_cMSE = None val_delta_shift_cMSE = None else: if next_sr.dtype.type is np.uint16: # integer array is in the range [0, 65536] next_sr = next_sr / np.iinfo( np.uint16).max # normalize in the range [0, 1] else: assert 0 <= next_sr.min() and next_sr.max( ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).' if len(cur_sr.shape) == 2: next_sr = next_sr[None, ] val_delta_cMSE = cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map) val_delta_shift_cMSE = shift_cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map) return sr, val_gt_SSIM, val_aposterior_SSIM, val_cPSNR, val_usual_PSNR, val_shift_cPSNR, val_cMSE, \ val_L2, val_shift_cMSE, val_delta_cMSE, val_delta_L2, val_delta_shift_cMSE