Exemplo n.º 1
0
def GD_temporal_TV(
        prior,
        kspace_u,
        mask,
        weight_fidelity,
        weight_temporal,
        forward_fun,
        inverse_fun,
        temporal_axis=-1,
        beta_sqrd=1e-7,
        x=None,
        maxiter=200):
    '''Gradient descent for generic encoding model, temporal TV.

    Notes
    -----

    '''
    # Make sure that the temporal axis is last
    if (temporal_axis != -1) and (temporal_axis != mask.ndim-1):
        prior = np.moveaxis(prior, temporal_axis, -1)
        kspace_u = np.moveaxis(kspace_u, temporal_axis, -1)
        mask = np.moveaxis(mask, temporal_axis, -1)

    # Get the image space of the data we did measure
    measuredImgDomain = inverse_fun(kspace_u)

    # Initialize estimates
    img_est = measuredImgDomain.copy()
    W_img_est = measuredImgDomain.copy()

    # Get monotonic ordering
    sort_order_real, sort_order_imag = sort_real_imag_parts(
        prior, axis=-1)

    # # Construct R and C (rows and columns, I assume)
    # rows, cols, pages = img_est.shape[:]
    # R = np.tile(
    #     np.arange(rows), (cols, pages, 1)).transpose((2, 0, 1))
    # C = np.tile(
    #     np.arange(cols), (rows, pages, 1)).transpose((0, 2, 1))
    #
    # # Get the actual indices
    # nIdx_real = R + C*rows + (sort_order_real)*rows*cols
    # nIdx_imag = R + C*rows + (sort_order_imag)*rows*cols

    # Intialize output
    table = Table(
        ['iter', 'norm', 'MSE', 'SSIM'],
        [len(repr(maxiter)), 8, 8, 8],
        ['d', 'e', 'e', 'e'])
    print(table.header())
    if x is None:
        xabs = 0
    else:
        xabs = np.abs(x)
    stop_criteria = 0

    unsort_real_data = np.zeros(img_est.shape, dtype=c_float)
    unsort_imag_data = np.zeros(img_est.shape, dtype=c_float)
    temporal_term_update_real = np.zeros(img_est.shape, dtype=c_float)
    temporal_term_update_imag = np.zeros(img_est.shape, dtype=c_float)

    # Do the thing
    for ii in trange(maxiter, leave=False):

        fidelity_update = weight_fidelity*(
            measuredImgDomain - W_img_est)

        ## computing TV term update for real and imag parts with
        # reordering
        real_smooth_data = np.take_along_axis( #pylint: disable=E1101
            img_est.real, sort_order_real, axis=-1)
        imag_smooth_data = np.take_along_axis( #pylint: disable=E1101
            img_est.imag, sort_order_imag, axis=-1)

        # Real part
        temp_a = np.diff(real_smooth_data, axis=-1)
        temp_b = temp_a/np.sqrt(beta_sqrd + (np.abs(temp_a)**2))
        temp_c = np.diff(temp_b, axis=-1)

        temporal_term_update_real[..., 0] = temp_b[..., 0]
        temporal_term_update_real[..., 1:-1] = temp_c
        temporal_term_update_real[..., -1] = -temp_b[..., -1]

        temporal_term_update_real *= weight_temporal
        np.put_along_axis( #pylint: disable=E1101
            unsort_real_data, sort_order_real,
            temporal_term_update_real, axis=-1)


        # Imag part
        temp_a = np.diff(imag_smooth_data, axis=-1)
        temp_b = temp_a/np.sqrt(beta_sqrd + (np.abs(temp_a)**2))
        temp_c = np.diff(temp_b, axis=-1)

        temporal_term_update_imag[..., 0] = temp_b[..., 0]
        temporal_term_update_imag[..., 1:-1] = temp_c
        temporal_term_update_imag[..., -1] = -temp_b[..., -1]

        temporal_term_update_imag *= weight_temporal
        np.put_along_axis( #pylint: disable=E1101
            unsort_imag_data, sort_order_imag,
            temporal_term_update_imag, axis=-1)

        # Do the updates
        temporal_term_update = unsort_real_data + 1j*unsort_imag_data
        img_est += fidelity_update + temporal_term_update

        # W_img_est = inverse_fun(forward_fun(img_est))
        W_img_est = np.fft.ifft2(np.fft.fft2(
            img_est, axes=(0, 1))*mask, axes=(0, 1))

        # from mr_utils import view
        # view(np.stack((prior, img_est)))

        # Give user an update
        curxabs = np.abs(img_est)
        tqdm.write(
            table.row(
                [ii, stop_criteria, compare_mse(curxabs, xabs),
                 compare_ssim(curxabs, xabs)]))


    return img_est
Exemplo n.º 2
0
 def test_table_row(self):
     '''Make sure rows are the correct widths.'''
     t = Table(self.headings, self.widths, self.formatters, pad=self.pad)
     _hdr = t.header()
     row = t.row(self.widths)
     self.assertEqual(len(row), sum(self.widths))
