Exemplo n.º 1
0
def primaldual(
    y,
    OpA,
    OpW,
    c0,
    eta,
    y0=None,
    iter=20,
    sigma=0.5,
    tau=0.5,
    theta=1.0,
    silent=False,
    report_pd_gap=False,
):
    """ Primal Dual algorithm.

    Reconstruction algorithm for the inverse problem y = Ax + e under the
    synthesis model x=Wc for some sparse coefficients c and ||e||_2 <= eta.

    min ||c||_1     s.t.    ||AWc - y||_2 <= eta

    Basic iteration steps are

    1) y_ = prox_l2_constraint_conjugate(y_+sig*AWc_, sigma*y, sigma*eta)
    2) c = shrink(c-tau*W'A'y_, tau)
    3) c_ = c + theta*(c - cold)


    Parameters
    ----------
    y : torch.Tensor
        The measurement vector.
    OpA : operators.LinearOperator
        The measurement operator, providing A and A'.
    OpW : operators.LinearOperator
        The synthesis operator, providing W and W'.
    c0 : torchTensor
        Initial guess for the coefficients, typically torch.zeros(...) of
        appropriate size.
    y0 : torchTensor
        Initial guess for the dual variable, typically zeros(...)
    eta : float
        The measurement noise level, specifying the constraint.
    iter : int, optional
        Number of primal dual iterations. (Default 20)
    sigma : float, optional
        Step size parameter, should satisfy sigma*tau*||AW||_2^2 < 1.
        (Default 0.5)
    tau : float, optional
        Step size parameter, should satisfy sigma*tau*||AW||_2^2 < 1.
        (Default 0.5)
    theta : float, optional
        DR parameter, arbitrary in [0,1]. (Default 0.5)
    silent : bool, optional
        Disable progress bar. (Default False)
    report_pd_gap : bool, optional
        Report pd-gap at the end.

    Returns
    -------
    tensor
       The recovered signal x=Wc and coefficients c.
    """

    # we do not explicitly check for sig*tau*||AW||_2^2 < 1 and trust that
    # the user read the documentation ;)

    if y0 is None:
        y0 = torch.zeros_like(y)

    # helper functions for primal-dual gap
    def F(_y):
        return ((_y - y).norm(p=2, dim=-1) > (eta + 1e-2)) * 1e4

    def Fstar(_y):
        return eta * _y.norm(p=2, dim=-1) + (y * _y).sum(dim=-1)

    def Gstar(_y):
        return ((torch.max(torch.abs(_y), dim=-1))[0] > (1.0 + 1e-2)) * 1e4

    # init iteration variables
    c = c0.clone()
    c_ = c.clone()
    y_ = y0.clone()

    # run main primal dual iterations
    for it in tqdm(range(iter), desc="Primal-Dual iterations", disable=silent):

        # primal dual step 1)
        y_ = prox_l2_constraint_conjugate(
            y_ + sigma * OpA(OpW(c_)), sigma * y, sigma * eta
        )
        # primal dual step 2)
        cold, c = c, shrink(c - tau * OpW.adj(OpA.adj(y_)), tau)

        # primal dual step 3
        c_ = c + theta * (c - cold)

    # compute primal dual gap
    if report_pd_gap:
        E = (
            F(OpA(OpW(c_)))
            + c_.abs().sum(dim=-1)
            + Fstar(y_)
            + Gstar(-OpW.adj(OpA.adj(y_)))
        )
        print("\n\n Primal Dual Gap: \t {:1.4e} \n\n".format(E.abs().max()))

    return OpW(c_), c_, y_
