示例#1
0
def test_blockmatching_usage():
    z, sigma = generate_input_data()
    z2 = np.ones(z.shape)
    res_cr, bm = bm3d(z2, sigma, blockmatches=(True, True))
    res1 = bm3d(z, sigma)
    res2 = bm3d(z, sigma, blockmatches=bm)
    assert np.max(np.abs(res1 - res2)) > ALLOWED_ERROR_SAME
示例#2
0
def test_consistency_traditional():
    z, sigma = generate_input_data()
    p = BM3DProfile()
    p.nf = 0
    res1 = bm3d(np.copy(z), np.copy(sigma), p)
    res2 = bm3d(np.copy(z), np.copy(sigma), p)
    assert np.max(np.abs(res2 - res1)) < ALLOWED_ERROR_SAME
def BM3D_proj(S1, S2, lambda_this):
    """
    This function does the BM3D projection
    """
    # BM3D denoising
    S1 = bm3d(S1, lambda_this)
    S2 = bm3d(S2, lambda_this)

    return S1, S2
示例#4
0
def test_blockmatches_wie_inconsistency():
    z, sigma = generate_input_data()
    z2, sigma = generate_input_data(2)

    res_ref, bms = bm3d(z2, sigma, blockmatches=(True, True))
    res_ht = bm3d(z, sigma, stage_arg=BM3DStages.HARD_THRESHOLDING)

    res1 = bm3d(z, sigma, stage_arg=res_ht)
    res2 = bm3d(z, sigma, stage_arg=res_ht, blockmatches=bms)

    assert np.max(np.abs(res1 - res2)) != 0
示例#5
0
def bm3d_proj(S, lambda_this):
    """
    This function does the TV projection
    """
    S1 = np.reshape(S[0, :], image_size)
    S2 = np.reshape(S[1, :], image_size)
    # TV denoising
    S1 = bm3d(S1, lambda_this)
    S2 = bm3d(S2, lambda_this)

    S[0, :] = np.reshape(S1, (1, n * n))
    S[1, :] = np.reshape(S2, (1, n * n))

    return S
def BM3D_proj(S, image_size, lambda_this):
    """
    This function does the BM3D projection
    """
    S1 = np.reshape(S[0, :], image_size)
    S2 = np.reshape(S[1, :], image_size)
    # BM3D denoising
    S1 = bm3d(S1, lambda_this)
    S2 = bm3d(S2, lambda_this)

    S[0, :] = np.reshape(S1, (1, image_size[0] * image_size[1]))
    S[1, :] = np.reshape(S2, (1, image_size[0] * image_size[1]))

    return S
示例#7
0
 def calc(self, projs, theta, sart_plot=False):
     image_r = super(SartBM3DReconstructor, self).calc(projs,
                                                       theta,
                                                       sart_plot=sart_plot)
     #denoise with tv
     self.image_r = bm3d(image_r, self.bm3d_sigma)
     return self.image_r
def bm3d_denoise(fn, save_audio, plot):
    power_sp = np.load(fn)
    db = np.log10(power_sp)
    z = np.atleast_3d(db)

    db_est = bm3d(z, np.sqrt(smoothing_factor), profile=profile)
    save_files(db, db_est, fn, power_sp, save_audio=save_audio, plot=plot)
示例#9
0
def fast_hyde_eigen_image_denoising(img, k_subspace, r_w, e, eigen_y,
                                    n) -> np.ndarray:
    # %% --------------------------Eigen-image denoising ------------------------------------
    # send slices of the image to the GPU if that is the case,
    rows, cols, b = img.shape
    np_dtype = np.float32 if img.dtype is torch.float32 else np.float64
    eigen_y_bm3d = np.empty((k_subspace, n), dtype=np_dtype)
    ecpu = e.to(device="cpu", non_blocking=True)
    r_w = r_w.to(device="cpu", non_blocking=True)

    nxt_eigen = eigen_y[0].cpu()
    mx = min(k_subspace, eigen_y.shape[0])
    for i in range(mx):
        lp_eigen = nxt_eigen.numpy()
        if i < mx - 1:
            nxt_eigen = eigen_y[i + 1].to(device="cpu", non_blocking=True)
        # produce eigen-image
        eigen_im = lp_eigen
        min_x = np.min(eigen_im)
        max_x = np.max(eigen_im)
        eigen_im -= min_x
        scale = max_x - min_x
        # normalize eigen_im
        eigen_im = np.reshape(eigen_im, (rows, cols)) / scale
        if i == 0:
            ecpu = ecpu.numpy()
            r_w = r_w.numpy()
        sigma = np.sqrt(ecpu[:, i].T @ r_w @ ecpu[:, i]) / scale

        filt_eigen_im = bm3d.bm3d(eigen_im, sigma)

        eigen_y_bm3d[i, :] = (filt_eigen_im * scale + min_x).reshape(
            eigen_y_bm3d[i, :].shape)

    return eigen_y_bm3d
示例#10
0
def BM3D(noisy_images: np.ndarray, noise_std_dev: float) -> np.ndarray:
    """
    Params:
    noisy_images: receive noisy_images with shape (IMG, M, N, C), 
    where IMG is the quantity of noisy images 
    with dimensions MxN and
    C represents the color dimensions, i.e. if the images are colored,
    C = 3 (RGB), otherwise if C = 1, is grayscale. 

    noise_std_dev: the standart deviation from noise.
    """ 
    validate_array_input(noisy_images)
    validate_if_noise_std_dev_is_a_float(noise_std_dev)

    filtered_images = []
    for i in range(noisy_images.shape[0]):
        filtered_images.append(
            bm3d.bm3d(
                noisy_images[i, :,:,:], 
                sigma_psd=noise_std_dev, 
                stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
            )
        )
    
    filtered_images = np.array(filtered_images)

    return filtered_images