Exemplo n.º 3
0
def proximal_GD(
        y,
        forward_fun,
        inverse_fun,
        sparsify,
        unsparsify,
        reorder_fun=None,
        mode='soft',
        alpha=.5,
        alpha_start=.5,
        thresh_sep=True,
        selective=None,
        x=None,
        ignore_residual=False,
        ignore_mse=True,
        ignore_ssim=True,
        disp=False,
        maxiter=200,
        strikes=0):
    r'''Proximal gradient descent for generic encoding/sparsity model.

    Parameters
    ----------
    y : array_like
        Measured data (i.e., y = Ax).
    forward_fun : callable
        A, the forward transformation function.
    inverse_fun : callable
        A^H, the inverse transformation function.
    sparsify : callable
        Sparsifying transform.
    unsparsify : callable
        Inverse sparsifying transform.
    reorder_fun : callable, optional
        Reordering function.
    unreorder_fun : callable, optional
        Inverse reordering function.
    mode : {'soft', 'hard', 'garotte', 'greater', 'less'}, optional
        Thresholding mode.
    alpha : float or callable, optional
        Step size, used for thresholding.
    alpha_start : float, optional
        Initial alpha to start with if alpha is callable.
    thresh_sep : bool, optional
        Whether or not to threshold real/imag individually.
    selective : bool, optional
        Function returning indicies of update to keep at each iter.
    x : array_like, optional
        The true image we are trying to reconstruct.
    ignore_residual : bool, optional
        Whether or not to break out of loop if resid increases.
    ignore_mse : bool, optional
        Whether or not to break out of loop if MSE increases.
    ignore_ssim : bool, optional
        Whether or not to break out of loop if SSIM increases.
    disp : bool, optional
        Whether or not to display iteration info.
    maxiter : int, optional
        Maximum number of iterations.
    strikes : int, optional
        Number of ending conditions tolerated before giving up.

    Returns
    -------
    x_hat : array_like
        Estimate of x.

    Notes
    -----
    Solves the problem:

    .. math::

        \min_x || y - Ax ||^2_2  + \lambda \text{Sparsify}(x)

    If `x=None`, then MSE will not be calculated. You probably want
    `mode='soft'`.  For the other options, see docs for
    pywt.threshold. `selective=None` will not throw away any updates.
    '''

    # Make sure compare_mse, compare_ssim is defined
    if x is None:
        compare_mse = lambda xx, yy: 0
        compare_ssim = lambda xx, yy: 0
        xabs = 0
        logging.info(
            'No true x provided, MSE/SSIM will not be calculated.')
    else:
        from skimage.measure import compare_mse, compare_ssim
        # Precompute absolute value of true image
        xabs = np.abs(x.astype(y.dtype))

    # Get some display stuff happening
    if disp:
        # Don't use tqdm
        range_fun = range

        from mr_utils.utils.printtable import Table
        table = Table(
            ['iter', 'norm', 'MSE', 'SSIM'],
            [len(repr(maxiter)), 8, 8, 8],
            ['d', 'e', 'e', 'e'])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)
    else:
        # Use tqdm to give us an idea of how fast we're going
        from tqdm import trange, tqdm
        range_fun = lambda x: trange(
            x, leave=False, desc='Proximal GD')

    # Initialize
    x_hat = np.zeros(y.shape, dtype=y.dtype)
    r = -y.copy()
    prev_stop_criteria = np.inf
    cur_ssim = 0
    prev_ssim = compare_ssim(xabs, np.abs(inverse_fun(y)))
    cur_mse = 0
    prev_mse = compare_mse(xabs, np.abs(inverse_fun(y)))
    norm_y = np.linalg.norm(y)
    if isinstance(alpha, float):
        alpha0 = alpha
    else:
        alpha0 = alpha_start

    # Do the thing
    strike_count = 0
    for ii in range_fun(int(maxiter)):

        # Compute stop criteria
        stop_criteria = np.linalg.norm(r)/norm_y
        if not ignore_residual and stop_criteria > prev_stop_criteria:
            if strike_count > strikes:
                msg = ('Breaking out of loop after %d iterations. '
                       'Norm of residual increased!' % ii)
                if importlib.util.find_spec("tqdm") is None:
                    tqdm.write(msg)
                else:
                    logging.warning(msg)
                break
            else:
                strike_count += 1
        prev_stop_criteria = stop_criteria

        # Compute gradient descent step in prep for reordering
        grad_step = x_hat - inverse_fun(r)

        # Do reordering if we asked for it
        if reorder_fun is not None:
            reorder_idx = reorder_fun(grad_step)
            reorder_idx_r = reorder_idx.real.astype(int)
            reorder_idx_i = reorder_idx.imag.astype(int)

            # unreorder_idx_r = inverse_permutation(reorder_idx_r)
            # unreorder_idx_i = inverse_permutation(reorder_idx_i)
            # unreorder_idx_r = np.arange(
            #     reorder_idx_r.size).astype(int)
            # unreorder_idx_r[reorder_idx_r] = reorder_idx_r
            # unreorder_idx_i = np.arange(
            #     reorder_idx_i.size).astype(int)
            # unreorder_idx_i[reorder_idx_i] = reorder_idx_i

            grad_step = (
                grad_step.real[np.unravel_index(
                    reorder_idx_r, y.shape)] \
                + 1j*grad_step.imag[np.unravel_index(
                    reorder_idx_i, y.shape)]).reshape(y.shape)

        # Take the step, we would normally assign x_hat directly, but
        # because we might be reordering and selectively updating,
        # we'll store it in a temporary variable...
        if thresh_sep:
            tmp = sparsify(grad_step)
            # Take a half step in each real/imag after talk with Ed
            tmp_r = threshold(tmp.real, value=alpha0/2, mode=mode)
            tmp_i = threshold(tmp.imag, value=alpha0/2, mode=mode)
            update = unsparsify(tmp_r + 1j*tmp_i)
        else:
            update = unsparsify(
                threshold(
                    sparsify(grad_step), value=alpha0, mode=mode))

        # Undo the reordering if we did it
        if reorder_fun is not None:
            # update = (
            #     update.real[np.unravel_index(
            #         unreorder_idx_r, y.shape)] \
            #     + 1j*update.imag[np.unravel_index(
            #         unreorder_idx_i, y.shape)]).reshape(y.shape)

            update_r = np.zeros(y.shape)
            update_r[np.unravel_index(
                reorder_idx_r, y.shape)] = update.real.flatten()
            update_i = np.zeros(y.shape)
            update_i[np.unravel_index(
                reorder_idx_i, y.shape)] = update.imag.flatten()
            update = update_r + 1j*update_i

        # Look at where we want to take the step - tread carefully...
        if selective is not None:
            selective_idx = selective(x_hat, update, ii)

        # Update image estimae
        if selective is not None:
            x_hat[selective_idx] = update[selective_idx]
        else:
            x_hat = update

        # Tell the user what happened
        if disp:
            curxabs = np.abs(x_hat)
            cur_mse = compare_mse(curxabs, xabs)
            cur_ssim = compare_ssim(curxabs, xabs)
            logging.info(
                table.row(
                    [ii, stop_criteria, cur_mse, cur_ssim]))

        if not ignore_mse and cur_mse > prev_mse:
            if strike_count > strikes:
                msg = ('Breaking out of loop after %d iterations. '
                       'MSE increased!' % ii)
                if importlib.util.find_spec("tqdm") is None:
                    tqdm.write(msg)
                else:
                    logging.warning(msg)
                break
            else:
                strike_count += 1
        prev_mse = cur_mse

        if not ignore_ssim and cur_ssim > prev_ssim:
            if strike_count > strikes:
                msg = ('Breaking out of loop after %d iterations. '
                       'SSIM increased!' % ii)
                if importlib.util.find_spec("tqdm") is None:
                    tqdm.write(msg)
                else:
                    logging.warning(msg)
                break
            else:
                strike_count += 1
        prev_ssim = cur_ssim

        # Compute residual
        r = forward_fun(x_hat) - y

        # Get next step size
        if callable(alpha):
            alpha0 = alpha(alpha0, ii)

    return x_hat
