Ejemplo n.º 1
0
def ex_ITCA_denoise(fn, thresh_pts, wname, levels, noise_std):
    x = utils.imgLoad(fn)[:, :, 128:128 + 64, 128:128 + 64]
    # get stationary wavelet transform
    _, ISWT = wvlt.getop_wavelet(wname, J=levels, stationary=True)
    # remove low-pass from penalty objective function (for printing purposes only)
    obj2 = lambda x, mu: mu*(torch.flatten(torch.norm(x[:-1],1)).sum() \
                           + torch.flatten(torch.norm(x[-1][:,1:],1)).sum())
    y = utils.awgn(x, noise_std)
    mu0 = torch.flatten(ISWT.adjoint(y)).max().item()
    print(f"mu0 = {mu0:.2e}")
    what, hist = solvers.ITCA(ISWT,
                              ISWT.adjoint,
                              y,
                              mu=mu0,
                              gamma=0.8,
                              K=50,
                              max_iter=10,
                              verbose=True,
                              threshfun=wvlt.wthresh,
                              obj2=obj2)
    xhat = ISWT(what)
    PSNR = lambda v: -10 * np.log10(torch.mean((x - v)**2).item())
    utils.visplot(torch.cat([y, xhat, x]),
                  titles=[
                      f"Noisy, {PSNR(y):.2f} dB",
                      f"Denoised, {PSNR(xhat):.2f} dB", "Ground Truth"
                  ])
    plt.figure()
    plt.semilogy(hist['total'], '-b')
    plt.xlabel("iterations")
    plt.ylabel("BPDN functional")
    plt.show()
Ejemplo n.º 2
0
def main(args):
    model_args, train_args, paths = [
        args[item] for item in ['model', 'train', 'paths']
    ]
    ngpu = torch.cuda.device_count()
    device = torch.device("cuda:0" if ngpu > 0 else "cpu")
    print(f"Using device {device}.")
    if ngpu > 1:
        print(f"Using {ngpu} GPUs.")
        data_parallel = True
    else:
        data_parallel = False
    model, _, _, epoch0 = initModel(args, device=device)
    model.eval()

    #plotCurves(args, epoch0, show=False, save=True)
    test(args, model, device=device)

    #Dconv = model.D.Conv.weight.data.transpose(0,1)
    #visual.visplot(Dconv.cpu(), (8,8))

    x = utils.imgLoad("Set12/04.png", gray=True).to(device)
    y = utils.awgn(x, 25)
    with torch.no_grad():
        xhat, edge_list = model(y, ret_edge=True)
    psnr = (-10 * torch.log10(torch.mean((x - xhat)**2))).item()
    print(f"PSNR = {psnr:.2f}")
    fig1 = visual.visplot(torch.cat([y, xhat, x]).cpu())
    if edge_list[0] is not None:
        fig2, handler = visual.visplotNeighbors(xhat.cpu(),
                                                edge_list[0].cpu(),
                                                local_area=None,
                                                depth=3)
    plt.show()
Ejemplo n.º 3
0
def ex_ITA_denoise(fn, thresh_pts, wname, levels, noise_std):
    x = utils.imgLoad(fn)
    # get stationary wavelet transform
    _, ISWT = wvlt.getop_wavelet(wname, J=levels, stationary=False)
    # remove low-pass from penalty objective function (for printing purposes only)
    obj2 = lambda x, mu: mu*(torch.flatten(torch.norm(x[:-1],1)).sum() \
                           + torch.flatten(torch.norm(x[-1][:,1:],1)).sum())
    # precompute Lipschitz-constant
    L = solvers.powerMethod(lambda v: ISWT.adjoint(ISWT(v)),
                            torch.randn_like(ISWT.adjoint(
                                x[:, :, :128, :128])),
                            max_iter=100,
                            verbose=True)[0]
    denoise = lambda y, t: ISWT(
        solvers.FITA(ISWT,
                     ISWT.adjoint,
                     y,
                     L=L,
                     mu=t,
                     max_iter=100,
                     verbose=True,
                     threshfun=wvlt.wthresh,
                     obj2=obj2)[0])
    # find optimal threshold on image patch
    thresh = find_threshold(x[:, :, 128:128 + 128, 128:128 + 128],
                            denoise,
                            noise_std,
                            thresh_pts,
                            show=False)
    # denoise with optimal threshold on full image
    y = utils.awgn(x, noise_std)
    what, hist = solvers.FITA(ISWT,
                              ISWT.adjoint,
                              y,
                              L=L,
                              mu=thresh,
                              max_iter=500,
                              verbose=True,
                              threshfun=wvlt.wthresh,
                              obj2=obj2)
    xhat = ISWT(what)
    PSNR = lambda v: -10 * np.log10(torch.mean((x - v)**2).item())
    utils.visplot(torch.cat([y, xhat, x]),
                  titles=[
                      f"Noisy, {PSNR(y):.2f} dB",
                      f"Denoised, {PSNR(xhat):.2f} dB", "Ground Truth"
                  ])
    plt.figure()
    plt.semilogy(hist['total'], '-b')
    plt.xlabel("iterations")
    plt.ylabel("BPDN functional")
    plt.show()