def filterBM3D(sLayers):
    optimg=copy.copy(sLayers[:]) 
    for i in range(len(sLayers)):
        copySAR = copy.copy(sLayers[:][i])
        sarArrayI = np.array(copySAR)
        denoised_image = bm3d.bm3d(sLayers[i], sigma_psd=4, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
        optimg[i] = denoised_image
    return(optimg)
def bm3dPrior(m_, simul_params, pnp_params):
    ngx = simul_params["ngx"]
    ngy = simul_params["ngy"]
    sigma = pnp_params["dn_sigma"]

    m = m_.reshape(ngy, ngx)
    m_denoise = bm3d(m, sigma)
    m_denoise = m_denoise.reshape(ngy * ngx, 1)
    return m_denoise
示例#13
0
def bm3d_fil(): #Block Matching & 3D Filter works fine
    if locerror['text']=="":
        messagebox.showerror("Error","Please choose the locattion")
    else:
        img=cv2.imread(dir,1)
        img1=img_as_float(img)
        bm3d_img=bm3d.bm3d(img1, sigma_psd=0.05,stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
        cv2.imshow("Block Matching & 3D Filter",bm3d_img)
        cv2.waitKey()
        cv2.destroyAllWindows()
示例#14
0
    def run(self,
            rgb_img,
            nr_method,
            nr_parm,
            no_blend,
            mask_level,
            blend_morph_kernel,
            blend_alpha=0.5):
        y, u, v = self.yuv_split(rgb_img)
        laplacian_pyramid = self.lap_pyr(y)

        # rec = self.lap_pyr_rec(laplacian_pyramid)
        # imageio.imwrite('test.jpg', img_as_ubyte(rec))
        # imageio.imwrite('test_lap.jpg', img_as_float(laplacian_pyramid[-1]))
        #
        # laplacian_pyramid[-1] = np.zeros_like(laplacian_pyramid[-1])
        # rec = self.lap_pyr_rec(laplacian_pyramid)
        # imageio.imwrite('test_zero.jpg', img_as_ubyte(rec))

        if nr_method == 'freq':
            y_denoised = self.freq_domain_denoise(y, threshold_sigma=nr_parm)
        if nr_method == 'freq_smooth':
            y_denoised = self.freq_domain_denoise_smooth(
                y, threshold_sigma=nr_parm)
        elif nr_method == "median":
            y_denoised = self.median_denoise(y, radius=int(nr_parm))
        elif nr_method == "gaussian":
            y_denoised = self.gaussian_denoise(y, radius=int(nr_parm))
        elif nr_method == 'tvl1':
            y_denoised = cv2.denoise_TVL1(y, y)
        elif nr_method == 'nlm':
            y_denoised = self.nlm_denoise(y, int(nr_parm))
        elif nr_method == 'bm3d':
            y_denoised = bm3d.bm3d(y, nr_parm)
        elif nr_method == 'lap_pyr':
            denoised_pyr = self.pyr_denoise(laplacian_pyramid, nr_parm)
            y_denoised = self.lap_pyr_rec(denoised_pyr)
        elif nr_method == 'bilateral':
            y_denoised = img_as_float(
                cv2.bilateralFilter(img_as_ubyte(y), int(nr_parm), 20, 5))
        else:
            y_denoised = self.nlm_denoise(y, int(nr_parm))

        mask = self.get_blending_mask(laplacian_pyramid,
                                      blend_freq_layer=mask_level,
                                      morph_kernel_size=blend_morph_kernel)

        y_final = self.alpha_blend(y, y_denoised, blend_alpha)
        # y_final = self.mask_blend(y, y_final, mask, no_blend)

        rgb_final = self.yuv_combine(y_final, u, v)

        return np.clip(rgb_final, a_min=0, a_max=1), np.clip(mask,
                                                             a_min=0,
                                                             a_max=1)
def main():
    # Experiment specifications
    imagename = 'cameraman256.png'

    # Load noise-free image
    y = np.array(Image.open(imagename)) / 255
    # Possible noise types to be generated 'gw', 'g1', 'g2', 'g3', 'g4', 'g1w',
    # 'g2w', 'g3w', 'g4w'.
    noise_type = 'g3'
    noise_var = 0.02  # Noise variance
    seed = 0  # seed for pseudorandom noise realization

    # Generate noise with given PSD
    noise, psd, kernel = get_experiment_noise(noise_type, noise_var, seed,
                                              y.shape)
    # N.B.: For the sake of simulating a more realistic acquisition scenario,
    # the generated noise is *not* circulant. Therefore there is a slight
    # discrepancy between PSD and the actual PSD computed from infinitely many
    # realizations of this noise with different seeds.

    # Generate noisy image corrupted by additive spatially correlated noise
    # with noise power spectrum PSD
    z = np.atleast_3d(y) + np.atleast_3d(noise)

    # Call BM3D With the default settings.
    y_est = bm3d(z, psd)

    # To include refiltering:
    # y_est = bm3d(z, psd, 'refilter')

    # For other settings, use BM3DProfile.
    # profile = BM3DProfile(); # equivalent to profile = BM3DProfile('np');
    # profile.gamma = 6;  # redefine value of gamma parameter
    # y_est = bm3d(z, psd, profile);

    # Note: For white noise, you may instead of the PSD
    # also pass a standard deviation
    # y_est = bm3d(z, sqrt(noise_var));

    psnr = get_psnr(y, y_est)
    print("PSNR:", psnr)

    # PSNR ignoring 16-pixel wide borders (as used in the paper), due to refiltering potentially leaving artifacts
    # on the pixels near the boundary of the image when noise is not circulant
    psnr_cropped = get_cropped_psnr(y, y_est, [16, 16])
    print("PSNR cropped:", psnr_cropped)

    # Ignore values outside range for display (or plt gives an error for multichannel input)
    y_est = np.minimum(np.maximum(y_est, 0), 1)
    z_rang = np.minimum(np.maximum(z, 0), 1)
    plt.title("y, z, y_est")
    plt.imshow(np.concatenate((y, np.squeeze(z_rang), y_est), axis=1),
               cmap='gray')
    plt.show()
示例#16
0
def bm3d_filter(image, sigma_psd=50 / 255):
    """
    Perform BM3D filter on image
    :param image: Image (numpy ndarray) to perform bm3d filter on
    :param sigma_psd: Standard deviation for intensities in range [0,255]
    :return: Filtered image
    """
    image_copy = image.copy()
    denoised_image = bm3d.bm3d(image_copy,
                               sigma_psd=sigma_psd,
                               stage_arg=bm3d.BM3DStages.ALL_STAGES)
    return denoised_image
示例#17
0
def bm3d_denoise_slicewise(image, FirstSliceNumber, SliceNumDigits, Sigma_n,
                           vl, vh):
    Nz, Ny, Nx = image.shape
    data_out = np.zeros((Nz, Ny, Nx), dtype=np.float32)
    for i in range(Nz):
        data = np.atleast_3d(((image[i] - vl) / (vh - vl)))  #scale to (0,1)
        data = np.minimum(np.maximum(data, 0), 1)  #clip
        data_out[i] = bm3d(data, Sigma_n)  #denoise

    data_out = np.minimum(np.maximum(data_out, 0), 1)  #clip
    data_out = data_out * (vh - vl) + vl  #re-scale
    return data_out
示例#18
0
def filterSharpen(sLayers):
    ker = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    optimg = copy.copy(sLayers[:])
    for i in range(len(sLayers)):
        copySAR = copy.copy(sLayers[:][i])
        sarArrayI = np.array(copySAR)
        denoised_image = bm3d.bm3d(sLayers[i],
                                   sigma_psd=4,
                                   stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
        x = ndimage.convolve(denoised_image, ker)
        optimg[i] = x
    return (optimg)
示例#19
0
def bm3d_method(agent_input, params):
    """
    BM3D prior

    Args:
        agent_input: full-size reconstruction
        params:  params.noise_std for noise standard deviation

    Returns:
        New full-size reconstruction after update
    """

    return bm3d(agent_input, params.noise_std)
示例#20
0
def metrics_bm3d_from_ds(ds, noise_std=30, with_shape=True):
    metrics = Metrics()
    pred_and_gt_shape = [(
        (bm3d.bm3d(images_noisy[0, ..., 0].numpy() + 0.5,
                   sigma_psd=noise_std / 255,
                   stage_arg=bm3d.BM3DStages.ALL_STAGES) - 0.5)[None, ...,
                                                                None],
        images_gt.numpy(),
        im_shape.numpy(),
    ) for images_noisy, images_gt, im_shape in tqdm_notebook(ds)]
    for im_recos, images, im_shape in tqdm_notebook(pred_and_gt_shape,
                                                    desc=f'Stats for BM3D'):
        metrics.push(images, im_recos, im_shape)
    return metrics
示例#21
0
def main(args):
    input_arr = np.load(args.input)
    H, W = input_arr.shape
    L = 256
    gap_H, gap_W = H % (L // 2), W % (L // 2)
    i0, i1 = gap_H // 2, H - (gap_H - (gap_H // 2))
    j0, j1 = gap_W // 2, H - (gap_W - (gap_W // 2))
    input_arr = input_arr[i0:i1, j0:j1]
    H, W = input_arr.shape
    if args.is_logt:
        logtimg_norm = 0.5 * (1 + (input_arr / (3 * np.log(10))))
        output = bm3d.bm3d(logtimg_norm, args.sigma_bm3d, profile='high')
        output += ((0.5 * args.bias_correction) / (3 * np.log(10)))
        output = np.clip(output, 0.5, None)
        output = np.exp(-((2 * output - 1) * (3 * np.log(10))))
    else:
        assert input_arr.min() >= 0
        max_val = input_arr.max()
        timg_norm = input_arr / max_val
        output = max_val * bm3d.bm3d(
            timg_norm, args.sigma_bm3d, profile='high')
    io_utils.save_img(io_utils.array_to_img(output, 'L'), args.output)
    return
示例#22
0
def denoiseBM3D(image):
    t1 = time.time()
    print("  denoising image using BM3D", flush=True)
    denoised = np.zeros(image.shape, np.uint8)  # empty image
    cv2.fastNlMeansDenoising(image,
                             denoised,
                             h=15,
                             templateWindowSize=7,
                             searchWindowSize=(15 + 1))

    denoised = bm3d.bm3d(image,
                         sigma_psd=0.2,
                         stage_arg=bm3d.BM3DStages.ALL_STAGES)

    print(", took %f s" % (time.time() - t1))

    return denoised
示例#23
0
def bm3d_param_selection():

    p = Path(r'project_model/results/test/test_latest.mat')
    outputs = sio.loadmat(p.resolve())

    sigmaX = np.exp(np.arange(-10, 10))

    base_psnrs = []

    for x in sigmaX:
        for i in range(len(outputs)):
            orig_image = outputs['real_B'][i]
            noisy_img = outputs['real_A'][i]

            #test base model
            y_pred = bm3d(noisy_img, x)

            psnr = get_psnr(orig_image, y_pred)
            base_psnrs.append({'param': x, 'rate': psnr})

    best_param = sorted(base_psnrs, key=lambda x: x['rate'], reverse=True)[0]
    print(f'BM3D best param is {best_param}')
    return best_param
示例#24
0
def bm3d_denoise(marginal, noisy_hist, noise, noise_type, gt_hist=None, return_TVD=False):

    # print('    bm3d', marginal, noisy_hist.shape)
    if not gt_hist is None:
        print('        queried TVD: {:.5f}'.format(get_TVD(gt_hist, noisy_hist)))

    shape = noisy_hist.shape
    if len(shape) > 3:
        noisy_hist = noisy_hist.reshape((shape[0], shape[1], -1))

    # axis=2 is channels if exists
    temp_hist = np.concatenate([noisy_hist] * (int(10/noisy_hist.shape[1])+1), axis=1)
    temp_hist = np.concatenate([temp_hist] * (int(10/temp_hist.shape[0])+1), axis=0)
    # print(temp_hist.shape)

    max_value = np.max(noisy_hist)
    if noise_type == 'normal':
        std_dev = noise/max_value
    elif noise_type == 'Laplace':
        std_dev = noise/max_value * 2 ** 0.5
    else:
        exit(-1)

    bm3d_hist = bm3d(temp_hist/max_value, sigma_psd=std_dev)

    bm3d_hist = bm3d_hist[:noisy_hist.shape[0], :noisy_hist.shape[1]]
    bm3d_hist *= max_value

    if len(shape) > 3:
        bm3d_hist = bm3d_hist.reshape(shape)
    
    if return_TVD and not gt_hist is None:
        bm3d_TVD = get_TVD(gt_hist, bm3d_hist)
        return bm3d_hist, bm3d_TVD

    return bm3d_hist
示例#25
0
def DFFC(data, flats, darks, downsample, nrPArepetions):
    # Load frames
    meanDarkfield = np.mean(darks, axis=1, dtype=np.float64)
    whiteVect = np.zeros((flats.shape[1], flats.shape[0] * flats.shape[2]),
                         dtype=np.float64)
    for i in range(flats.shape[1]):
        whiteVect[i] = flats[:, i, :].flatten() - meanDarkfield.flatten()
    mn = np.mean(whiteVect, axis=0)

    # Substract mean flat field
    M, N = whiteVect.shape
    Data = whiteVect - mn

    # =============================================================================
    # Parallel Analysis (EEFs selection):
    #      Selection of the number of components for PCA using parallel Analysis.
    #      Each flat field is a single row of the matrix flatFields, different
    #      rows are different observations.
    # =============================================================================

    def cov(X):
        one_vector = np.ones((1, X.shape[0]))
        mu = np.dot(one_vector, X) / X.shape[0]
        X_mean_subtract = X - mu
        covA = np.dot(X_mean_subtract.T, X_mean_subtract) / (X.shape[0] - 1)
        return covA

    def parallelAnalysis(flatFields, repetitions):
        stdEFF = np.std(flatFields, axis=0, ddof=1, dtype=np.float64)
        H, W = flatFields.shape
        keepTrack = np.zeros((H, repetitions), dtype=np.float64)
        stdMatrix = np.tile(stdEFF, (H, 1))
        for i in range(repetitions):
            print(f"Parallel Analysis - repetition {i}")
            sample = stdMatrix * np.random.randn(H, W)
            D1, _ = np.linalg.eig(np.cov(sample))
            keepTrack[:, i] = D1.copy()
        mean_flat_fields_EFF = np.mean(flatFields, axis=0)
        F = flatFields - mean_flat_fields_EFF
        D1, V1 = np.linalg.eig(np.cov(F))
        selection = np.zeros((1, H))
        # mean + 2 * std
        selection[:, D1 > (np.mean(keepTrack, axis=1) +
                           2 * np.std(keepTrack, axis=1, ddof=1))] = 1
        numberPC = np.sum(selection)
        return V1, D1, int(numberPC)

    # Parallel Analysis
    nrEigenflatfields = 0
    print("Parallel Analysis:")
    while (nrEigenflatfields <= 0):
        V1, D1, nrEigenflatfields = parallelAnalysis(Data, nrPArepetions)
    print(f"{nrEigenflatfields} eigen flat fields selected!")
    idx = D1.argsort()[::-1]
    D1 = D1[idx]
    V1 = V1[:, idx]

    # Calculation eigen flat fields
    H, C, W = data.shape
    eig0 = mn.reshape((H, W))
    EFF = np.zeros((nrEigenflatfields + 1, H, W))  #n_EFF + 1 eig0
    EFF_denoised = np.zeros((nrEigenflatfields + 1, H, W))  #n_EFF + 1 eig0
    print("Calculating EFFs:")
    EFF[0] = eig0
    for i in range(nrEigenflatfields):
        EFF[i + 1] = (np.matmul(Data.T, V1[i]).T).reshape((H, W))

    EFF_denoised = EFF.copy()
    # Denoise eigen flat fields
    print("Denoising EFFs using BM3D method:")
    for i in range(1, len(EFF)):
        print(f"Denoising EFF {i}")
        EFF_max, EFF_min = EFF_denoised[i, :, :].max(), EFF_denoised[
            i, :, :].min()
        EFF_denoised[i, :, :] = (EFF_denoised[i, :, :] - EFF_min) / (EFF_max -
                                                                     EFF_min)
        sigma_bm3d = estimate_sigma(EFF_denoised[i, :, :]) * 10
        #print(f"Estimated sigma: {sigma_bm3d}")
        EFF_denoised[i, :, :] = bm3d.bm3d(EFF_denoised[i, :, :], sigma_bm3d)
        EFF_denoised[i, :, :] = (EFF_denoised[i, :, :] *
                                 (EFF_max - EFF_min)) + EFF_min

    print("Denoising completed.")

    # =============================================================================
    # cost_func: cost funcion used to estimate the weights using TV
    # =============================================================================

    def cost_func(x, *args):
        (projections, meanFF, FF, DF) = args
        FF_eff = np.zeros((FF.shape[1], FF.shape[2]))
        for i in range(len(FF)):
            FF_eff = FF_eff + x[i] * FF[i]
        logCorProj = (projections - DF) / (
            meanFF + FF_eff) * np.mean(meanFF.flatten() + FF_eff.flatten())
        Gx, Gy = np.gradient(logCorProj)
        mag = (Gx**2 + Gy**2)**(1 / 2)
        cost = np.sum(mag.flatten())
        return cost

    # =============================================================================
    # CondTVmean function: finds the optimal estimates  of the coefficients of the
    # eigen flat fields.
    # =============================================================================

    def condTVmean(projection, meanFF, FF, DF, x, DS):
        # Downsample image
        projection = downscale_local_mean(projection, (DS, DS))
        meanFF = downscale_local_mean(meanFF, (DS, DS))
        FF2 = np.zeros((FF.shape[0], meanFF.shape[0], meanFF.shape[1]))
        for i in range(len(FF)):
            FF2[i] = downscale_local_mean(FF[i], (DS, DS))
        FF = FF2
        DF = downscale_local_mean(DF, (DS, DS))
        # Optimize weights (x)
        x = scipy.optimize.minimize(cost_func,
                                    x,
                                    args=(projection, meanFF, FF, DF),
                                    method='BFGS',
                                    tol=1e-8)
        return x.x

    H, C, W = data.shape
    print("TV optimisation for DFF coefficients:")
    clean_DFFC = np.zeros((H, C, W), dtype=np.float64)
    for i in range(C):
        if i % 5 == 0: print("Normalising projection", i)
        projection = data[:, i, :]
        # Estimate weights for a single projection
        meanFF = EFF_denoised[0]
        FF = EFF_denoised[1:]
        weights = np.zeros(nrEigenflatfields)
        x = condTVmean(projection, meanFF, FF, meanDarkfield, weights,
                       downsample)
        # Dynamic FFC
        FFeff = np.zeros(meanDarkfield.shape)
        for j in range(nrEigenflatfields):
            FFeff = FFeff + x[j] * EFF_denoised[j + 1]
        tmp = np.divide((projection - meanDarkfield),
                        (EFF_denoised[0] + FFeff))
        clean_DFFC[:, i, :] = tmp

    return [clean_DFFC, EFF, EFF_denoised]
示例#26
0
文件: learn.py 项目: gauenk/cl_gen
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses,
               writer):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    noise_type = cfg.noise_params.ntype

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0
    dynamics_acc, dynamics_count = 0, 0

    write_examples = False
    write_examples_iter = 200
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #   Load Pre-Simulated Random Numbers
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Dataset Augmentation
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    transforms = [tvF.vflip, tvF.hflip, tvF.rotate]
    aug = RandomChoice(transforms)

    def apply_transformations(burst, gt_img):
        N, B = burst.shape[:2]
        gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w')
        all_images = torch.cat([gt_img_rs, burst], dim=0)
        all_images = rearrange(all_images, 'n b c h w -> (n b) c h w')
        tv_utils.save_image(all_images,
                            'aug_original.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = aug(all_images)
        tv_utils.save_image(aug_images,
                            'aug_augmented.png',
                            nrow=N + 1,
                            normalize=True)
        aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B)
        aug_gt_img = aug_images[0]
        aug_burst = aug_images[1:]
        return aug_burst, aug_gt_img

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Half Precision
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    # model.align_info.model.half()
    # model.denoiser_info.model.half()
    # model.unet_info.model.half()
    # models = [model.align_info.model,
    #           model.denoiser_info.model,
    #           model.unet_info.model]
    # for model_l in models:
    #     model_l.half()
    #     for layer in model_l.modules():
    #         if isinstance(layer, torch.nn.BatchNorm2d):
    #             layer.float()

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha,
                                  gradient_L1=~cfg.supervised)
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Noise2Noise
    #
    # -=-=-=-=-=-=-=-=-=-=-

    noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer:
        data_clock = Timer()
        clock = Timer()
    ds_size = len(train_loader)
    small_ds = ds_size < 500
    steps_per_epoch = ds_size if not small_ds else 500

    write_examples_iter = steps_per_epoch // 3
    all_filters = []

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-
    dynamics_acc_i = -1.
    if cfg.use_seed:
        init = torch.initial_seed()
        torch.manual_seed(cfg.seed + 1 + epoch + init)
    train_iter = iter(train_loader)
    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer:
            data_clock.tic()
            clock.tic()

        # -- grab data batch --
        sample = next(train_iter)
        burst, raw_img, motion = sample['burst'], sample['clean'], sample[
            'directions']
        raw_img_iid = sample['iid']
        raw_img_iid = raw_img_iid.cuda(non_blocking=True)
        burst = burst.cuda(non_blocking=True)

        aligned, est_nnf = align_burst(cfg, burst, model)
        sim_images = subsample_aligned(cfg, aligned)
        burst_in, tgt_out = create_training_pairs(burst, sim_images)

        dn_losses = []
        for burst, target in zip(burst_in, tgt_out):

            # -- forward pass --
            est_denoised = model(burst)
            dn_loss = compute_denoising_loss(est_denoised, target)

            # -- compute grads --
            if cfg.use_seed: torch.set_deterministic(False)
            dn_loss.backward()
            if cfg.use_seed: torch.set_deterministic(True)

            # -- backprop --
            optim.step()
            scheduler.step()

            # -- store info --
            losses.append(dn_loss.item())

        # -- average over losses --
        dn_loss = torch.mean(dn_losses)

        # -- alignment loss --
        align_loss = compute_nnf_loss(gt_nnf, est_nnf)

        # -- total loss --
        final_loss = dn_loss + align_loss
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- recompute model output for original images --
            outputs = model(burst_og)
            m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)
            raw_img = get_nmlz_tgt_img(cfg, raw_img)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, m_aligned_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(burst_og, dim=0)
            if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave)
            mse_loss = F.mse_loss(raw_img, mis_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25))
            # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25))

            # -- psnr for [bm3d] --
            mid_img_og = burst[N // 2]
            bm3d_nb_psnrs = []
            M = 4 if B > 4 else B
            for b in range(M):
                bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                # maybe an issue here
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for input averaged frames --
            # burst_ave = torch.mean(burst_og,dim=0)
            # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1)
            # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy()
            # psnr_input_ave = np.mean(mse_to_psnr(mse_loss))
            # psnr_input_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for aligned + denoised --
            R = denoised.shape[1]
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1)
            # if noise_type == "qis": denoised = quantize_img(cfg,denoised)
            # save_image(denoised_ave,"denoised_ave.png")
            # save_image(denoised,"denoised.png")
            mse_loss = F.mse_loss(raw_img_repN, denoised,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img, denoised_ave,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- ave dynamic acc --
            ave_dyn_acc = dynamics_acc / dynamics_count * 100.
            dynamics_acc, dynamics_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, ave_dyn_acc)  #rec_ot_ave)

            #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e"
                % write_info,
                flush=True)
            # -- write to summary writer --
            if writer:
                writer.add_scalar('train/running-loss', running_loss,
                                  cfg.global_step)
                writer.add_scalars('train/model-psnr', {
                    'ave': psnr,
                    'std': psnr_std
                }, cfg.global_step)
                writer.add_scalars('train/dn-frame-psnr', {
                    'ave': psnr_denoised_ave,
                    'std': psnr_denoised_std
                }, cfg.global_step)

            # -- reset loss --
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, motion)

        if use_timer: clock.toc()

        if use_timer:
            print("data_clock", data_clock.average_time)
            print("clock", clock.average_time)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