Exemplo n.º 4
0
def IHT(A, y, k, mu=1, maxiter=500, tol=1e-8, x=None, disp=False):
    '''Iterative hard thresholding algorithm (IHT).

    A -- Measurement matrix.
    y -- Measurements (i.e., y = Ax).
    k -- Number of expected nonzero coefficients.
    mu -- Step size.
    maxiter -- Maximum number of iterations.
    tol -- Stopping criteria.
    x -- True signal we are trying to estimate.
    disp -- Whether or not to display iterations.

    Solves the problem:
        min_x || y - Ax ||^2_2  s.t.  ||x||_0 <= k

    If disp=True, then MSE will be calculated using provided x. mu=1 seems to
    satisfy Theorem 8.4 often, but might need to be adjusted (usually < 1).
    See normalized IHT for adaptive step size.

    Implements Algorithm 8.5 from:
        Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory
        and applications. Cambridge University Press, 2012.
    '''

    # length of measurement vector and original signal
    n, N = A.shape[:]

    # Make sure we have everything we need for disp
    if disp and x is None:
        logging.warning('No true x provided, using x=0 for MSE calc.')
        x = np.zeros(N)

    # Some fancy, asthetic touches...
    if disp:
        table = Table(['iter', 'MSE'], [len(repr(maxiter)), 8], ['d', 'e'])
        range_fun = range
    else:
        from tqdm import trange
        range_fun = lambda x: trange(x, leave=False, desc='IHT')

    # Initial estimate of x, x_hat
    x_hat = np.zeros(N, dtype=y.dtype)

    # Get initial residue
    r = y.copy()

    # Set up header for logger
    if disp:
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Run until tol reached or maxiter reached
    for tt in range_fun(maxiter):
        # Update estimate using residual scaled by step size
        x_hat += mu * np.dot(A.conj().T, r)

        # Find the k'th largest coefficient of gamma, use it as threshold
        thresh = -np.sort(-np.abs(x_hat))[k - 1]

        # Hard thresholding operator
        x_hat[np.abs(x_hat) < thresh] = 0

        # Show MSE at current iteration if we wanted it
        if disp:
            logging.info(table.row([tt, np.mean((np.abs(x - x_hat)**2))]))

        # update the residual
        r = y - np.dot(A, x_hat)

        # Check stopping criteria
        if np.linalg.norm(r) / np.linalg.norm(y) < tol:
            break

    # Regroup and debrief...
    if tt == (maxiter - 1):
        logging.warning(
            'Hit maximum iteration count, estimate may not be accurate!')
    else:
        if disp:
            logging.info('Final || r || . || y ||^-1 : %g' %
                         (np.linalg.norm(r) / np.linalg.norm(y)))

    return (x_hat)
Exemplo n.º 5
0
 def test_table_row(self):
     t = Table(self.headings,self.widths,self.formatters,pad=self.pad)
     hdr = t.header()
     row = t.row(self.widths)
     self.assertEqual(len(row),sum(self.widths))
Exemplo n.º 6
0
def cosamp(A, y, k, lstsq='exact', tol=1e-8, maxiter=500, x=None, disp=False):
    '''Compressive sampling matching pursuit (CoSaMP) algorithm.

    Parameters
    ==========
    A : array_like
        Measurement matrix.
    y : array_like
        Measurements (i.e., y = Ax).
    k : int
        Number of expected nonzero coefficients.
    lstsq : {'exact', 'lm', 'gd'}, optional
        How to solve intermediate least squares problem.
    tol : float, optional
        Stopping criteria.
    maxiter : int, optional
        Maximum number of iterations.
    x : array_like, optional
        True signal we are trying to estimate.
    disp : bool, optional
        Whether or not to display iterations.

    Returns
    =======
    x_hat : array_like
        Estimate of x.

    Notes
    =====
    lstsq function
    - 'exact' solves it using numpy's linalg.lstsq method.
    - 'lm' uses solves with the Levenberg-Marquardt algorithm.
    - 'gd' uses 3 iterations of a gradient descent solver.

    Implements Algorithm 8.7 from [1]_.

    References
    ==========
    .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing:
           theory and applications. Cambridge University Press, 2012.
    '''

    # length of measurement vector and original signal
    _n, N = A.shape[:]

    # Initializations
    x_hat = np.zeros(N, dtype=y.dtype)
    r = y.copy()
    ynorm = np.linalg.norm(y)

    if x is None:
        x = np.zeros(x_hat.shape, dtype=y.dtype)
    elif x.size < x_hat.size:
        x = np.hstack(([0], x))

    # Decide how we want to solve the intermediate least squares problem
    if lstsq == 'exact':
        lstsq_fun = lambda A0, y: np.linalg.lstsq(A0, y, rcond=None)[0]
    elif lstsq == 'lm':
        # # This also doesn't work very well currently....
        # from scipy.optimize import least_squares
        # lstsq_fun = lambda A0, y: least_squares(
        #     lambda x: np.linalg.norm(y - np.dot(A0, x)),
        #     np.zeros(A0.shape[1], dtype=y.dtype))['x']
        raise NotImplementedError()
    elif lstsq == 'gd':
        # # This doesn't work very well...
        # from mr_utils.optimization import gd, fd_complex_step
        # lstsq_fun = lambda A0, y: gd(
        #     lambda x: np.linalg.norm(y - np.dot(A0, x)),
        #     fd_complex_step,
        #     np.zeros(A0.shape[1], dtype=y.dtype), maxiter=3)[0]
        raise NotImplementedError()
    else:
        raise NotImplementedError()

    # Start up a table
    if disp:
        table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8],
                      ['d', 'e', 'e'])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    for ii in range(maxiter):

        # Get step direction
        g = np.dot(A.conj().T, r)

        # Add 2*k largest elements of g to support set
        Tn = np.union1d(x_hat.nonzero()[0], np.argsort(np.abs(g))[-(2 * k):])

        # Solve the least squares problem
        xn = np.zeros(N, dtype=y.dtype)
        xn[Tn] = lstsq_fun(A[:, Tn], y)

        xn[np.argsort(np.abs(xn))[:-k]] = 0
        x_hat = xn.copy()

        # Compute new residual
        r = y - np.dot(A, x_hat)

        # Compute stopping criteria
        stop_criteria = np.linalg.norm(r) / ynorm

        # Show MSE at current iteration if we wanted it
        if disp:
            logging.info(
                table.row([ii, stop_criteria,
                           np.mean((np.abs(x - x_hat)**2))]))

        # Check stopping criteria
        if stop_criteria < tol:
            break

    return x_hat
