def primaldual( y, OpA, OpW, c0, eta, y0=None, iter=20, sigma=0.5, tau=0.5, theta=1.0, silent=False, report_pd_gap=False, ): """ Primal Dual algorithm. Reconstruction algorithm for the inverse problem y = Ax + e under the synthesis model x=Wc for some sparse coefficients c and ||e||_2 <= eta. min ||c||_1 s.t. ||AWc - y||_2 <= eta Basic iteration steps are 1) y_ = prox_l2_constraint_conjugate(y_+sig*AWc_, sigma*y, sigma*eta) 2) c = shrink(c-tau*W'A'y_, tau) 3) c_ = c + theta*(c - cold) Parameters ---------- y : torch.Tensor The measurement vector. OpA : operators.LinearOperator The measurement operator, providing A and A'. OpW : operators.LinearOperator The synthesis operator, providing W and W'. c0 : torchTensor Initial guess for the coefficients, typically torch.zeros(...) of appropriate size. y0 : torchTensor Initial guess for the dual variable, typically zeros(...) eta : float The measurement noise level, specifying the constraint. iter : int, optional Number of primal dual iterations. (Default 20) sigma : float, optional Step size parameter, should satisfy sigma*tau*||AW||_2^2 < 1. (Default 0.5) tau : float, optional Step size parameter, should satisfy sigma*tau*||AW||_2^2 < 1. (Default 0.5) theta : float, optional DR parameter, arbitrary in [0,1]. (Default 0.5) silent : bool, optional Disable progress bar. (Default False) report_pd_gap : bool, optional Report pd-gap at the end. Returns ------- tensor The recovered signal x=Wc and coefficients c. """ # we do not explicitly check for sig*tau*||AW||_2^2 < 1 and trust that # the user read the documentation ;) if y0 is None: y0 = torch.zeros_like(y) # helper functions for primal-dual gap def F(_y): return ((_y - y).norm(p=2, dim=-1) > (eta + 1e-2)) * 1e4 def Fstar(_y): return eta * _y.norm(p=2, dim=-1) + (y * _y).sum(dim=-1) def Gstar(_y): return ((torch.max(torch.abs(_y), dim=-1))[0] > (1.0 + 1e-2)) * 1e4 # init iteration variables c = c0.clone() c_ = c.clone() y_ = y0.clone() # run main primal dual iterations for it in tqdm(range(iter), desc="Primal-Dual iterations", disable=silent): # primal dual step 1) y_ = prox_l2_constraint_conjugate( y_ + sigma * OpA(OpW(c_)), sigma * y, sigma * eta ) # primal dual step 2) cold, c = c, shrink(c - tau * OpW.adj(OpA.adj(y_)), tau) # primal dual step 3 c_ = c + theta * (c - cold) # compute primal dual gap if report_pd_gap: E = ( F(OpA(OpW(c_))) + c_.abs().sum(dim=-1) + Fstar(y_) + Gstar(-OpW.adj(OpA.adj(y_))) ) print("\n\n Primal Dual Gap: \t {:1.4e} \n\n".format(E.abs().max())) return OpW(c_), c_, y_
def admm_l1_rec_diag( y, OpA, OpW, x0, z0, lam, rho, iter=20, silent=False, ): """ ADMM for least squares solve with L1 regularization. Reconstruction algorithm for the inverse problem y = Ax + e under the analysis model z=Wx for sparse analysis coefficients z. min ||Ax - y||^2_2 + lambda ||Wx||_1 Note: it assumes that A'*A is diagonalizable in k-space. Parameters ---------- y : torch.Tensor The measurement vector. OpA : operators.LinearOperator The measurement operator, providing A and A'. OpW : operators.LinearOperator The analysis operator, providing W and W'. x0 : torchTensor Initial guess for the signal, typically torch.zeros(...) of appropriate size. z0 : torchTensor Initial guess for the coefficients, typically torch.zeros(...) of appropriate size. lam : float The regularization parameter lambda for the sparsity constraint. rho : float The Lagrangian augmentation parameter for the ADMM algorithm. iter : int, optional Number of ADMM iterations. (Default 20) silent : bool, optional Disable progress bar. (Default False) Returns ------- tensor The recovered signal x. """ # init iteration variables z = z0.clone() x = x0.clone() u = torch.zeros_like(z0) tv_kernel = OpW.get_fourier_kernel() # run main ADMM iterations t = tqdm(range(iter), desc="ADMM iterations", disable=silent) for it in t: # ADMM step 1) : signal update rhs = OpA.adj(y) + rho * OpW.adj(z - u) x = OpA.tikh(rhs, tv_kernel, rho) # ADMM step 2) : coefficient update zold, z = z, shrink(OpW(x) + u, lam / rho) # ADMM step 3 : dual variable update u = u + OpW(x) - z # evaluate with torch.no_grad(): loss = (0.5 * (OpA(x) - y).pow(2).sum(dim=(-1, -2)) + lam * OpW(x).abs().sum((-1, -2))).mean() primal_residual = OpW(x) - z dual_residual = rho * OpW.adj(zold - z) t.set_postfix( loss=loss.item(), pres=torch.norm(primal_residual).item(), dres=torch.norm(dual_residual).item(), ) return x, z
def admm_l1_rec( y, OpA, OpW, x0, z0, lam, rho, iter=20, silent=False, timeout=None, ): """ ADMM for least squares solve with L1 regularization. Reconstruction algorithm for the inverse problem y = Ax + e under the analysis model z=Wx for sparse analysis coefficients z. min ||Ax - y||^2_2 + lambda ||Wx||_1 Parameters ---------- y : torch.Tensor The measurement vector. OpA : operators.LinearOperator The measurement operator, providing A and A'. OpW : operators.LinearOperator The analysis operator, providing W and W'. x0 : torchTensor Initial guess for the signal, typically torch.zeros(...) of appropriate size. z0 : torchTensor Initial guess for the coefficients, typically torch.zeros(...) of appropriate size. lam : float The regularization parameter lambda for the sparsity constraint. rho : float The Lagrangian augmentation parameter for the ADMM algorithm. iter : int, optional Number of ADMM iterations. (Default 20) silent : bool, optional Disable progress bar. (Default False) timeout : int, optional Set runtime limit in seconds. (Default None) Returns ------- tensor The recovered signal x. """ # init iteration variables z = z0.clone() x = x0.clone() u = torch.zeros_like(z0) # prepare conjugate gradient inversion inverter = CGInverterLayer( x.shape[1:], lambda x: OpA.adj(OpA(x)) + rho * OpW.adj(OpW(x)), rtol=1e-6, atol=0.0, maxiter=200, ) if timeout is not None: start_time = time.time() # run main ADMM iterations t = tqdm(range(iter), desc="ADMM iterations", disable=silent) for it in t: if timeout is not None: if time.time() > start_time + timeout: print("ADMM aborted due to timeout") return x, z # ADMM step 1) : signal update rhs = OpA.adj(y) + rho * OpW.adj(z - u) x = inverter(rhs, x) # ADMM step 2) : coefficient update zold, z = z, shrink(OpW(x) + u, lam / rho) # ADMM step 3 : dual variable update u = u + OpW(x) - z # evaluate with torch.no_grad(): loss = (0.5 * (OpA(x) - y).pow(2).sum(dim=(-1, -2)) + lam * OpW(x).abs().sum((-1, -2))).mean() primal_residual = OpW(x) - z dual_residual = rho * OpW.adj(zold - z) t.set_postfix( loss=loss.item(), pres=torch.norm(primal_residual).item(), dres=torch.norm(dual_residual).item(), ) return x, z