示例#27
0
文件: learn.py 项目: gauenk/cl_gen
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch,
               record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    D = 5 * 10**3
    steps_per_epoch = len(train_loader)

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        optimizer.zero_grad()
        model.zero_grad()
        model.denoiser_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, BS, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        aligned, aligned_ave, denoised, denoised_ave, filters = model(burst)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Entropy Loss for Filters
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N)
        filters_entropy = entropyLoss(filters_shaped)
        filters_entropy_coeff = 10.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0  #.933**cfg.global_step if cfg.global_step < 100 else 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_ave_d = denoised_ave.detach()
        losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- regularization scheduler --
        if cfg.global_step < 100: reg = 0.5
        elif cfg.global_step < 200: reg = 0.25
        elif cfg.global_step < 5000: reg = 0.15
        elif cfg.global_step < 10000: reg = 0.1
        else: reg = 0.05

        # -- computation --
        residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        residuals = rearrange(residuals, 'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level)
        # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3)
        rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2.
        rec_ot_pair_coeff = 100  # - .997**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        align_loss = align_mse_coeff * align_mse
        rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse
        entropy_loss = filters_entropy_coeff * filters_entropy
        final_loss = align_loss + rec_loss + entropy_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot_pair.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.denoiser_info.optim.step()
        optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
示例#28
0
def train_loop(cfg, model, noise_critic, optimizer, criterion, train_loader,
               epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #    Setup for epoch
    # -=-=-=-=-=-=-=-=-=-=-

    model.train()
    model = model.to(cfg.device)
    N = cfg.N
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    D = 5 * 10**3
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })
    write_examples = True
    write_examples_iter = 800
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #      Init Record Keeping
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    losses_nc, losses_nc_count = 0, 0
    losses_mse, losses_mse_count = 0, 0
    running_loss, total_loss = 0, 0

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #      Init Loss Functions
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    lossRecMSE = LossRec(tensor_grad=cfg.supervised)
    lossBurstMSE = LossRecBurst(tensor_grad=cfg.supervised)

    # -=-=-=-=-=-=-=-=-=-=-
    #    Final Configs
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    train_iter = iter(train_loader)
    if use_timer: clock = Timer()

    # -=-=-=-=-=-=-=-=-=-=-
    #    GAN Scheduler
    # -=-=-=-=-=-=-=-=-=-=-

    # -- noise critic steps --
    if epoch == 0: disc_steps = 0
    elif epoch < 3: disc_steps = 1
    elif epoch < 10: disc_steps = 1
    else: disc_steps = 1

    # -- denoising steps --
    if epoch == 0: gen_steps = 1
    if epoch < 3: gen_steps = 15
    if epoch < 10: gen_steps = 10
    else: gen_steps = 10

    # -- steps each epoch --
    steps_per_iter = disc_steps * gen_steps
    steps_per_epoch = len(train_loader) // steps_per_iter
    if steps_per_epoch > 120: steps_per_epoch = 120

    # -=-=-=-=-=-=-=-=-=-=-
    #    Start Epoch
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        for gen_step in range(gen_steps):

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Setting up for Iteration
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- setup iteration timer --
            if use_timer: clock.tic()

            # -- zero gradients --
            optimizer.zero_grad()
            model.zero_grad()
            model.denoiser_info.model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            noise_critic.disc.zero_grad()
            noise_critic.optim.zero_grad()

            # -- grab data batch --
            burst, res_imgs, raw_img, directions = next(train_iter)

            # -- getting shapes of data --
            N, BS, C, H, W = burst.shape
            burst = burst.cuda(non_blocking=True)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Formatting Images for FP
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- creating some transforms --
            stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
            cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

            # -- extract target image --
            mid_img = burst[N // 2]
            raw_zm_img = szm(raw_img.cuda(non_blocking=True))
            if cfg.supervised: gt_img = raw_zm_img
            else: gt_img = mid_img

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #           Foward Pass
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            aligned, aligned_ave, denoised, denoised_ave, filters = model(
                burst)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #    MSE (KPN) Reconstruction Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            loss_rec = lossRecMSE(denoised_ave, gt_img)
            loss_burst = lossBurstMSE(denoised, gt_img)
            loss_mse = loss_rec + 100 * loss_burst
            gbs, spe = cfg.global_step, steps_per_epoch
            if epoch < 3: weight_mse = 10
            else: weight_mse = 10 * 0.9999**(gbs - 3 * spe)

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #      Noise Critic Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            loss_nc = noise_critic.compute_residual_loss(denoised, gt_img)
            weight_nc = 1

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #          Final Loss
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            final_loss = weight_mse * loss_mse + weight_nc * loss_nc

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #    Update Info for Record Keeping
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- update alignment kl loss info --
            losses_nc += loss_nc.item()
            losses_nc_count += 1

            # -- update reconstruction kl loss info --
            losses_mse += loss_mse.item()
            losses_mse_count += 1

            # -- update info --
            running_loss += final_loss.item()
            total_loss += final_loss.item()

            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
            #          Backward Pass
            # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

            # -- compute the gradients! --
            final_loss.backward()

            # -- backprop now. --
            model.denoiser_info.optim.step()
            optimizer.step()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #     Iterate for Noise Critic
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        for disc_step in range(disc_steps):

            # -- zero gradients --
            optimizer.zero_grad()
            model.zero_grad()
            model.denoiser_info.optim.zero_grad()
            noise_critic.disc.zero_grad()
            noise_critic.optim.zero_grad()

            # -- grab noisy data --
            _burst, _res_imgs, _raw_img, _directions = next(train_iter)
            _burst = _burst.to(cfg.device)

            # -- generate "fake" data from noisy data --
            _aligned, _aligned_ave, _denoised, _denoised_ave, _filters = model(
                _burst)
            _residuals = _denoised - _burst[N // 2].unsqueeze(1).repeat(
                1, N, 1, 1, 1)

            # -- update discriminator --
            loss_disc = noise_critic.update_disc(_residuals)

            # -- message to stdout --
            first_update = (disc_step == 0)
            last_update = (disc_step == disc_steps - 1)
            iter_update = first_update or last_update
            # if (batch_idx % cfg.log_interval//2) == 0 and batch_idx > 0 and iter_update:
            print(f"[Noise Critic]: {loss_disc}")

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #      Print Message to Stdout
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- init --
            BS = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            for b in range(BS):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(BS, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- write record --
            if use_record:
                record_losses = record_losses.append(
                    {
                        'burst': burst_loss,
                        'ave': ave_loss,
                        'ot': ot_loss,
                        'psnr': psnr,
                        'psnr_std': psnr_std
                    },
                    ignore_index=True)

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- average mse losses --
            ave_losses_mse = losses_mse / losses_mse_count
            losses_mse, losses_mse_count = 0, 0

            # -- average noise critic loss --
            ave_losses_nc = losses_nc / losses_nc_count
            losses_nc, losses_nc_count = 0, 0

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch,
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          ave_losses_mse, ave_losses_nc)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [mse]: %.2e [nc]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, stacked_burst, aligned, denoised, filters,
                               directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1
    total_loss /= len(train_loader)
    return total_loss, record_losses
示例#29
0
def train_loop(cfg, model, train_loader, epoch, record_losses):

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Setup for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    model.align_info.model.train()
    model.denoiser_info.model.train()
    model.unet_info.model.train()
    model.denoiser_info.model = model.denoiser_info.model.to(cfg.device)
    model.align_info.model = model.align_info.model.to(cfg.device)
    model.unet_info.model = model.unet_info.model.to(cfg.device)

    N = cfg.N
    total_loss = 0
    running_loss = 0
    szm = ScaleZeroMean()
    blocksize = 128
    unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize)
    use_record = False
    if record_losses is None:
        record_losses = pd.DataFrame({
            'burst': [],
            'ave': [],
            'ot': [],
            'psnr': [],
            'psnr_std': []
        })

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Record Keeping
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    align_mse_losses, align_mse_count = 0, 0
    align_ot_losses, align_ot_count = 0, 0
    rec_mse_losses, rec_mse_count = 0, 0
    rec_ot_losses, rec_ot_count = 0, 0
    running_loss, total_loss = 0, 0

    write_examples = True
    noise_level = cfg.noise_params['g']['stddev']

    # -=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #    Add hooks for epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-

    align_hook = AlignmentFilterHooks(cfg.N)
    align_hooks = []
    for kpn_module in model.align_info.model.children():
        for name, layer in kpn_module.named_children():
            if name == "filter_cls":
                align_hook_handle = layer.register_forward_hook(align_hook)
                align_hooks.append(align_hook_handle)

    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
    #
    #      Init Loss Functions
    #
    # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

    alignmentLossMSE = BurstRecLoss()
    denoiseLossMSE = BurstRecLoss()
    # denoiseLossOT = BurstResidualLoss()
    entropyLoss = EntropyLoss()

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #    Final Configs
    #
    # -=-=-=-=-=-=-=-=-=-=-

    use_timer = False
    one = torch.FloatTensor([1.]).to(cfg.device)
    switch = True
    if use_timer: clock = Timer()
    train_iter = iter(train_loader)
    steps_per_epoch = len(train_loader)
    write_examples_iter = steps_per_epoch // 2

    # -=-=-=-=-=-=-=-=-=-=-
    #
    #     Start Epoch
    #
    # -=-=-=-=-=-=-=-=-=-=-

    for batch_idx in range(steps_per_epoch):

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Setting up for Iteration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- setup iteration timer --
        if use_timer: clock.tic()

        # -- zero gradients; ready 2 go --
        model.align_info.model.zero_grad()
        model.align_info.optim.zero_grad()
        model.denoiser_info.model.zero_grad()
        model.denoiser_info.optim.zero_grad()
        model.unet_info.model.zero_grad()
        model.unet_info.optim.zero_grad()

        # -- grab data batch --
        burst, res_imgs, raw_img, directions = next(train_iter)

        # -- getting shapes of data --
        N, B, C, H, W = burst.shape
        burst = burst.cuda(non_blocking=True)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #      Formatting Images for FP
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- creating some transforms --
        stacked_burst = rearrange(burst, 'n b c h w -> b n c h w')
        cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w')

        # -- extract target image --
        mid_img = burst[N // 2]
        raw_zm_img = szm(raw_img.cuda(non_blocking=True))
        if cfg.supervised: gt_img = raw_zm_img
        else: gt_img = mid_img

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Check Some Gradients
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        def mse_v_wassersteinG_check_some_gradients(cfg, burst, gt_img, model):
            grads = edict()
            gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            model.unet_info.model.zero_grad()
            burst.requires_grad_(True)

            outputs = model(burst)
            aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]
            residuals = denoised - gt_img_rs
            P = 1.  #residuals.numel()
            denoised.retain_grad()
            rec_mse = (denoised.reshape(B, -1) - gt_img.reshape(B, -1))**2
            rec_mse.retain_grad()
            ones = P * torch.ones_like(rec_mse)
            rec_mse.backward(ones, retain_graph=True)
            grads.rmse = rec_mse.grad.clone().reshape(B, -1)
            grad_rec_mse = grads.rmse
            grads.dmse = denoised.grad.clone().reshape(B, -1)
            grad_denoised_mse = grads.dmse
            ones = torch.ones_like(rec_mse)
            grads.d_to_b = torch.autograd.grad(rec_mse, denoised,
                                               ones)[0].reshape(B, -1)

            model.unet_info.model.zero_grad()
            outputs = model(burst)
            aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
            aligned_filters, denoised_filters = outputs[4:]
            # residuals = denoised - gt_img_rs
            # rec_ot = w_gaussian_bp(residuals,noise_level)
            denoised.retain_grad()
            rec_ot_v = (denoised - gt_img_rs)**2
            rec_ot_v.retain_grad()
            rec_ot = (rec_ot_v.mean() - noise_level / 255.)**2
            rec_ot.retain_grad()
            ones = P * torch.ones_like(rec_ot)
            rec_ot.backward(ones)
            grad_denoised_ot = denoised.grad.clone().reshape(B, -1)
            grads.dot = grad_denoised_ot
            grad_rec_ot = rec_ot_v.grad.clone().reshape(B, -1)
            grads.rot = grad_denoised_ot

            print("Gradient Name Info")
            for name, g in grads.items():
                g_norm = g.norm().item()
                g_mean = g.mean().item()
                g_std = g.std().item()
                print(name, g.shape, g_norm, g_mean, g_std)

            print_pairs = False
            if print_pairs:
                print("All Gradient Ratios")
                for name_t, g_t in grads.items():
                    for name_b, g_b in grads.items():
                        ratio = g_t / g_b
                        ratio_m = ratio.mean().item()
                        ratio_std = ratio.std().item()
                        print("[%s/%s] [%2.2e +/- %2.2e]" %
                              (name_t, name_b, ratio_m, ratio_std))

            use_true_mse = False
            if use_true_mse:
                print("Ratios with Estimated MSE Gradient")
                true_dmse = 2 * torch.mean(denoised_ave - gt_img)**2
                ratio_mse = grads.dmse / true_dmse
                ratio_mse_dtb = grads.dmse / grads.d_to_b
                print(ratio_mse)
                print(ratio_mse_dtb)

            dot_v_dmse = True
            if dot_v_dmse:
                print("Ratio of Denoised OT and Denoised MSE")
                ratio_mseot = (grads.dmse / grads.dot)
                print(ratio_mseot.mean(), ratio_mseot.std())
                ratio_mseot = ratio_mseot[0, 0].item()

                c1 = torch.mean((denoised - gt_img_rs)**2).item()
                c2 = noise_level / 255
                m = torch.mean(gt_img_rs).item()
                true_ratio = 2. * (c1 - c2) / (np.product(burst.shape))
                # diff = denoised.reshape(B,-1)-gt_img_rs.reshape(B,-1)
                # true_ratio = 2.*(c1 - c2) * ( diff / ( np.product(burst.shape) ) )
                # print(c1,c2,m,true_ratio,1./true_ratio)
                ratio_mseot = (grads.dmse / (grads.dot))
                print(ratio_mseot * true_ratio)

                # ratio_mseot = (grads.dmse / ( grads.dot / diff) )
                # print(ratio_mseot*true_ratio)
                # print(ratio_mseot.mean(),ratio_mseot.std())

            exit()
            model.unet_info.model.zero_grad()

        # mse_v_wassersteinG_check_some_gradients(cfg,burst,gt_img,model)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #           Foward Pass
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        outputs = model(burst)
        aligned, aligned_ave, denoised, denoised_ave = outputs[:4]
        aligned_filters, denoised_filters = outputs[4:]

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms (aligned)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        aligned_filters_rs = rearrange(aligned_filters,
                                       'b n k2 c h w -> b n (k2 c h w)')
        norms = torch.norm(aligned_filters_rs, p=2., dim=2)
        norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N)
        norm_loss_align = torch.mean(
            torch.pow(torch.abs(norms - norms_mid), 1.))

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Require Approx Equal Filter Norms (denoised)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        denoised_filters = rearrange(denoised_filters,
                                     'b n k2 c h w -> b n (k2 c h w)')
        norms = torch.norm(denoised_filters, p=2., dim=2)
        norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N)
        norm_loss_denoiser = torch.mean(
            torch.pow(torch.abs(norms - norms_mid), 1.))
        norm_loss_coeff = 0.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Decrease Entropy within a Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_entropy = 0
        filters_entropy_coeff = 0.  # 1000.
        all_filters = []
        L = len(align_hook.filters)
        iter_filters = align_hook.filters if L > 0 else [aligned_filters]
        for filters in iter_filters:
            filters_shaped = rearrange(filters,
                                       'b n k2 c h w -> (b n c h w) k2',
                                       n=N)
            filters_entropy += entropyLoss(filters_shaped)
            all_filters.append(filters)
        if L > 0: filters_entropy /= L
        all_filters = torch.stack(all_filters, dim=1)
        align_hook.clear()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Increase Entropy across each Kernel
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        filters_dist_entropy = 0

        # -- across each frame --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each batch --
        filters_shaped = rearrange(all_filters,
                                   'b l n k2 c h w -> (n l) (b c h w) k2')
        filters_shaped = torch.mean(filters_shaped, dim=1)
        filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        # -- across each kpn cascade --
        # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2')
        # filters_shaped = torch.mean(filters_shaped,dim=1)
        # filters_dist_entropy += -1 * entropyLoss(filters_shaped)

        filters_dist_coeff = 0

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Alignment Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = alignmentLossMSE(aligned, aligned_ave, gt_img,
                                  cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        align_mse = np.sum(losses)
        align_mse_coeff = 0.  #0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Alignment Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # pad = 2*cfg.N
        # fs = cfg.dynamic.frame_size
        residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # centered_residuals = tvF.center_crop(residuals,(fs-pad,fs-pad))
        # centered_residuals = tvF.center_crop(residuals,(fs//2,fs//2))
        # align_ot = kl_gaussian_bp(residuals,noise_level,flip=True)
        align_ot = kl_gaussian_bp_patches(residuals,
                                          noise_level,
                                          flip=True,
                                          patchsize=16)
        align_ot_coeff = 0  # 100.

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #   Reconstruction Losses (MSE)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        losses = denoiseLossMSE(denoised, denoised_ave, gt_img,
                                cfg.global_step)
        ave_loss, burst_loss = [loss.item() for loss in losses]
        rec_mse = np.sum(losses)
        rec_mse_coeff = 0.95**cfg.global_step

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #    Reconstruction Losses (Distribution)
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- computation --
        gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
        # rec_ot = kl_gaussian_bp(residuals,noise_level)
        rec_ot = kl_gaussian_bp(residuals, noise_level, flip=True)
        # rec_ot /= 2.
        # alpha_grid = [0.,1.,5.,10.,25.]
        # for alpha in alpha_grid:
        #     # residuals = torch.normal(torch.zeros_like(residuals)+ gt_img_rs*alpha/255.,noise_level/255.)
        #     residuals = torch.normal(torch.zeros_like(residuals),noise_level/255.+ gt_img_rs*alpha/255.)

        #     rec_ot_v2_a = kl_gaussian_bp_patches(residuals,noise_level,patchsize=16)
        #     rec_ot_v1_b = kl_gaussian_bp(residuals,noise_level,flip=True)
        #     rec_ot_v2_b = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16)
        #     rec_ot_all = torch.tensor([rec_ot_v1_a,rec_ot_v2_a,rec_ot_v1_b,rec_ot_v2_b])

        #     rec_ot_v2 = (rec_ot_v2_a + rec_ot_v2_b).item()/2.
        #     print(alpha,torch.min(rec_ot_all),torch.max(rec_ot_all),rec_ot_v1,rec_ot_v2)
        # exit()
        # rec_ot = w_gaussian_bp(residuals,noise_level)
        # print(residuals.numel())
        rec_ot_coeff = 100.  #residuals.numel()*2.
        # 1000.# - .997**cfg.global_step

        # residuals = rearrange(residuals,'b n c h w -> b n (h w) c')
        # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level)
        # rec_ot_loss_v1 = kl_gaussian_bp(residuals,noise_level,flip=True)
        # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals)
        # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg)
        # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3)
        # rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device)
        # rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Final Losses
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse
        norm_loss = norm_loss_coeff * (norm_loss_denoiser + norm_loss_align)
        align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot
        entropy_loss = 0  #filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy
        # final_loss = align_loss + rec_loss + entropy_loss + norm_loss
        final_loss = rec_loss

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #              Record Keeping
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- alignment MSE --
        align_mse_losses += align_mse.item()
        align_mse_count += 1

        # -- alignment Dist --
        align_ot_losses += align_ot.item()
        align_ot_count += 1

        # -- reconstruction MSE --
        rec_mse_losses += rec_mse.item()
        rec_mse_count += 1

        # -- reconstruction Dist. --
        rec_ot_losses += rec_ot.item()
        rec_ot_count += 1

        # -- total loss --
        running_loss += final_loss.item()
        total_loss += final_loss.item()

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #        Gradients & Backpropogration
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        # -- compute the gradients! --
        final_loss.backward()

        # -- backprop now. --
        model.align_info.optim.step()
        model.denoiser_info.optim.step()
        model.unet_info.optim.step()

        # for name,params in model.unet_info.model.named_parameters():
        #     if not ("weight" in name): continue
        #     print(params.grad.norm())
        #     # print(module.conv1.parameters())
        #     # print(module.conv1.data.grad)

        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
        #
        #            Printing to Stdout
        #
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

        if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0:

            # -- compute mse for fun --
            B = raw_img.shape[0]
            raw_img = raw_img.cuda(non_blocking=True)

            # -- psnr for [average of aligned frames] --
            mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_aligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [average of input, misaligned frames] --
            mis_ave = torch.mean(stacked_burst, dim=1)
            mse_loss = F.mse_loss(raw_img, mis_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_misaligned_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [bm3d] --
            bm3d_nb_psnrs = []
            M = 10 if B > 10 else B
            for b in range(B):
                bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5,
                                     sigma_psd=noise_level / 255,
                                     stage_arg=bm3d.BM3DStages.ALL_STAGES)
                bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2)
                b_loss = F.mse_loss(raw_img[b].cpu(),
                                    bm3d_rec,
                                    reduction='none').reshape(1, -1)
                b_loss = torch.mean(b_loss, 1).detach().cpu().numpy()
                bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss))
                bm3d_nb_psnrs.append(bm3d_nb_psnr)
            bm3d_nb_ave = np.mean(bm3d_nb_psnrs)
            bm3d_nb_std = np.std(bm3d_nb_psnrs)

            # -- psnr for aligned + denoised --
            raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1)
            mse_loss = F.mse_loss(raw_img_repN,
                                  denoised + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss))
            psnr_denoised_std = np.std(mse_to_psnr(mse_loss))

            # -- psnr for [model output image] --
            mse_loss = F.mse_loss(raw_img,
                                  denoised_ave + 0.5,
                                  reduction='none').reshape(B, -1)
            mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy()
            psnr = np.mean(mse_to_psnr(mse_loss))
            psnr_std = np.std(mse_to_psnr(mse_loss))

            # -- update losses --
            running_loss /= cfg.log_interval

            # -- alignment MSE --
            align_mse_ave = align_mse_losses / align_mse_count
            align_mse_losses, align_mse_count = 0, 0

            # -- alignment Dist. --
            align_ot_ave = align_ot_losses / align_ot_count
            align_ot_losses, align_ot_count = 0, 0

            # -- reconstruction MSE --
            rec_mse_ave = rec_mse_losses / rec_mse_count
            rec_mse_losses, rec_mse_count = 0, 0

            # -- reconstruction Dist. --
            rec_ot_ave = rec_ot_losses / rec_ot_count
            rec_ot_losses, rec_ot_count = 0, 0

            # -- write record --
            if use_record:
                info = {
                    'burst': burst_loss,
                    'ave': ave_loss,
                    'ot': rec_ot_ave,
                    'psnr': psnr,
                    'psnr_std': psnr_std
                }
                record_losses = record_losses.append(info, ignore_index=True)

            # -- write to stdout --
            write_info = (epoch, cfg.epochs, batch_idx, len(train_loader),
                          running_loss, psnr, psnr_std, psnr_denoised_ave,
                          psnr_denoised_std, psnr_aligned_ave,
                          psnr_aligned_std, psnr_misaligned_ave,
                          psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std,
                          rec_mse_ave, rec_ot_ave)
            print(
                "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e"
                % write_info)
            running_loss = 0

        # -- write examples --
        if write_examples and (batch_idx % write_examples_iter) == 0 and (
                batch_idx > 0 or cfg.global_step == 0):
            write_input_output(cfg, model, stacked_burst, aligned, denoised,
                               all_filters, directions)

        if use_timer: clock.toc()
        if use_timer: print(clock)
        cfg.global_step += 1

    # -- remove hooks --
    for hook in align_hooks:
        hook.remove()

    total_loss /= len(train_loader)
    return total_loss, record_losses
示例#30
0
 def operate_Bm3d(self, image_noisy, mode, qp):
     bm3d.bm3d(image_noisy, sigma_psd=self.getPsd(mode, qp))