Exemplo n.º 7
0
def IHT_TV(y,
           forward_fun,
           inverse_fun,
           k,
           mu=1,
           tol=1e-8,
           do_reordering=False,
           x=None,
           ignore_residual=False,
           disp=False,
           maxiter=500):
    r'''IHT for generic encoding model and TV constraint.

    Parameters
    ----------
    y : array_like
        Measured data, i.e., y = Ax.
    forward_fun : callable
        A, the forward transformation function.
    inverse_fun : callable
        A^H, the inverse transformation function.
    k : int
        Sparsity measure (number of nonzero coefficients expected).
    mu : float, optional
        Step size.
    tol : float, optional
        Stop when stopping criteria meets this threshold.
    do_reordering : bool, optional
        Reorder column-stacked true image.
    x : array_like, optional
        The true image we are trying to reconstruct.
    ignore_residual : bool, optional
        Whether or not to break out of loop if resid increases.
    disp : bool, optional
        Whether or not to display iteration info.
    maxiter : int, optional
        Maximum number of iterations.

    Returns
    -------
    x_hat : array_like
        Estimate of x.

    Notes
    -----
    Solves the problem:

    .. math::

        \min_x || y - Ax ||^2_2 \text{ s.t. } || \text{TV}(x) ||_0
        \leq k

    If `x=None`, then MSE will not be calculated.
    '''

    # Make sure we have a defined compare_mse and Table for printing
    if disp:
        from mr_utils.utils.printtable import Table

        if x is not None:
            from skimage.measure import compare_mse
            xabs = np.abs(x)
        else:
            compare_mse = lambda xx, yy: 0

    # Right now we are doing absolute values on updates
    x_hat = np.zeros(y.shape)
    r = y.copy()
    prev_stop_criteria = np.inf
    norm_y = np.linalg.norm(y)

    # Initialize display table
    if disp:
        table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8],
                      ['d', 'e', 'e'])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Find perfect reordering (column-stacked-wise)
    if do_reordering:
        from mr_utils.utils.orderings import (col_stacked_order,
                                              inverse_permutation)
        reordering = col_stacked_order(x)
        inverse_reordering = inverse_permutation(reordering)

        # Find new sparsity measure
        if x is not None:
            k = np.sum(np.abs(np.diff(x.flatten()[reordering])) > 0)
        else:
            logging.warning(('Make sure sparsity level k is '
                             'adjusted for reordering!'))

    # Do the thing
    for ii in range(int(maxiter)):

        # Density compensation!!!!
        #

        # Take step
        # val = (x_hat + mu*np.abs(np.fft.ifft2(r))).flatten()
        val = (x_hat + mu * inverse_fun(r)).flatten()

        # Do the reordering
        if do_reordering:
            val = val[reordering]

        # Finite differences transformation
        first_samp = val[0]  # save first sample for inverse transform
        fd = np.diff(val)

        # Hard thresholding
        fd[np.argsort(np.abs(fd))[:-1 * k]] = 0

        # Inverse finite differences transformation
        res = np.hstack((first_samp, fd)).cumsum()
        if do_reordering:
            res = res[inverse_reordering]

        # Compute stopping criteria
        stop_criteria = np.linalg.norm(r) / norm_y

        # If the stop_criteria gets worse, get out of dodge
        if not ignore_residual and (stop_criteria > prev_stop_criteria):
            logging.warning('Residual increased! Not continuing!')
            break
        prev_stop_criteria = stop_criteria

        # Update x
        x_hat = res.reshape(x_hat.shape)

        # Show the people what they asked for
        if disp:
            logging.info(
                table.row([ii, stop_criteria,
                           compare_mse(xabs, x_hat)]))
        if stop_criteria < tol:
            break

        # update the residual
        r = y - forward_fun(x_hat)

    return x_hat