Exemplo n.º 2
0
def admm_l1_rec_diag(
    y,
    OpA,
    OpW,
    x0,
    z0,
    lam,
    rho,
    iter=20,
    silent=False,
):
    """ ADMM for least squares solve with L1 regularization.

    Reconstruction algorithm for the inverse problem y = Ax + e under the
    analysis model z=Wx for sparse analysis coefficients z.

    min ||Ax - y||^2_2 + lambda ||Wx||_1

    Note: it assumes that A'*A is diagonalizable in k-space.

    Parameters
    ----------
    y : torch.Tensor
        The measurement vector.
    OpA : operators.LinearOperator
        The measurement operator, providing A and A'.
    OpW : operators.LinearOperator
        The analysis operator, providing W and W'.
    x0 : torchTensor
        Initial guess for the signal, typically torch.zeros(...) of
        appropriate size.
    z0 : torchTensor
        Initial guess for the coefficients, typically torch.zeros(...) of
        appropriate size.
    lam : float
        The regularization parameter lambda for the sparsity constraint.
    rho : float
        The Lagrangian augmentation parameter for the ADMM algorithm.
    iter : int, optional
        Number of ADMM iterations. (Default 20)
    silent : bool, optional
        Disable progress bar. (Default False)

    Returns
    -------
    tensor
       The recovered signal x.
    """

    # init iteration variables
    z = z0.clone()
    x = x0.clone()
    u = torch.zeros_like(z0)
    tv_kernel = OpW.get_fourier_kernel()

    # run main ADMM iterations
    t = tqdm(range(iter), desc="ADMM iterations", disable=silent)
    for it in t:
        # ADMM step 1) : signal update
        rhs = OpA.adj(y) + rho * OpW.adj(z - u)
        x = OpA.tikh(rhs, tv_kernel, rho)

        # ADMM step 2) : coefficient update
        zold, z = z, shrink(OpW(x) + u, lam / rho)

        # ADMM step 3 : dual variable update
        u = u + OpW(x) - z

        # evaluate
        with torch.no_grad():
            loss = (0.5 * (OpA(x) - y).pow(2).sum(dim=(-1, -2)) +
                    lam * OpW(x).abs().sum((-1, -2))).mean()
            primal_residual = OpW(x) - z
            dual_residual = rho * OpW.adj(zold - z)

            t.set_postfix(
                loss=loss.item(),
                pres=torch.norm(primal_residual).item(),
                dres=torch.norm(dual_residual).item(),
            )

    return x, z
Exemplo n.º 3
0
def admm_l1_rec(
    y,
    OpA,
    OpW,
    x0,
    z0,
    lam,
    rho,
    iter=20,
    silent=False,
    timeout=None,
):
    """ ADMM for least squares solve with L1 regularization.

    Reconstruction algorithm for the inverse problem y = Ax + e under the
    analysis model z=Wx for sparse analysis coefficients z.

    min ||Ax - y||^2_2 + lambda ||Wx||_1

    Parameters
    ----------
    y : torch.Tensor
        The measurement vector.
    OpA : operators.LinearOperator
        The measurement operator, providing A and A'.
    OpW : operators.LinearOperator
        The analysis operator, providing W and W'.
    x0 : torchTensor
        Initial guess for the signal, typically torch.zeros(...) of
        appropriate size.
    z0 : torchTensor
        Initial guess for the coefficients, typically torch.zeros(...) of
        appropriate size.
    lam : float
        The regularization parameter lambda for the sparsity constraint.
    rho : float
        The Lagrangian augmentation parameter for the ADMM algorithm.
    iter : int, optional
        Number of ADMM iterations. (Default 20)
    silent : bool, optional
        Disable progress bar. (Default False)
    timeout : int, optional
        Set runtime limit in seconds. (Default None)

    Returns
    -------
    tensor
       The recovered signal x.
    """

    # init iteration variables
    z = z0.clone()
    x = x0.clone()
    u = torch.zeros_like(z0)

    # prepare conjugate gradient inversion
    inverter = CGInverterLayer(
        x.shape[1:],
        lambda x: OpA.adj(OpA(x)) + rho * OpW.adj(OpW(x)),
        rtol=1e-6,
        atol=0.0,
        maxiter=200,
    )

    if timeout is not None:
        start_time = time.time()

    # run main ADMM iterations
    t = tqdm(range(iter), desc="ADMM iterations", disable=silent)
    for it in t:

        if timeout is not None:
            if time.time() > start_time + timeout:
                print("ADMM aborted due to timeout")
                return x, z

        # ADMM step 1) : signal update
        rhs = OpA.adj(y) + rho * OpW.adj(z - u)
        x = inverter(rhs, x)

        # ADMM step 2) : coefficient update
        zold, z = z, shrink(OpW(x) + u, lam / rho)

        # ADMM step 3 : dual variable update
        u = u + OpW(x) - z

        # evaluate
        with torch.no_grad():
            loss = (0.5 * (OpA(x) - y).pow(2).sum(dim=(-1, -2)) +
                    lam * OpW(x).abs().sum((-1, -2))).mean()
            primal_residual = OpW(x) - z
            dual_residual = rho * OpW.adj(zold - z)

            t.set_postfix(
                loss=loss.item(),
                pres=torch.norm(primal_residual).item(),
                dres=torch.norm(dual_residual).item(),
            )

    return x, z