def inpaint_tv( img, mask, mu=np.float32(1e-2), gamma=np.float32(1e-1), tol=np.float32(1e-4), max_iter=1000, verbose=False, ): """ Total Variation Inpainting with Split Bregman method. """ def masked_laplacian(x): px = convolve(x, dx) py = convolve(x, dy) qx = convolve(mask * px, dxT) qy = convolve(mask * py, dyT) return qx + qy def masked_divergence(px, py): qx = convolve(mask * px, dxT) qy = convolve(mask * py, dyT) return qx + qy def A(x): return (1 - mask) * x + gamma * masked_laplacian(x) d1, d2, b1, b2, uf, ul = [npcl.zeros_like(img) for _ in range(6)] img_norm = npcl.sum(img**2) for k in range(max_iter): # u-subproblem b = (1 - mask) * img + gamma * masked_divergence(d1 - b1, d2 - b2) ul, k_sub = solve_cg(A, b, uf, max_iter=10) # d-subproblem u1 = convolve(ul, dx) u2 = convolve(ul, dy) d1, d2 = shrink(u1 + b1, u2 + b2, mu * mask / gamma) b1 = b1 + u1 - d1 b2 = b2 + u2 - d2 gap = npcl.sum((ul - uf)**2) / img_norm if verbose is True: print( 'iteration number: ', k + 1, ', gap: ', gap.get(), ) if gap < tol**2: break uf = ul.copy() return uf, k
def denoise_tv(image, weight=0.1, eps=2.e-4, n_iter_max=100): img_dev = image.copy() ndim = 2 weight = np.float32(weight) eps = np.float32(eps) px = npcl.zeros_like(image) py = npcl.zeros_like(image) d = npcl.zeros_like(image) tau = np.float32(1/(2.*ndim)) N = np.float32(img_dev.shape[0]*img_dev.shape[1]) i = 0 while i < n_iter_max: if i > 0: # d will be the (negative) divergence of p d = divergence2d(px, py) d = -d out = img_dev + d else: out = img_dev E = npcl.sum((d ** 2)).get() # (gx, gy) stores the gradients of out along each axis gx, gy = grad2d(out) norm = norm2d(gx, gy) E += weight*npcl.sum(norm).get() norm *= tau/weight norm += np.float32(1) px = px-tau*gx py = py-tau*gy px /= norm py /= norm E /= N if i == 0: E_init = E E_previous = E else: if np.abs(E_previous-E) < eps * E_init: break else: E_previous = E i += 1 return out
def restart(y, x, xold): return npcl.sum((y - x) * (x - xold)).get() >= 0
def norms(x): return npcl.sum(x**2).get()
def norm(x): return npcl.sum(npcl.fabs(x)).get()