Exemplo n.º 8
0
def nIHT(A,
         y,
         k,
         c=0.1,
         kappa=None,
         x=None,
         maxiter=200,
         tol=1e-8,
         disp=False):
    '''Normalized iterative hard thresholding.

    A -- Measurement matrix
    y -- Measurements (i.e., y = Ax)
    k -- Number of nonzero coefficients preserved after thresholding.
    c -- Small, fixed constant. Tunable.
    kappa -- Constant, > 1/(1 - c).
    x -- True signal we want to estimate.
    maxiter -- Maximum number of iterations (of the outer loop).
    tol -- Stopping criteria.
    dip -- Whether or not to display iteration info.

    Implements Algorithm 8.6 from:
        Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory
        and applications. Cambridge University Press, 2012.
    '''

    # Basic checks
    assert 0 < c < 1, 'c must be in (0,1)'

    # length of measurement vector and original signal
    n, N = A.shape[:]

    # Make sure we have everything we need for disp
    if disp and x is None:
        logging.warning('No true x provided, using x=0 for MSE calc.')
        x = np.zeros(N)

    if disp:
        table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8],
                      ['d', 'e', 'e'])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Initializations
    x_hat = np.zeros(N)

    # Inital calculation of support
    val = A.T.dot(y)
    thresh = -np.sort(-np.abs(val))[k - 1]
    val[np.abs(val) < thresh] = 0
    T = np.nonzero(val)

    # Find suitable kappa if the user didn't give us one
    if kappa is None:
        # kappa must be > 1/(1 - c), so try 2 times the lower bound
        kappa = 2 / (1 - c)
    else:
        assert kappa > 1 / (1 - c), 'kappa must be > 1/(1 - c)'

    # Do the iterative part of the thresholding...
    for ii in range(maxiter):

        # Compute residual
        r = y - np.dot(A, x_hat)

        # Check stopping criteria
        stop_criteria = np.linalg.norm(r) / np.linalg.norm(y)
        if stop_criteria < tol:
            break

        # Let's check out what's going on
        if disp:
            logging.info(table.row([ii, stop_criteria, compare_mse(x, x_hat)]))

        # Compute step size
        g = np.dot(A.T, r)
        mu = np.linalg.norm(g)**2 / np.linalg.norm(np.dot(A, g))**2

        # Hard thresholding
        xn = x_hat + mu * g
        thresh = -np.sort(-np.abs(xn))[k - 1]
        xn[np.abs(xn) < thresh] = 0

        # Compute support of xn
        Tn = np.nonzero(xn)

        # Decide what to do
        if np.array_equal(Tn, T):
            x_hat = xn
        else:
            cond = (1 - c) * np.linalg.norm(xn - x_hat)**2 / np.linalg.norm(
                np.dot(A, xn - x_hat))**2
            if mu <= cond:
                x_hat = xn
            else:
                while mu > cond:
                    mu /= kappa * (1 - c)
                    xn = x_hat + mu * g
                    thresh = -np.sort(-np.abs(xn))[k - 1]
                    xn[np.abs(xn) < thresh] = 0
                    cond = (1 - c
                            ) * np.linalg.norm(xn - x_hat)**2 / np.linalg.norm(
                                np.dot(A, xn - x_hat))**2

                Tn = np.nonzero(xn)
                x_hat = xn

    return (x_hat)