Ejemplo n.º 4
0
def main():
    """ Verify that conv.Conv2d and conv.TreeConv2d are implemented 
	correctly by performing a discrete wavelet transform with them.
	DWT, IDWT = getop_wavelet('db3',J=5,stationary=False)
	x = torch.cat([utils.imgLoad(f"Set12/{i:02d}.png") for i in [1]])
	z = DWT(x)
	print("z.shape =", z.shape)
	z = wthresh(z, 0)
	for zz in z:
		fig = utils.visplot(zz.transpose(0,1).abs()); fig.show()
	xhat = IDWT(z)
	err = torch.mean((x-xhat)**2)
	print(f"MSE = {err:.2e}")
	fig = utils.visplot(xhat); fig.show()
	input()
	"""
    Wa, Ws = filter_bank_2D('db3')
    ks = Wa.shape[-1]
    W1 = torch.zeros(4, 1, 7, 7)
    W1[:, :, :-1, :-1] = Wa
    W2 = torch.zeros(4, 1, 7, 7)
    W2[:, :, 1:, 1:] = Wa
    p1 = int(np.floor((ks) / 2))
    p2 = int(np.ceil((ks) / 2))
    DWT = torch.nn.Conv2d(1, 4, ks, stride=2, padding=p1, bias=False)
    DWT.weight.data = W1
    IDWT = torch.nn.ConvTranspose2d(4,
                                    1,
                                    ks,
                                    stride=2,
                                    padding=p1,
                                    output_padding=1,
                                    bias=False)
    IDWT.weight.data = W1
    x = torch.cat([utils.imgLoad(f"Set12/{i:02d}.png") for i in [1]])
    z = DWT(x)
    print("z.shape =", z.shape)
    xhat = IDWT(z)
    err = torch.mean(((x - xhat)[:, :, 20:-20, 20:-20])**2)
    print(f"MSE = {err:.2e}")
    fig = utils.visplot(xhat[:, :, 10:-10, 10:-10])
    fig.show()
    fig = utils.visplot(z.transpose(0, 1))
    fig.show()
    input()
Ejemplo n.º 5
0
def main():
    DWT, IDWT = wvlt.getop_wavelet('db3', J=3, stationary=False)
    img = utils.imgLoad("/home/nikopj/doc/pic/film/7480/74800029.JPG")
    gry = img.mean(dim=1, keepdim=True)
    print("img.shape =", img.shape)
    print("gry.shape =", gry.shape)
    z = DWT(gry)
    gry = z[-1][:, 0:1]
    print("gry.shape =", gry.shape)
    m, n = gry.shape[2:]
    DWT, IDWT = wvlt.getop_wavelet('db3', J=4, stationary=True)
    for ii in range(300):
        print("ii =", ii)
        z = DWT(gry)
        dx, dy = z[-1][0, 1].abs().numpy(), z[-1][0, 2].abs().numpy()
        bp = shortest_path(dx + dy)
        fltidx = bp + n * np.arange(m)
        gry = np.delete(gry, fltidx).reshape(1, 1, m, n - 1)
        n -= 1
    plt.imshow(gry[0, 0])
    plt.show()
Ejemplo n.º 6
0
def ex_wthresh_denoise(fn, thresh_pts, wname, levels, noise_std):
    x = utils.imgLoad(fn)
    DWT, IDWT = wvlt.getop_wavelet(wname, J=levels, stationary=False)
    denoise = lambda y, t: IDWT(wvlt.wthresh(DWT(y), t))
    thresh = find_threshold(x, denoise, noise_std, thresh_pts, show=True)