def test_blockmatching_usage(): z, sigma = generate_input_data() z2 = np.ones(z.shape) res_cr, bm = bm3d(z2, sigma, blockmatches=(True, True)) res1 = bm3d(z, sigma) res2 = bm3d(z, sigma, blockmatches=bm) assert np.max(np.abs(res1 - res2)) > ALLOWED_ERROR_SAME
def test_consistency_traditional(): z, sigma = generate_input_data() p = BM3DProfile() p.nf = 0 res1 = bm3d(np.copy(z), np.copy(sigma), p) res2 = bm3d(np.copy(z), np.copy(sigma), p) assert np.max(np.abs(res2 - res1)) < ALLOWED_ERROR_SAME
def BM3D_proj(S1, S2, lambda_this): """ This function does the BM3D projection """ # BM3D denoising S1 = bm3d(S1, lambda_this) S2 = bm3d(S2, lambda_this) return S1, S2
def test_blockmatches_wie_inconsistency(): z, sigma = generate_input_data() z2, sigma = generate_input_data(2) res_ref, bms = bm3d(z2, sigma, blockmatches=(True, True)) res_ht = bm3d(z, sigma, stage_arg=BM3DStages.HARD_THRESHOLDING) res1 = bm3d(z, sigma, stage_arg=res_ht) res2 = bm3d(z, sigma, stage_arg=res_ht, blockmatches=bms) assert np.max(np.abs(res1 - res2)) != 0
def bm3d_proj(S, lambda_this): """ This function does the TV projection """ S1 = np.reshape(S[0, :], image_size) S2 = np.reshape(S[1, :], image_size) # TV denoising S1 = bm3d(S1, lambda_this) S2 = bm3d(S2, lambda_this) S[0, :] = np.reshape(S1, (1, n * n)) S[1, :] = np.reshape(S2, (1, n * n)) return S
def BM3D_proj(S, image_size, lambda_this): """ This function does the BM3D projection """ S1 = np.reshape(S[0, :], image_size) S2 = np.reshape(S[1, :], image_size) # BM3D denoising S1 = bm3d(S1, lambda_this) S2 = bm3d(S2, lambda_this) S[0, :] = np.reshape(S1, (1, image_size[0] * image_size[1])) S[1, :] = np.reshape(S2, (1, image_size[0] * image_size[1])) return S
def calc(self, projs, theta, sart_plot=False): image_r = super(SartBM3DReconstructor, self).calc(projs, theta, sart_plot=sart_plot) #denoise with tv self.image_r = bm3d(image_r, self.bm3d_sigma) return self.image_r
def bm3d_denoise(fn, save_audio, plot): power_sp = np.load(fn) db = np.log10(power_sp) z = np.atleast_3d(db) db_est = bm3d(z, np.sqrt(smoothing_factor), profile=profile) save_files(db, db_est, fn, power_sp, save_audio=save_audio, plot=plot)
def fast_hyde_eigen_image_denoising(img, k_subspace, r_w, e, eigen_y, n) -> np.ndarray: # %% --------------------------Eigen-image denoising ------------------------------------ # send slices of the image to the GPU if that is the case, rows, cols, b = img.shape np_dtype = np.float32 if img.dtype is torch.float32 else np.float64 eigen_y_bm3d = np.empty((k_subspace, n), dtype=np_dtype) ecpu = e.to(device="cpu", non_blocking=True) r_w = r_w.to(device="cpu", non_blocking=True) nxt_eigen = eigen_y[0].cpu() mx = min(k_subspace, eigen_y.shape[0]) for i in range(mx): lp_eigen = nxt_eigen.numpy() if i < mx - 1: nxt_eigen = eigen_y[i + 1].to(device="cpu", non_blocking=True) # produce eigen-image eigen_im = lp_eigen min_x = np.min(eigen_im) max_x = np.max(eigen_im) eigen_im -= min_x scale = max_x - min_x # normalize eigen_im eigen_im = np.reshape(eigen_im, (rows, cols)) / scale if i == 0: ecpu = ecpu.numpy() r_w = r_w.numpy() sigma = np.sqrt(ecpu[:, i].T @ r_w @ ecpu[:, i]) / scale filt_eigen_im = bm3d.bm3d(eigen_im, sigma) eigen_y_bm3d[i, :] = (filt_eigen_im * scale + min_x).reshape( eigen_y_bm3d[i, :].shape) return eigen_y_bm3d
def BM3D(noisy_images: np.ndarray, noise_std_dev: float) -> np.ndarray: """ Params: noisy_images: receive noisy_images with shape (IMG, M, N, C), where IMG is the quantity of noisy images with dimensions MxN and C represents the color dimensions, i.e. if the images are colored, C = 3 (RGB), otherwise if C = 1, is grayscale. noise_std_dev: the standart deviation from noise. """ validate_array_input(noisy_images) validate_if_noise_std_dev_is_a_float(noise_std_dev) filtered_images = [] for i in range(noisy_images.shape[0]): filtered_images.append( bm3d.bm3d( noisy_images[i, :,:,:], sigma_psd=noise_std_dev, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING ) ) filtered_images = np.array(filtered_images) return filtered_images
def filterBM3D(sLayers): optimg=copy.copy(sLayers[:]) for i in range(len(sLayers)): copySAR = copy.copy(sLayers[:][i]) sarArrayI = np.array(copySAR) denoised_image = bm3d.bm3d(sLayers[i], sigma_psd=4, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING) optimg[i] = denoised_image return(optimg)
def bm3dPrior(m_, simul_params, pnp_params): ngx = simul_params["ngx"] ngy = simul_params["ngy"] sigma = pnp_params["dn_sigma"] m = m_.reshape(ngy, ngx) m_denoise = bm3d(m, sigma) m_denoise = m_denoise.reshape(ngy * ngx, 1) return m_denoise
def bm3d_fil(): #Block Matching & 3D Filter works fine if locerror['text']=="": messagebox.showerror("Error","Please choose the locattion") else: img=cv2.imread(dir,1) img1=img_as_float(img) bm3d_img=bm3d.bm3d(img1, sigma_psd=0.05,stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING) cv2.imshow("Block Matching & 3D Filter",bm3d_img) cv2.waitKey() cv2.destroyAllWindows()
def run(self, rgb_img, nr_method, nr_parm, no_blend, mask_level, blend_morph_kernel, blend_alpha=0.5): y, u, v = self.yuv_split(rgb_img) laplacian_pyramid = self.lap_pyr(y) # rec = self.lap_pyr_rec(laplacian_pyramid) # imageio.imwrite('test.jpg', img_as_ubyte(rec)) # imageio.imwrite('test_lap.jpg', img_as_float(laplacian_pyramid[-1])) # # laplacian_pyramid[-1] = np.zeros_like(laplacian_pyramid[-1]) # rec = self.lap_pyr_rec(laplacian_pyramid) # imageio.imwrite('test_zero.jpg', img_as_ubyte(rec)) if nr_method == 'freq': y_denoised = self.freq_domain_denoise(y, threshold_sigma=nr_parm) if nr_method == 'freq_smooth': y_denoised = self.freq_domain_denoise_smooth( y, threshold_sigma=nr_parm) elif nr_method == "median": y_denoised = self.median_denoise(y, radius=int(nr_parm)) elif nr_method == "gaussian": y_denoised = self.gaussian_denoise(y, radius=int(nr_parm)) elif nr_method == 'tvl1': y_denoised = cv2.denoise_TVL1(y, y) elif nr_method == 'nlm': y_denoised = self.nlm_denoise(y, int(nr_parm)) elif nr_method == 'bm3d': y_denoised = bm3d.bm3d(y, nr_parm) elif nr_method == 'lap_pyr': denoised_pyr = self.pyr_denoise(laplacian_pyramid, nr_parm) y_denoised = self.lap_pyr_rec(denoised_pyr) elif nr_method == 'bilateral': y_denoised = img_as_float( cv2.bilateralFilter(img_as_ubyte(y), int(nr_parm), 20, 5)) else: y_denoised = self.nlm_denoise(y, int(nr_parm)) mask = self.get_blending_mask(laplacian_pyramid, blend_freq_layer=mask_level, morph_kernel_size=blend_morph_kernel) y_final = self.alpha_blend(y, y_denoised, blend_alpha) # y_final = self.mask_blend(y, y_final, mask, no_blend) rgb_final = self.yuv_combine(y_final, u, v) return np.clip(rgb_final, a_min=0, a_max=1), np.clip(mask, a_min=0, a_max=1)
def main(): # Experiment specifications imagename = 'cameraman256.png' # Load noise-free image y = np.array(Image.open(imagename)) / 255 # Possible noise types to be generated 'gw', 'g1', 'g2', 'g3', 'g4', 'g1w', # 'g2w', 'g3w', 'g4w'. noise_type = 'g3' noise_var = 0.02 # Noise variance seed = 0 # seed for pseudorandom noise realization # Generate noise with given PSD noise, psd, kernel = get_experiment_noise(noise_type, noise_var, seed, y.shape) # N.B.: For the sake of simulating a more realistic acquisition scenario, # the generated noise is *not* circulant. Therefore there is a slight # discrepancy between PSD and the actual PSD computed from infinitely many # realizations of this noise with different seeds. # Generate noisy image corrupted by additive spatially correlated noise # with noise power spectrum PSD z = np.atleast_3d(y) + np.atleast_3d(noise) # Call BM3D With the default settings. y_est = bm3d(z, psd) # To include refiltering: # y_est = bm3d(z, psd, 'refilter') # For other settings, use BM3DProfile. # profile = BM3DProfile(); # equivalent to profile = BM3DProfile('np'); # profile.gamma = 6; # redefine value of gamma parameter # y_est = bm3d(z, psd, profile); # Note: For white noise, you may instead of the PSD # also pass a standard deviation # y_est = bm3d(z, sqrt(noise_var)); psnr = get_psnr(y, y_est) print("PSNR:", psnr) # PSNR ignoring 16-pixel wide borders (as used in the paper), due to refiltering potentially leaving artifacts # on the pixels near the boundary of the image when noise is not circulant psnr_cropped = get_cropped_psnr(y, y_est, [16, 16]) print("PSNR cropped:", psnr_cropped) # Ignore values outside range for display (or plt gives an error for multichannel input) y_est = np.minimum(np.maximum(y_est, 0), 1) z_rang = np.minimum(np.maximum(z, 0), 1) plt.title("y, z, y_est") plt.imshow(np.concatenate((y, np.squeeze(z_rang), y_est), axis=1), cmap='gray') plt.show()
def bm3d_filter(image, sigma_psd=50 / 255): """ Perform BM3D filter on image :param image: Image (numpy ndarray) to perform bm3d filter on :param sigma_psd: Standard deviation for intensities in range [0,255] :return: Filtered image """ image_copy = image.copy() denoised_image = bm3d.bm3d(image_copy, sigma_psd=sigma_psd, stage_arg=bm3d.BM3DStages.ALL_STAGES) return denoised_image
def bm3d_denoise_slicewise(image, FirstSliceNumber, SliceNumDigits, Sigma_n, vl, vh): Nz, Ny, Nx = image.shape data_out = np.zeros((Nz, Ny, Nx), dtype=np.float32) for i in range(Nz): data = np.atleast_3d(((image[i] - vl) / (vh - vl))) #scale to (0,1) data = np.minimum(np.maximum(data, 0), 1) #clip data_out[i] = bm3d(data, Sigma_n) #denoise data_out = np.minimum(np.maximum(data_out, 0), 1) #clip data_out = data_out * (vh - vl) + vl #re-scale return data_out
def filterSharpen(sLayers): ker = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) optimg = copy.copy(sLayers[:]) for i in range(len(sLayers)): copySAR = copy.copy(sLayers[:][i]) sarArrayI = np.array(copySAR) denoised_image = bm3d.bm3d(sLayers[i], sigma_psd=4, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING) x = ndimage.convolve(denoised_image, ker) optimg[i] = x return (optimg)
def bm3d_method(agent_input, params): """ BM3D prior Args: agent_input: full-size reconstruction params: params.noise_std for noise standard deviation Returns: New full-size reconstruction after update """ return bm3d(agent_input, params.noise_std)
def metrics_bm3d_from_ds(ds, noise_std=30, with_shape=True): metrics = Metrics() pred_and_gt_shape = [( (bm3d.bm3d(images_noisy[0, ..., 0].numpy() + 0.5, sigma_psd=noise_std / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) - 0.5)[None, ..., None], images_gt.numpy(), im_shape.numpy(), ) for images_noisy, images_gt, im_shape in tqdm_notebook(ds)] for im_recos, images, im_shape in tqdm_notebook(pred_and_gt_shape, desc=f'Stats for BM3D'): metrics.push(images, im_recos, im_shape) return metrics
def main(args): input_arr = np.load(args.input) H, W = input_arr.shape L = 256 gap_H, gap_W = H % (L // 2), W % (L // 2) i0, i1 = gap_H // 2, H - (gap_H - (gap_H // 2)) j0, j1 = gap_W // 2, H - (gap_W - (gap_W // 2)) input_arr = input_arr[i0:i1, j0:j1] H, W = input_arr.shape if args.is_logt: logtimg_norm = 0.5 * (1 + (input_arr / (3 * np.log(10)))) output = bm3d.bm3d(logtimg_norm, args.sigma_bm3d, profile='high') output += ((0.5 * args.bias_correction) / (3 * np.log(10))) output = np.clip(output, 0.5, None) output = np.exp(-((2 * output - 1) * (3 * np.log(10)))) else: assert input_arr.min() >= 0 max_val = input_arr.max() timg_norm = input_arr / max_val output = max_val * bm3d.bm3d( timg_norm, args.sigma_bm3d, profile='high') io_utils.save_img(io_utils.array_to_img(output, 'L'), args.output) return
def denoiseBM3D(image): t1 = time.time() print(" denoising image using BM3D", flush=True) denoised = np.zeros(image.shape, np.uint8) # empty image cv2.fastNlMeansDenoising(image, denoised, h=15, templateWindowSize=7, searchWindowSize=(15 + 1)) denoised = bm3d.bm3d(image, sigma_psd=0.2, stage_arg=bm3d.BM3DStages.ALL_STAGES) print(", took %f s" % (time.time() - t1)) return denoised
def bm3d_param_selection(): p = Path(r'project_model/results/test/test_latest.mat') outputs = sio.loadmat(p.resolve()) sigmaX = np.exp(np.arange(-10, 10)) base_psnrs = [] for x in sigmaX: for i in range(len(outputs)): orig_image = outputs['real_B'][i] noisy_img = outputs['real_A'][i] #test base model y_pred = bm3d(noisy_img, x) psnr = get_psnr(orig_image, y_pred) base_psnrs.append({'param': x, 'rate': psnr}) best_param = sorted(base_psnrs, key=lambda x: x['rate'], reverse=True)[0] print(f'BM3D best param is {best_param}') return best_param
def bm3d_denoise(marginal, noisy_hist, noise, noise_type, gt_hist=None, return_TVD=False): # print(' bm3d', marginal, noisy_hist.shape) if not gt_hist is None: print(' queried TVD: {:.5f}'.format(get_TVD(gt_hist, noisy_hist))) shape = noisy_hist.shape if len(shape) > 3: noisy_hist = noisy_hist.reshape((shape[0], shape[1], -1)) # axis=2 is channels if exists temp_hist = np.concatenate([noisy_hist] * (int(10/noisy_hist.shape[1])+1), axis=1) temp_hist = np.concatenate([temp_hist] * (int(10/temp_hist.shape[0])+1), axis=0) # print(temp_hist.shape) max_value = np.max(noisy_hist) if noise_type == 'normal': std_dev = noise/max_value elif noise_type == 'Laplace': std_dev = noise/max_value * 2 ** 0.5 else: exit(-1) bm3d_hist = bm3d(temp_hist/max_value, sigma_psd=std_dev) bm3d_hist = bm3d_hist[:noisy_hist.shape[0], :noisy_hist.shape[1]] bm3d_hist *= max_value if len(shape) > 3: bm3d_hist = bm3d_hist.reshape(shape) if return_TVD and not gt_hist is None: bm3d_TVD = get_TVD(gt_hist, bm3d_hist) return bm3d_hist, bm3d_TVD return bm3d_hist
def DFFC(data, flats, darks, downsample, nrPArepetions): # Load frames meanDarkfield = np.mean(darks, axis=1, dtype=np.float64) whiteVect = np.zeros((flats.shape[1], flats.shape[0] * flats.shape[2]), dtype=np.float64) for i in range(flats.shape[1]): whiteVect[i] = flats[:, i, :].flatten() - meanDarkfield.flatten() mn = np.mean(whiteVect, axis=0) # Substract mean flat field M, N = whiteVect.shape Data = whiteVect - mn # ============================================================================= # Parallel Analysis (EEFs selection): # Selection of the number of components for PCA using parallel Analysis. # Each flat field is a single row of the matrix flatFields, different # rows are different observations. # ============================================================================= def cov(X): one_vector = np.ones((1, X.shape[0])) mu = np.dot(one_vector, X) / X.shape[0] X_mean_subtract = X - mu covA = np.dot(X_mean_subtract.T, X_mean_subtract) / (X.shape[0] - 1) return covA def parallelAnalysis(flatFields, repetitions): stdEFF = np.std(flatFields, axis=0, ddof=1, dtype=np.float64) H, W = flatFields.shape keepTrack = np.zeros((H, repetitions), dtype=np.float64) stdMatrix = np.tile(stdEFF, (H, 1)) for i in range(repetitions): print(f"Parallel Analysis - repetition {i}") sample = stdMatrix * np.random.randn(H, W) D1, _ = np.linalg.eig(np.cov(sample)) keepTrack[:, i] = D1.copy() mean_flat_fields_EFF = np.mean(flatFields, axis=0) F = flatFields - mean_flat_fields_EFF D1, V1 = np.linalg.eig(np.cov(F)) selection = np.zeros((1, H)) # mean + 2 * std selection[:, D1 > (np.mean(keepTrack, axis=1) + 2 * np.std(keepTrack, axis=1, ddof=1))] = 1 numberPC = np.sum(selection) return V1, D1, int(numberPC) # Parallel Analysis nrEigenflatfields = 0 print("Parallel Analysis:") while (nrEigenflatfields <= 0): V1, D1, nrEigenflatfields = parallelAnalysis(Data, nrPArepetions) print(f"{nrEigenflatfields} eigen flat fields selected!") idx = D1.argsort()[::-1] D1 = D1[idx] V1 = V1[:, idx] # Calculation eigen flat fields H, C, W = data.shape eig0 = mn.reshape((H, W)) EFF = np.zeros((nrEigenflatfields + 1, H, W)) #n_EFF + 1 eig0 EFF_denoised = np.zeros((nrEigenflatfields + 1, H, W)) #n_EFF + 1 eig0 print("Calculating EFFs:") EFF[0] = eig0 for i in range(nrEigenflatfields): EFF[i + 1] = (np.matmul(Data.T, V1[i]).T).reshape((H, W)) EFF_denoised = EFF.copy() # Denoise eigen flat fields print("Denoising EFFs using BM3D method:") for i in range(1, len(EFF)): print(f"Denoising EFF {i}") EFF_max, EFF_min = EFF_denoised[i, :, :].max(), EFF_denoised[ i, :, :].min() EFF_denoised[i, :, :] = (EFF_denoised[i, :, :] - EFF_min) / (EFF_max - EFF_min) sigma_bm3d = estimate_sigma(EFF_denoised[i, :, :]) * 10 #print(f"Estimated sigma: {sigma_bm3d}") EFF_denoised[i, :, :] = bm3d.bm3d(EFF_denoised[i, :, :], sigma_bm3d) EFF_denoised[i, :, :] = (EFF_denoised[i, :, :] * (EFF_max - EFF_min)) + EFF_min print("Denoising completed.") # ============================================================================= # cost_func: cost funcion used to estimate the weights using TV # ============================================================================= def cost_func(x, *args): (projections, meanFF, FF, DF) = args FF_eff = np.zeros((FF.shape[1], FF.shape[2])) for i in range(len(FF)): FF_eff = FF_eff + x[i] * FF[i] logCorProj = (projections - DF) / ( meanFF + FF_eff) * np.mean(meanFF.flatten() + FF_eff.flatten()) Gx, Gy = np.gradient(logCorProj) mag = (Gx**2 + Gy**2)**(1 / 2) cost = np.sum(mag.flatten()) return cost # ============================================================================= # CondTVmean function: finds the optimal estimates of the coefficients of the # eigen flat fields. # ============================================================================= def condTVmean(projection, meanFF, FF, DF, x, DS): # Downsample image projection = downscale_local_mean(projection, (DS, DS)) meanFF = downscale_local_mean(meanFF, (DS, DS)) FF2 = np.zeros((FF.shape[0], meanFF.shape[0], meanFF.shape[1])) for i in range(len(FF)): FF2[i] = downscale_local_mean(FF[i], (DS, DS)) FF = FF2 DF = downscale_local_mean(DF, (DS, DS)) # Optimize weights (x) x = scipy.optimize.minimize(cost_func, x, args=(projection, meanFF, FF, DF), method='BFGS', tol=1e-8) return x.x H, C, W = data.shape print("TV optimisation for DFF coefficients:") clean_DFFC = np.zeros((H, C, W), dtype=np.float64) for i in range(C): if i % 5 == 0: print("Normalising projection", i) projection = data[:, i, :] # Estimate weights for a single projection meanFF = EFF_denoised[0] FF = EFF_denoised[1:] weights = np.zeros(nrEigenflatfields) x = condTVmean(projection, meanFF, FF, meanDarkfield, weights, downsample) # Dynamic FFC FFeff = np.zeros(meanDarkfield.shape) for j in range(nrEigenflatfields): FFeff = FFeff + x[j] * EFF_denoised[j + 1] tmp = np.divide((projection - meanDarkfield), (EFF_denoised[0] + FFeff)) clean_DFFC[:, i, :] = tmp return [clean_DFFC, EFF, EFF_denoised]
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'directions'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) aligned, est_nnf = align_burst(cfg, burst, model) sim_images = subsample_aligned(cfg, aligned) burst_in, tgt_out = create_training_pairs(burst, sim_images) dn_losses = [] for burst, target in zip(burst_in, tgt_out): # -- forward pass -- est_denoised = model(burst) dn_loss = compute_denoising_loss(est_denoised, target) # -- compute grads -- if cfg.use_seed: torch.set_deterministic(False) dn_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop -- optim.step() scheduler.step() # -- store info -- losses.append(dn_loss.item()) # -- average over losses -- dn_loss = torch.mean(dn_losses) # -- alignment loss -- align_loss = compute_nnf_loss(gt_nnf, est_nnf) # -- total loss -- final_loss = dn_loss + align_loss running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) D = 5 * 10**3 steps_per_epoch = len(train_loader) # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Entropy Loss for Filters # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N) filters_entropy = entropyLoss(filters_shaped) filters_entropy_coeff = 10. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0 #.933**cfg.global_step if cfg.global_step < 100 else 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_ave_d = denoised_ave.detach() losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = rearrange(residuals, 'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level) # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2. rec_ot_pair_coeff = 100 # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_loss = align_mse_coeff * align_mse rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse entropy_loss = filters_entropy_coeff * filters_entropy final_loss = align_loss + rec_loss + entropy_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot_pair.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, noise_critic, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # Setup for epoch # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) D = 5 * 10**3 use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses_nc, losses_nc_count = 0, 0 losses_mse, losses_mse_count = 0, 0 running_loss, total_loss = 0, 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Loss Functions # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- lossRecMSE = LossRec(tensor_grad=cfg.supervised) lossBurstMSE = LossRecBurst(tensor_grad=cfg.supervised) # -=-=-=-=-=-=-=-=-=-=- # Final Configs # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True train_iter = iter(train_loader) if use_timer: clock = Timer() # -=-=-=-=-=-=-=-=-=-=- # GAN Scheduler # -=-=-=-=-=-=-=-=-=-=- # -- noise critic steps -- if epoch == 0: disc_steps = 0 elif epoch < 3: disc_steps = 1 elif epoch < 10: disc_steps = 1 else: disc_steps = 1 # -- denoising steps -- if epoch == 0: gen_steps = 1 if epoch < 3: gen_steps = 15 if epoch < 10: gen_steps = 10 else: gen_steps = 10 # -- steps each epoch -- steps_per_iter = disc_steps * gen_steps steps_per_epoch = len(train_loader) // steps_per_iter if steps_per_epoch > 120: steps_per_epoch = 120 # -=-=-=-=-=-=-=-=-=-=- # Start Epoch # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): for gen_step in range(gen_steps): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Setting up for Iteration # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Formatting Images for FP # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Foward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model( burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # MSE (KPN) Reconstruction Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_rec = lossRecMSE(denoised_ave, gt_img) loss_burst = lossBurstMSE(denoised, gt_img) loss_mse = loss_rec + 100 * loss_burst gbs, spe = cfg.global_step, steps_per_epoch if epoch < 3: weight_mse = 10 else: weight_mse = 10 * 0.9999**(gbs - 3 * spe) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Noise Critic Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_nc = noise_critic.compute_residual_loss(denoised, gt_img) weight_nc = 1 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Final Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- final_loss = weight_mse * loss_mse + weight_nc * loss_nc # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Update Info for Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- update alignment kl loss info -- losses_nc += loss_nc.item() losses_nc_count += 1 # -- update reconstruction kl loss info -- losses_mse += loss_mse.item() losses_mse_count += 1 # -- update info -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Backward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Iterate for Noise Critic # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- for disc_step in range(disc_steps): # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab noisy data -- _burst, _res_imgs, _raw_img, _directions = next(train_iter) _burst = _burst.to(cfg.device) # -- generate "fake" data from noisy data -- _aligned, _aligned_ave, _denoised, _denoised_ave, _filters = model( _burst) _residuals = _denoised - _burst[N // 2].unsqueeze(1).repeat( 1, N, 1, 1, 1) # -- update discriminator -- loss_disc = noise_critic.update_disc(_residuals) # -- message to stdout -- first_update = (disc_step == 0) last_update = (disc_step == disc_steps - 1) iter_update = first_update or last_update # if (batch_idx % cfg.log_interval//2) == 0 and batch_idx > 0 and iter_update: print(f"[Noise Critic]: {loss_disc}") # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Print Message to Stdout # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- init -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- write record -- if use_record: record_losses = record_losses.append( { 'burst': burst_loss, 'ave': ave_loss, 'ot': ot_loss, 'psnr': psnr, 'psnr_std': psnr_std }, ignore_index=True) # -- update losses -- running_loss /= cfg.log_interval # -- average mse losses -- ave_losses_mse = losses_mse / losses_mse_count losses_mse, losses_mse_count = 0, 0 # -- average noise critic loss -- ave_losses_nc = losses_nc / losses_nc_count losses_nc, losses_nc_count = 0, 0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, ave_losses_mse, ave_losses_nc) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [mse]: %.2e [nc]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 2 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Check Some Gradients # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- def mse_v_wassersteinG_check_some_gradients(cfg, burst, gt_img, model): grads = edict() gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) model.unet_info.model.zero_grad() burst.requires_grad_(True) outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] residuals = denoised - gt_img_rs P = 1. #residuals.numel() denoised.retain_grad() rec_mse = (denoised.reshape(B, -1) - gt_img.reshape(B, -1))**2 rec_mse.retain_grad() ones = P * torch.ones_like(rec_mse) rec_mse.backward(ones, retain_graph=True) grads.rmse = rec_mse.grad.clone().reshape(B, -1) grad_rec_mse = grads.rmse grads.dmse = denoised.grad.clone().reshape(B, -1) grad_denoised_mse = grads.dmse ones = torch.ones_like(rec_mse) grads.d_to_b = torch.autograd.grad(rec_mse, denoised, ones)[0].reshape(B, -1) model.unet_info.model.zero_grad() outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # residuals = denoised - gt_img_rs # rec_ot = w_gaussian_bp(residuals,noise_level) denoised.retain_grad() rec_ot_v = (denoised - gt_img_rs)**2 rec_ot_v.retain_grad() rec_ot = (rec_ot_v.mean() - noise_level / 255.)**2 rec_ot.retain_grad() ones = P * torch.ones_like(rec_ot) rec_ot.backward(ones) grad_denoised_ot = denoised.grad.clone().reshape(B, -1) grads.dot = grad_denoised_ot grad_rec_ot = rec_ot_v.grad.clone().reshape(B, -1) grads.rot = grad_denoised_ot print("Gradient Name Info") for name, g in grads.items(): g_norm = g.norm().item() g_mean = g.mean().item() g_std = g.std().item() print(name, g.shape, g_norm, g_mean, g_std) print_pairs = False if print_pairs: print("All Gradient Ratios") for name_t, g_t in grads.items(): for name_b, g_b in grads.items(): ratio = g_t / g_b ratio_m = ratio.mean().item() ratio_std = ratio.std().item() print("[%s/%s] [%2.2e +/- %2.2e]" % (name_t, name_b, ratio_m, ratio_std)) use_true_mse = False if use_true_mse: print("Ratios with Estimated MSE Gradient") true_dmse = 2 * torch.mean(denoised_ave - gt_img)**2 ratio_mse = grads.dmse / true_dmse ratio_mse_dtb = grads.dmse / grads.d_to_b print(ratio_mse) print(ratio_mse_dtb) dot_v_dmse = True if dot_v_dmse: print("Ratio of Denoised OT and Denoised MSE") ratio_mseot = (grads.dmse / grads.dot) print(ratio_mseot.mean(), ratio_mseot.std()) ratio_mseot = ratio_mseot[0, 0].item() c1 = torch.mean((denoised - gt_img_rs)**2).item() c2 = noise_level / 255 m = torch.mean(gt_img_rs).item() true_ratio = 2. * (c1 - c2) / (np.product(burst.shape)) # diff = denoised.reshape(B,-1)-gt_img_rs.reshape(B,-1) # true_ratio = 2.*(c1 - c2) * ( diff / ( np.product(burst.shape) ) ) # print(c1,c2,m,true_ratio,1./true_ratio) ratio_mseot = (grads.dmse / (grads.dot)) print(ratio_mseot * true_ratio) # ratio_mseot = (grads.dmse / ( grads.dot / diff) ) # print(ratio_mseot*true_ratio) # print(ratio_mseot.mean(),ratio_mseot.std()) exit() model.unet_info.model.zero_grad() # mse_v_wassersteinG_check_some_gradients(cfg,burst,gt_img,model) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (aligned) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned_filters_rs = rearrange(aligned_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(aligned_filters_rs, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_align = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (denoised) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(denoised_filters, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_denoiser = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) norm_loss_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0. #0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # pad = 2*cfg.N # fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # centered_residuals = tvF.center_crop(residuals,(fs-pad,fs-pad)) # centered_residuals = tvF.center_crop(residuals,(fs//2,fs//2)) # align_ot = kl_gaussian_bp(residuals,noise_level,flip=True) align_ot = kl_gaussian_bp_patches(residuals, noise_level, flip=True, patchsize=16) align_ot_coeff = 0 # 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- computation -- gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # rec_ot = kl_gaussian_bp(residuals,noise_level) rec_ot = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot /= 2. # alpha_grid = [0.,1.,5.,10.,25.] # for alpha in alpha_grid: # # residuals = torch.normal(torch.zeros_like(residuals)+ gt_img_rs*alpha/255.,noise_level/255.) # residuals = torch.normal(torch.zeros_like(residuals),noise_level/255.+ gt_img_rs*alpha/255.) # rec_ot_v2_a = kl_gaussian_bp_patches(residuals,noise_level,patchsize=16) # rec_ot_v1_b = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_v2_b = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_all = torch.tensor([rec_ot_v1_a,rec_ot_v2_a,rec_ot_v1_b,rec_ot_v2_b]) # rec_ot_v2 = (rec_ot_v2_a + rec_ot_v2_b).item()/2. # print(alpha,torch.min(rec_ot_all),torch.max(rec_ot_all),rec_ot_v1,rec_ot_v2) # exit() # rec_ot = w_gaussian_bp(residuals,noise_level) # print(residuals.numel()) rec_ot_coeff = 100. #residuals.numel()*2. # 1000.# - .997**cfg.global_step # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) # rec_ot_loss_v1 = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) # rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * (norm_loss_denoiser + norm_loss_align) align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = 0 #filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy # final_loss = align_loss + rec_loss + entropy_loss + norm_loss final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() # for name,params in model.unet_info.model.named_parameters(): # if not ("weight" in name): continue # print(params.grad.norm()) # # print(module.conv1.parameters()) # # print(module.conv1.data.grad) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def operate_Bm3d(self, image_noisy, mode, qp): bm3d.bm3d(image_noisy, sigma_psd=self.getPsd(mode, qp))