Exemplo n.º 9
0
def amp2d(y,
          forward_fun,
          inverse_fun,
          sigmaType=2,
          randshift=False,
          tol=1e-8,
          x=None,
          ignore_residual=False,
          disp=False,
          maxiter=100):
    r'''Approximate message passing using wavelet sparsifying transform.

    Parameters
    ==========
    y : array_like
        Measurements, i.e., y = Ax.
    forward_fun : callable
        A, the forward transformation function.
    inverse_fun : callable
        A^H, the inverse transformation function.
    sigmaType : int
        Method for determining threshold.
    randshift : bool, optional
        Whether or not to randomly circular shift every iteration.
    tol : float, optional
        Stop when stopping criteria meets this threshold.
    x : array_like, optional
        The true image we are trying to reconstruct.
    ignore_residual : bool, optional
        Whether or not to ignore stopping criteria.
    disp : bool, optional
        Whether or not to display iteration info.
    maxiter : int, optional
        Maximum number of iterations.

    Returns
    =======
    wn : array_like
        Estimate of x.

    Notes
    =====
    Solves the problem:

    .. math::

        \min_x || \Psi(x) ||_1 \text{ s.t. } || y -
        \text{forward}(x) ||^2_2 < \epsilon^2

    The CDF-97 wavelet is used.  If `x=None`, then MSE will not be calculated.

    Algorithm described in [1]_, based on MATLAB implementation found at [2]_.

    References
    ==========
    .. [1] "Message Passing Algorithms for CS" Donoho et al., PNAS
           2009;106:18914

    .. [2] http://kyungs.bol.ucla.edu/Site/Software.html
    '''

    # Make sure we have a defined compare_mse and Table for printing
    if disp:
        # Initialize display table
        from mr_utils.utils.printtable import Table
        if disp:
            table = Table(['iter', 'resid', 'resid diff', 'MSE'],
                          [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e'])
            hdr = table.header()
            for line in hdr.split('\n'):
                logging.info(line)

        if x is not None:
            from skimage.measure import compare_mse
            xabs = np.abs(x)
        else:
            xabs = 0
            compare_mse = lambda xx, yy: 0

    # Do some initial calculations...
    mm = np.sum(abs(y) > np.finfo(float).eps)
    rfact = y.size / mm

    # I'm currently not sure how we found these optimim lambdas...
    OptimumLambdaSigned = loadmat(dirname(__file__) \
        + '/OptimumLambdaSigned.mat')  # has the optimal values of lambda
    delta_vec = OptimumLambdaSigned['delta_vec'][0]
    lambda_opt = OptimumLambdaSigned['lambda_opt'][0]
    delta = 1 / rfact
    lambdas = np.interp(delta, delta_vec, lambda_opt)

    # Initial values
    wn = np.zeros(y.shape, dtype=y.dtype)
    zn = y - forward_fun(wn)
    abc = 0
    nx, ny = y.shape[:]

    res_norm = np.zeros(maxiter + 1)
    nn = np.zeros(maxiter + 1)
    res_diff = np.zeros(maxiter)

    res_norm[0] = np.linalg.norm(zn)
    norm_y = np.linalg.norm(y)
    nn[0] = res_norm[0] / norm_y

    for abc in range(int(maxiter)):

        # First-order Approximate Message Passing
        temp_z = inverse_fun(zn) + wn

        # Randomly shift left, right if we asked for it
        if randshift:
            rand_shift_x = np.random.randint(0, nx)
            rand_shift_y = np.random.randint(0, ny)
            temp_z = np.roll(temp_z, (rand_shift_x, rand_shift_y))

        # Sparsify with wavelet transform
        temp_z, locations = cdf97_2d_forward(temp_z, level=5)

        # Compute sigma hat
        if sigmaType == 1:
            sigma_hat = np.median(np.abs(temp_z.flatten())) / .6745
        else:
            sigma_hat = res_norm[abc] / np.sqrt(mm)

        # If sigma is zero put any VERY small number
        if sigma_hat == 0:
            sigma_hat = .1

        # Soft Thresholding
        wn1 = (np.abs(temp_z) > lambdas*sigma_hat)*(np.abs(temp_z) \
            - lambdas*sigma_hat)*np.sign(temp_z)

        # Compute a sparsity/measurement ratio
        amp_weight = np.sum(np.abs(wn1) > np.finfo(float).eps) / mm

        # Un-sparsify
        wn1 = cdf97_2d_inverse(wn1, locations)

        # random shift back
        if randshift:
            wn1 = np.roll(wn1, (-rand_shift_x, -rand_shift_y))

        # Update the residual term
        residual = y - forward_fun(wn1)

        # Normalized data fidelity term
        res_norm[abc + 1] = np.linalg.norm(residual)
        nn[abc + 1] = res_norm[abc + 1] / norm_y
        res_diff[abc] = np.abs(nn[abc + 1] - nn[abc])

        # Give the people what they asked for!
        if disp:
            logging.info(
                table.row([
                    abc, nn[abc + 1], res_diff[abc],
                    compare_mse(xabs, np.abs(wn1))
                ]))

        # Check stopping criteria
        if not ignore_residual and (res_diff[abc] < tol):
            break

        # Update Estimation
        wn = wn1

        # Weight the residual with a little extra sauce
        if amp_weight > 1:
            zn = residual + 0.25 * zn
        else:
            zn = residual + amp_weight * zn

    return wn
Exemplo n.º 10
0
def IST(A,
        y,
        mu=0.8,
        theta0=None,
        k=None,
        maxiter=500,
        tol=1e-8,
        x=None,
        disp=False):
    r'''Iterative soft thresholding algorithm (IST).

    Parameters
    ==========
    A : array_like
        Measurement matrix.
    y : array_like
        Measurements (i.e., y = Ax).
    mu : float, optional
        Step size (theta contraction factor, 0 < mu <= 1).
    theta0 : float, optional
        Initial threshold, decreased by factor of mu each iteration.
    k : int, optional
        Number of expected nonzero coefficients.
    maxiter : int, optional
        Maximum number of iterations.
    tol : float, optional
        Stopping criteria.
    x : array_like, optional
        True signal we are trying to estimate.
    disp : bool, optional
        Whether or not to display iterations.

    Returns
    =======
    x_hat : array_like
        Estimate of x.

    Notes
    =====
    Solves the problem:

    .. math::

        \min_x || y - Ax ||^2_2 \text{ s.t. } ||x||_0 \leq k

    If `disp=True`, then MSE will be calculated using provided x. If
    `theta0=None`, the initial threshold of the IHT will be used as the
    starting theta.

    Implements Equations [22-23] from [1]_

    References
    ==========
    .. [1] Rani, Meenu, S. B. Dhok, and R. B. Deshmukh. "A systematic review of
           compressive sensing: Concepts, implementations and applications."
           IEEE Access 6 (2018): 4875-4894.
    '''

    # Check to make sure we have good mu
    assert 0 < mu <= 1, 'mu should be 0 < mu <= 1!'

    # length of measurement vector and original signal
    _n, N = A.shape[:]

    # Make sure we have everything we need for disp
    if disp and x is None:
        logging.warning('No true x provided, using x=0 for MSE calc.')
        x = np.zeros(N)

    # Some fancy, asthetic touches...
    if disp:
        range_fun = range
        table = Table(['iter', 'norm', 'theta', 'MSE'],
                      [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e'])
    else:
        from tqdm import trange
        range_fun = lambda x: trange(x, leave=False, desc='IST')

    # Initial estimate of x, x_hat
    x_hat = np.zeros(N)

    # Get initial residue
    r = y.copy()

    # Start theta at specified theta0 or use IHT first threshold
    if theta0 is None:
        assert k is not None, ('k (measure of sparsity) required to compute '
                               'initial threshold!')
        theta = -np.sort(-np.abs(np.dot(A.T, r)))[k - 1]
    else:
        assert theta0 > 0, 'Threshold must be positive!'
        theta = theta0

    # Set up header for logger
    if disp:
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Run until tol reached or maxiter reached
    tt = 0
    for tt in range_fun(maxiter):
        # Update estimate using residual
        x_hat += np.dot(A.T, r)

        # Just like IHT, but use soft thresholding operator
        # It is unclear to me what sign function needs to be used:
        # count 0 as 0?
        x_hat = np.maximum(np.abs(x_hat) - theta, np.zeros(
            x_hat.shape)) * np.sign(x_hat)

        # update the residual
        r = y - np.dot(A, x_hat)

        # Check stopping criteria
        stop_criteria = np.linalg.norm(r) / np.linalg.norm(y)
        if stop_criteria < tol:
            break

        # Show MSE at current iteration if we wanted it
        if disp:
            logging.info(
                table.row([tt, stop_criteria, theta,
                           compare_mse(x, x_hat)]))

        # Contract theta before we go back around the horn
        theta *= mu

    # Regroup and debrief...
    if tt == (maxiter - 1):
        logging.warning(
            'Hit maximum iteration count, estimate may not be accurate!')
    else:
        if disp:
            logging.info('Final || r || . || y ||^-1 : %g',
                         (np.linalg.norm(r) / np.linalg.norm(y)))

    return x_hat
Exemplo n.º 11
0
def tv_l1_denoise(im, lam, disp=False, niter=100):
    '''TV-L1 image denoising with the primal-dual algorithm.

    Parameters
    ==========
    im : array_like
        image to be processed
    lam : float
        regularization parameter controlling the amount of denoising;
        smaller values imply more aggressive denoising which tends to
        produce more smoothed results
    disp : bool, optional
        print energy being minimized each iteration
    niter : int, optional
        number of iterations

    Returns
    =======
    newim : array_like
        l1 denoised image.

    Raises
    ======
    AssertionError
        When dimension of im is not 2.
    '''

    L2 = 8.0
    tau = 0.02
    sigma = 1.0/(L2*tau)
    theta = 1.0
    lt = lam*tau

    assert im.ndim == 2, 'This function only works for 2D images!'
    height, width = im.shape[:]

    unew = np.zeros(im.shape)
    p = np.zeros((height, width, 2))
    # d = np.zeros(im.shape)
    ux = np.zeros(im.shape)
    uy = np.zeros(im.shape)

    mx = np.max(im)
    if mx > 1.0:
        # normalize
        nim = im/mx
    else:
        # leave intact
        nim = im

    # initialize
    u = nim
    p[:, :, 0] = np.append(u[:, 1:], u[:, -1:], axis=1) - u
    p[:, :, 1] = np.append(u[1:, :], u[-1:, :], axis=0) - u

    # Work out what we're displaying
    if disp:
        from mr_utils.utils.printtable import Table
        table = Table(
            ['Iter', 'Energy'],
            [len(repr(niter)), 8],
            ['d', 'e'])
        print(table.header())
        range_fun = range
    else:
        from tqdm import trange
        range_fun = lambda x: trange(x, leave=False, desc="TV Denoise")

    for kk in range_fun(niter):
        # projection
        # compute gradient in ux, uy
        ux = np.append(u[:, 1:], u[:, -1:], axis=1) - u
        uy = np.append(u[1:, :], u[-1:, :], axis=0) - u
        p += sigma*np.stack((ux, uy), axis=2)

        # project
        normep = np.maximum(np.ones(im.shape),
                            np.sqrt(p[:, :, 0]**2 + p[:, :, 1]**2))
        p[:, :, 0] /= normep
        p[:, :, 1] /= normep

        # shrinkage
        # compute divergence in div
        div = np.vstack((p[:height-1, :, 1], np.zeros((1, width)))) \
            - np.vstack((np.zeros((1, width)), p[:height-1, :, 1]))
        div += np.hstack((p[:, :width-1, 0], np.zeros((height, 1)))) \
            - np.hstack((np.zeros((height, 1)), p[:, :width-1, 0]))

        # TV-L1 model
        v = u + tau*div
        unew = (v - lt)*(v - nim > lt) + (v + lt)*(v - nim < -lt) \
            + nim*(np.abs(v - nim) <= lt)

        # extragradient step
        u = unew + theta*(unew - u)

        # energy being minimized
        if disp:
            E = np.sum(np.sqrt(ux.flatten()**2 + uy.flatten()**2)) \
                + lam*np.sum(np.abs(u.flatten() - nim.flatten()))
            print(table.row((kk, E)))

    newim = u
    return newim
Exemplo n.º 12
0
def IHT(A, y, k, mu=1, maxiter=500, tol=1e-8, x=None, disp=False):
    r'''Iterative hard thresholding algorithm (IHT).

    Parameters
    ----------
    A : array_like
        Measurement matrix.
    y : array_like
        Measurements (i.e., y = Ax).
    k : int
        Number of expected nonzero coefficients.
    mu : float, optional
        Step size.
    maxiter : int, optional
        Maximum number of iterations.
    tol : float, optional
        Stopping criteria.
    x : array_like, optional
        True signal we are trying to estimate.
    disp : bool, optional
        Whether or not to display iterations.

    Returns
    -------
    x_hat : array_like
        Estimate of x.

    Notes
    -----
    Solves the problem:

    .. math::

        \min_x || y - Ax ||^2_2 \text{ s.t. } ||x||_0 \leq k

    If `disp=True`, then MSE will be calculated using provided x.
    `mu=1` seems to satisfy Theorem 8.4 often, but might need to be
    adjusted (usually < 1). See normalized IHT for adaptive step size.

    Implements Algorithm 8.5 from [1]_.

    References
    ----------
    .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed
           sensing: theory and applications. Cambridge University
           Press, 2012.
    '''

    # length of measurement vector and original signal
    _n, N = A.shape[:]

    # Make sure we have everything we need for disp
    if disp and x is None:
        logging.warning('No true x provided, using x=0 for MSE calc.')
        x = np.zeros(N)

    # Some fancy, asthetic touches...
    if disp:
        table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8],
                      ['d', 'e', 'e'])
        range_fun = range
    else:
        from tqdm import trange
        range_fun = lambda x: trange(x, leave=False, desc='IHT')

    # Initial estimate of x, x_hat
    x_hat = np.zeros(N, dtype=y.dtype)

    # Get initial residue
    r = y.copy()

    # Set up header for logger
    if disp:
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Run until tol reached or maxiter reached
    tt = 0
    for tt in range_fun(int(maxiter)):
        # Update estimate using residual scaled by step size
        x_hat += mu * np.dot(A.conj().T, r)

        # Leave only k coefficients nonzero (hard threshold)
        x_hat[np.argsort(np.abs(x_hat))[:-k]] = 0

        stop_criteria = np.linalg.norm(r) / np.linalg.norm(y)

        # Show MSE at current iteration if we wanted it
        if disp:
            logging.info(
                table.row([tt, stop_criteria,
                           np.mean((np.abs(x - x_hat)**2))]))

        # update the residual
        r = y - np.dot(A, x_hat)

        # Check stopping criteria
        if stop_criteria < tol:
            break

    # Regroup and debrief...
    if tt == (maxiter - 1):
        logging.warning(('Hit maximum iteration count, estimate '
                         'may not be accurate!'))
    else:
        if disp:
            logging.info('Final || r || . || y ||^-1 : %g',
                         (np.linalg.norm(r) / np.linalg.norm(y)))

    return x_hat
Exemplo n.º 13
0
def cosamp(A,y,k,lstsq='exact',tol=1e-8,maxiter=500,x=None,disp=False):
    '''Compressive sampling matching pursuit (CoSaMP) algorithm.

    A -- Measurement matrix.
    y -- Measurements (i.e., y = Ax).
    k -- Number of expected nonzero coefficients.
    lstsq -- How to solve intermediate least squares problem.
    tol -- Stopping criteria.
    maxiter -- Maximum number of iterations.
    x -- True signal we are trying to estimate.
    disp -- Whether or not to display iterations.

    lstsq function:
        lstsq = { 'exact', 'lm', 'gd' }.

        'exact' solves it using numpy's linalg.lstsq method.
        'lm' uses solves with the Levenberg-Marquardt algorithm.
        'gd' uses 3 iterations of a gradient descent solver.

    Implements Algorithm 8.7 from:
        Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory
        and applications. Cambridge University Press, 2012.
    '''

    # length of measurement vector and original signal
    n,N = A.shape[:]

    # Initializations
    x_hat = np.zeros(N,dtype=y.dtype)
    r = y.copy()

    # Decide how we want to solve the intermediate least squares problem
    if lstsq == 'exact':
        lstsq_fun = lambda A0,y: np.linalg.lstsq(A0,y,rcond=None)[0]
    elif lstsq == 'lm':
        from scipy.optimize import least_squares
        lstsq_fun = lambda A0,y: least_squares(lambda x: np.linalg.norm(y - np.dot(A0,x)),np.zeros(A0.shape[1],dtype=y.dtype))['x']
    elif lstsq == 'gd':
        # This doesn't work very well...
        from mr_utils.optimization import gd,fd_complex_step
        lstsq_fun = lambda A0,y: gd(lambda x: np.linalg.norm(y - np.dot(A0,x)),fd_complex_step,np.zeros(A0.shape[1],dtype=y.dtype),iter=3)[0]
    else:
        raise NotImplementedError()

    # Start up a table
    if disp:
        table = Table([ 'iter','residual','MSE' ],[ len(repr(maxiter)),8,8 ],[ 'd','e','e' ])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    for ii in range(maxiter):

        # Get step direction
        g = np.dot(A.T,r)

        # Add 2*k largest elements of g to support set
        Tn = np.union1d(x_hat.nonzero()[0],np.argsort(np.abs(g))[-(2*k):])

        # Solve the least squares problem
        xn = np.zeros(N)
        xn[Tn] = lstsq_fun(A[:,Tn],y)

        xn[np.argsort(np.abs(xn))[:-k]] = 0
        x_hat = xn.copy()

        # Compute new residual
        r = y - np.dot(A,x_hat)

        # Compute stopping criteria
        stop_criteria = np.linalg.norm(r)/np.linalg.norm(y)

        # Show MSE at current iteration if we wanted it
        if disp:
            logging.info(table.row([ ii,stop_criteria,np.mean((np.abs(x - x_hat)**2)) ]))

        # Check stopping criteria
        if stop_criteria < tol:
            break

    return(x_hat)
Exemplo n.º 14
0
        first_m = m_hat[0]

        # Transform into finite differences domain
        fd = np.diff(E.conj().T.dot(r))

        # Hard thresholding
        fd[np.argsort(np.abs(fd))[:-k]] = 0

        # Inverse transform and take the step
        m_hat += mu * np.hstack((first_m, fd)).cumsum()

        # This may or may not be a good stopping criteria for this
        stop_criteria = np.linalg.norm(r) / np.linalg.norm(s)

        # Show MSE at current iteration if we wanted it
        if disp:
            print(
                table.row([tt, stop_criteria,
                           np.mean((np.abs(m - m_hat)**2))]))

        # Check stopping criteria
        if stop_criteria < tol:
            break

        # Get new residual
        r = s - E.dot(m_hat)

    plt.imshow(np.abs(m_hat.reshape(smiley.shape)))
    plt.title('IHT Recon')
    plt.show()
Exemplo n.º 15
0
def GD_TV(y,
          forward_fun,
          inverse_fun,
          alpha=.5,
          lam=.01,
          do_reordering=False,
          x=None,
          ignore_residual=False,
          disp=False,
          maxiter=200):
    r'''Gradient descent for a generic encoding model and TV constraint.

    Parameters
    ==========
    y : array_like
        Measured data (i.e., y = Ax).
    forward_fun : callable
        A, the forward transformation function.
    inverse_fun : callable
        A^H, the inverse transformation function.
    alpha : float, optional
        Step size.
    lam : float, optional
        TV constraint weight.
    do_reordering : bool, optional
        Whether or not to reorder for sparsity constraint.
    x : array_like, optional
        The true image we are trying to reconstruct.
    ignore_residual : bool, optional
        Whether or not to break out of loop if resid increases.
    disp : bool, optional
        Whether or not to display iteration info.
    maxiter : int, optional
        Maximum number of iterations.

    Returns
    =======
    x_hat : array_like
        Estimate of x.

    Notes
    =====
    Solves the problem:

    .. math::

        \min_x || y - Ax ||^2_2  + \lambda \text{TV}(x)

    If `x=None`, then MSE will not be calculated.
    '''

    # Make sure compare_mse is defined
    if x is None:
        compare_mse = lambda xx, yy: 0
        logging.info('No true x provided, MSE will not be calculated.')
        xabs = 0
    else:
        from skimage.measure import compare_mse
        xabs = np.abs(x)  # Precompute absolute value of true image

        # Get the reordering indicies ready, both for real and imag parts
        if do_reordering:
            from mr_utils.utils.sort2d import sort2d
            from mr_utils.utils.orderings import inverse_permutation
            _, reordering_r = sort2d(x.real)
            _, reordering_i = sort2d(x.imag)
            inverse_reordering_r = inverse_permutation(reordering_r)
            inverse_reordering_i = inverse_permutation(reordering_i)

    # Get some display stuff happening
    if disp:
        from mr_utils.utils.printtable import Table
        table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8],
                      ['d', 'e', 'e'])
        hdr = table.header()
        for line in hdr.split('\n'):
            logging.info(line)

    # Initialize
    x_hat = np.zeros(y.shape, dtype=y.dtype)
    r = -y.copy()
    prev_stop_criteria = np.inf
    norm_y = np.linalg.norm(y)

    # Do the thing
    for ii in range(int(maxiter)):

        # Fidelity term
        fidelity = inverse_fun(r)

        # Let's reorder if we said that was going to be a thing
        if do_reordering:
            # real part
            xr = x_hat.real.flatten()[reordering_r].reshape(x.shape)
            second_term_r = dTV(xr).flatten()[inverse_reordering_r] \
                .reshape(x.shape)

            # imag part
            xi = x_hat.imag.flatten()[reordering_i].reshape(x.shape)
            second_term_i = dTV(xi).flatten()[inverse_reordering_i] \
                .reshape(x.shape)

            # put it all together...
            second_term = second_term_r + 1j * second_term_i
        else:
            # Sparsity term
            second_term = dTV(x_hat)

        # Compute stop criteria
        stop_criteria = np.linalg.norm(r) / norm_y
        if not ignore_residual and stop_criteria > prev_stop_criteria:
            logging.warning(('Breaking out of loop after %d iterations. '
                             'Norm of residual increased!'), ii)
            break
        prev_stop_criteria = stop_criteria

        # Take the step
        x_hat -= alpha * (fidelity + lam * second_term)

        # Tell the user what happened
        if disp:
            logging.info(
                table.row(
                    [ii, stop_criteria,
                     compare_mse(np.abs(x_hat), xabs)]))

        # Compute residual
        r = forward_fun(x_hat) - y

    return x_hat