def GD_temporal_TV( prior, kspace_u, mask, weight_fidelity, weight_temporal, forward_fun, inverse_fun, temporal_axis=-1, beta_sqrd=1e-7, x=None, maxiter=200): '''Gradient descent for generic encoding model, temporal TV. Notes ----- ''' # Make sure that the temporal axis is last if (temporal_axis != -1) and (temporal_axis != mask.ndim-1): prior = np.moveaxis(prior, temporal_axis, -1) kspace_u = np.moveaxis(kspace_u, temporal_axis, -1) mask = np.moveaxis(mask, temporal_axis, -1) # Get the image space of the data we did measure measuredImgDomain = inverse_fun(kspace_u) # Initialize estimates img_est = measuredImgDomain.copy() W_img_est = measuredImgDomain.copy() # Get monotonic ordering sort_order_real, sort_order_imag = sort_real_imag_parts( prior, axis=-1) # # Construct R and C (rows and columns, I assume) # rows, cols, pages = img_est.shape[:] # R = np.tile( # np.arange(rows), (cols, pages, 1)).transpose((2, 0, 1)) # C = np.tile( # np.arange(cols), (rows, pages, 1)).transpose((0, 2, 1)) # # # Get the actual indices # nIdx_real = R + C*rows + (sort_order_real)*rows*cols # nIdx_imag = R + C*rows + (sort_order_imag)*rows*cols # Intialize output table = Table( ['iter', 'norm', 'MSE', 'SSIM'], [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e']) print(table.header()) if x is None: xabs = 0 else: xabs = np.abs(x) stop_criteria = 0 unsort_real_data = np.zeros(img_est.shape, dtype=c_float) unsort_imag_data = np.zeros(img_est.shape, dtype=c_float) temporal_term_update_real = np.zeros(img_est.shape, dtype=c_float) temporal_term_update_imag = np.zeros(img_est.shape, dtype=c_float) # Do the thing for ii in trange(maxiter, leave=False): fidelity_update = weight_fidelity*( measuredImgDomain - W_img_est) ## computing TV term update for real and imag parts with # reordering real_smooth_data = np.take_along_axis( #pylint: disable=E1101 img_est.real, sort_order_real, axis=-1) imag_smooth_data = np.take_along_axis( #pylint: disable=E1101 img_est.imag, sort_order_imag, axis=-1) # Real part temp_a = np.diff(real_smooth_data, axis=-1) temp_b = temp_a/np.sqrt(beta_sqrd + (np.abs(temp_a)**2)) temp_c = np.diff(temp_b, axis=-1) temporal_term_update_real[..., 0] = temp_b[..., 0] temporal_term_update_real[..., 1:-1] = temp_c temporal_term_update_real[..., -1] = -temp_b[..., -1] temporal_term_update_real *= weight_temporal np.put_along_axis( #pylint: disable=E1101 unsort_real_data, sort_order_real, temporal_term_update_real, axis=-1) # Imag part temp_a = np.diff(imag_smooth_data, axis=-1) temp_b = temp_a/np.sqrt(beta_sqrd + (np.abs(temp_a)**2)) temp_c = np.diff(temp_b, axis=-1) temporal_term_update_imag[..., 0] = temp_b[..., 0] temporal_term_update_imag[..., 1:-1] = temp_c temporal_term_update_imag[..., -1] = -temp_b[..., -1] temporal_term_update_imag *= weight_temporal np.put_along_axis( #pylint: disable=E1101 unsort_imag_data, sort_order_imag, temporal_term_update_imag, axis=-1) # Do the updates temporal_term_update = unsort_real_data + 1j*unsort_imag_data img_est += fidelity_update + temporal_term_update # W_img_est = inverse_fun(forward_fun(img_est)) W_img_est = np.fft.ifft2(np.fft.fft2( img_est, axes=(0, 1))*mask, axes=(0, 1)) # from mr_utils import view # view(np.stack((prior, img_est))) # Give user an update curxabs = np.abs(img_est) tqdm.write( table.row( [ii, stop_criteria, compare_mse(curxabs, xabs), compare_ssim(curxabs, xabs)])) return img_est
def test_table_row(self): '''Make sure rows are the correct widths.''' t = Table(self.headings, self.widths, self.formatters, pad=self.pad) _hdr = t.header() row = t.row(self.widths) self.assertEqual(len(row), sum(self.widths))
def proximal_GD( y, forward_fun, inverse_fun, sparsify, unsparsify, reorder_fun=None, mode='soft', alpha=.5, alpha_start=.5, thresh_sep=True, selective=None, x=None, ignore_residual=False, ignore_mse=True, ignore_ssim=True, disp=False, maxiter=200, strikes=0): r'''Proximal gradient descent for generic encoding/sparsity model. Parameters ---------- y : array_like Measured data (i.e., y = Ax). forward_fun : callable A, the forward transformation function. inverse_fun : callable A^H, the inverse transformation function. sparsify : callable Sparsifying transform. unsparsify : callable Inverse sparsifying transform. reorder_fun : callable, optional Reordering function. unreorder_fun : callable, optional Inverse reordering function. mode : {'soft', 'hard', 'garotte', 'greater', 'less'}, optional Thresholding mode. alpha : float or callable, optional Step size, used for thresholding. alpha_start : float, optional Initial alpha to start with if alpha is callable. thresh_sep : bool, optional Whether or not to threshold real/imag individually. selective : bool, optional Function returning indicies of update to keep at each iter. x : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to break out of loop if resid increases. ignore_mse : bool, optional Whether or not to break out of loop if MSE increases. ignore_ssim : bool, optional Whether or not to break out of loop if SSIM increases. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. strikes : int, optional Number of ending conditions tolerated before giving up. Returns ------- x_hat : array_like Estimate of x. Notes ----- Solves the problem: .. math:: \min_x || y - Ax ||^2_2 + \lambda \text{Sparsify}(x) If `x=None`, then MSE will not be calculated. You probably want `mode='soft'`. For the other options, see docs for pywt.threshold. `selective=None` will not throw away any updates. ''' # Make sure compare_mse, compare_ssim is defined if x is None: compare_mse = lambda xx, yy: 0 compare_ssim = lambda xx, yy: 0 xabs = 0 logging.info( 'No true x provided, MSE/SSIM will not be calculated.') else: from skimage.measure import compare_mse, compare_ssim # Precompute absolute value of true image xabs = np.abs(x.astype(y.dtype)) # Get some display stuff happening if disp: # Don't use tqdm range_fun = range from mr_utils.utils.printtable import Table table = Table( ['iter', 'norm', 'MSE', 'SSIM'], [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) else: # Use tqdm to give us an idea of how fast we're going from tqdm import trange, tqdm range_fun = lambda x: trange( x, leave=False, desc='Proximal GD') # Initialize x_hat = np.zeros(y.shape, dtype=y.dtype) r = -y.copy() prev_stop_criteria = np.inf cur_ssim = 0 prev_ssim = compare_ssim(xabs, np.abs(inverse_fun(y))) cur_mse = 0 prev_mse = compare_mse(xabs, np.abs(inverse_fun(y))) norm_y = np.linalg.norm(y) if isinstance(alpha, float): alpha0 = alpha else: alpha0 = alpha_start # Do the thing strike_count = 0 for ii in range_fun(int(maxiter)): # Compute stop criteria stop_criteria = np.linalg.norm(r)/norm_y if not ignore_residual and stop_criteria > prev_stop_criteria: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'Norm of residual increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_stop_criteria = stop_criteria # Compute gradient descent step in prep for reordering grad_step = x_hat - inverse_fun(r) # Do reordering if we asked for it if reorder_fun is not None: reorder_idx = reorder_fun(grad_step) reorder_idx_r = reorder_idx.real.astype(int) reorder_idx_i = reorder_idx.imag.astype(int) # unreorder_idx_r = inverse_permutation(reorder_idx_r) # unreorder_idx_i = inverse_permutation(reorder_idx_i) # unreorder_idx_r = np.arange( # reorder_idx_r.size).astype(int) # unreorder_idx_r[reorder_idx_r] = reorder_idx_r # unreorder_idx_i = np.arange( # reorder_idx_i.size).astype(int) # unreorder_idx_i[reorder_idx_i] = reorder_idx_i grad_step = ( grad_step.real[np.unravel_index( reorder_idx_r, y.shape)] \ + 1j*grad_step.imag[np.unravel_index( reorder_idx_i, y.shape)]).reshape(y.shape) # Take the step, we would normally assign x_hat directly, but # because we might be reordering and selectively updating, # we'll store it in a temporary variable... if thresh_sep: tmp = sparsify(grad_step) # Take a half step in each real/imag after talk with Ed tmp_r = threshold(tmp.real, value=alpha0/2, mode=mode) tmp_i = threshold(tmp.imag, value=alpha0/2, mode=mode) update = unsparsify(tmp_r + 1j*tmp_i) else: update = unsparsify( threshold( sparsify(grad_step), value=alpha0, mode=mode)) # Undo the reordering if we did it if reorder_fun is not None: # update = ( # update.real[np.unravel_index( # unreorder_idx_r, y.shape)] \ # + 1j*update.imag[np.unravel_index( # unreorder_idx_i, y.shape)]).reshape(y.shape) update_r = np.zeros(y.shape) update_r[np.unravel_index( reorder_idx_r, y.shape)] = update.real.flatten() update_i = np.zeros(y.shape) update_i[np.unravel_index( reorder_idx_i, y.shape)] = update.imag.flatten() update = update_r + 1j*update_i # Look at where we want to take the step - tread carefully... if selective is not None: selective_idx = selective(x_hat, update, ii) # Update image estimae if selective is not None: x_hat[selective_idx] = update[selective_idx] else: x_hat = update # Tell the user what happened if disp: curxabs = np.abs(x_hat) cur_mse = compare_mse(curxabs, xabs) cur_ssim = compare_ssim(curxabs, xabs) logging.info( table.row( [ii, stop_criteria, cur_mse, cur_ssim])) if not ignore_mse and cur_mse > prev_mse: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'MSE increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_mse = cur_mse if not ignore_ssim and cur_ssim > prev_ssim: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'SSIM increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_ssim = cur_ssim # Compute residual r = forward_fun(x_hat) - y # Get next step size if callable(alpha): alpha0 = alpha(alpha0, ii) return x_hat
def IHT(A, y, k, mu=1, maxiter=500, tol=1e-8, x=None, disp=False): '''Iterative hard thresholding algorithm (IHT). A -- Measurement matrix. y -- Measurements (i.e., y = Ax). k -- Number of expected nonzero coefficients. mu -- Step size. maxiter -- Maximum number of iterations. tol -- Stopping criteria. x -- True signal we are trying to estimate. disp -- Whether or not to display iterations. Solves the problem: min_x || y - Ax ||^2_2 s.t. ||x||_0 <= k If disp=True, then MSE will be calculated using provided x. mu=1 seems to satisfy Theorem 8.4 often, but might need to be adjusted (usually < 1). See normalized IHT for adaptive step size. Implements Algorithm 8.5 from: Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) # Some fancy, asthetic touches... if disp: table = Table(['iter', 'MSE'], [len(repr(maxiter)), 8], ['d', 'e']) range_fun = range else: from tqdm import trange range_fun = lambda x: trange(x, leave=False, desc='IHT') # Initial estimate of x, x_hat x_hat = np.zeros(N, dtype=y.dtype) # Get initial residue r = y.copy() # Set up header for logger if disp: hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Run until tol reached or maxiter reached for tt in range_fun(maxiter): # Update estimate using residual scaled by step size x_hat += mu * np.dot(A.conj().T, r) # Find the k'th largest coefficient of gamma, use it as threshold thresh = -np.sort(-np.abs(x_hat))[k - 1] # Hard thresholding operator x_hat[np.abs(x_hat) < thresh] = 0 # Show MSE at current iteration if we wanted it if disp: logging.info(table.row([tt, np.mean((np.abs(x - x_hat)**2))])) # update the residual r = y - np.dot(A, x_hat) # Check stopping criteria if np.linalg.norm(r) / np.linalg.norm(y) < tol: break # Regroup and debrief... if tt == (maxiter - 1): logging.warning( 'Hit maximum iteration count, estimate may not be accurate!') else: if disp: logging.info('Final || r || . || y ||^-1 : %g' % (np.linalg.norm(r) / np.linalg.norm(y))) return (x_hat)
def test_table_row(self): t = Table(self.headings,self.widths,self.formatters,pad=self.pad) hdr = t.header() row = t.row(self.widths) self.assertEqual(len(row),sum(self.widths))
def cosamp(A, y, k, lstsq='exact', tol=1e-8, maxiter=500, x=None, disp=False): '''Compressive sampling matching pursuit (CoSaMP) algorithm. Parameters ========== A : array_like Measurement matrix. y : array_like Measurements (i.e., y = Ax). k : int Number of expected nonzero coefficients. lstsq : {'exact', 'lm', 'gd'}, optional How to solve intermediate least squares problem. tol : float, optional Stopping criteria. maxiter : int, optional Maximum number of iterations. x : array_like, optional True signal we are trying to estimate. disp : bool, optional Whether or not to display iterations. Returns ======= x_hat : array_like Estimate of x. Notes ===== lstsq function - 'exact' solves it using numpy's linalg.lstsq method. - 'lm' uses solves with the Levenberg-Marquardt algorithm. - 'gd' uses 3 iterations of a gradient descent solver. Implements Algorithm 8.7 from [1]_. References ========== .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal _n, N = A.shape[:] # Initializations x_hat = np.zeros(N, dtype=y.dtype) r = y.copy() ynorm = np.linalg.norm(y) if x is None: x = np.zeros(x_hat.shape, dtype=y.dtype) elif x.size < x_hat.size: x = np.hstack(([0], x)) # Decide how we want to solve the intermediate least squares problem if lstsq == 'exact': lstsq_fun = lambda A0, y: np.linalg.lstsq(A0, y, rcond=None)[0] elif lstsq == 'lm': # # This also doesn't work very well currently.... # from scipy.optimize import least_squares # lstsq_fun = lambda A0, y: least_squares( # lambda x: np.linalg.norm(y - np.dot(A0, x)), # np.zeros(A0.shape[1], dtype=y.dtype))['x'] raise NotImplementedError() elif lstsq == 'gd': # # This doesn't work very well... # from mr_utils.optimization import gd, fd_complex_step # lstsq_fun = lambda A0, y: gd( # lambda x: np.linalg.norm(y - np.dot(A0, x)), # fd_complex_step, # np.zeros(A0.shape[1], dtype=y.dtype), maxiter=3)[0] raise NotImplementedError() else: raise NotImplementedError() # Start up a table if disp: table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) for ii in range(maxiter): # Get step direction g = np.dot(A.conj().T, r) # Add 2*k largest elements of g to support set Tn = np.union1d(x_hat.nonzero()[0], np.argsort(np.abs(g))[-(2 * k):]) # Solve the least squares problem xn = np.zeros(N, dtype=y.dtype) xn[Tn] = lstsq_fun(A[:, Tn], y) xn[np.argsort(np.abs(xn))[:-k]] = 0 x_hat = xn.copy() # Compute new residual r = y - np.dot(A, x_hat) # Compute stopping criteria stop_criteria = np.linalg.norm(r) / ynorm # Show MSE at current iteration if we wanted it if disp: logging.info( table.row([ii, stop_criteria, np.mean((np.abs(x - x_hat)**2))])) # Check stopping criteria if stop_criteria < tol: break return x_hat
def IHT_TV(y, forward_fun, inverse_fun, k, mu=1, tol=1e-8, do_reordering=False, x=None, ignore_residual=False, disp=False, maxiter=500): r'''IHT for generic encoding model and TV constraint. Parameters ---------- y : array_like Measured data, i.e., y = Ax. forward_fun : callable A, the forward transformation function. inverse_fun : callable A^H, the inverse transformation function. k : int Sparsity measure (number of nonzero coefficients expected). mu : float, optional Step size. tol : float, optional Stop when stopping criteria meets this threshold. do_reordering : bool, optional Reorder column-stacked true image. x : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to break out of loop if resid increases. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. Returns ------- x_hat : array_like Estimate of x. Notes ----- Solves the problem: .. math:: \min_x || y - Ax ||^2_2 \text{ s.t. } || \text{TV}(x) ||_0 \leq k If `x=None`, then MSE will not be calculated. ''' # Make sure we have a defined compare_mse and Table for printing if disp: from mr_utils.utils.printtable import Table if x is not None: from skimage.measure import compare_mse xabs = np.abs(x) else: compare_mse = lambda xx, yy: 0 # Right now we are doing absolute values on updates x_hat = np.zeros(y.shape) r = y.copy() prev_stop_criteria = np.inf norm_y = np.linalg.norm(y) # Initialize display table if disp: table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Find perfect reordering (column-stacked-wise) if do_reordering: from mr_utils.utils.orderings import (col_stacked_order, inverse_permutation) reordering = col_stacked_order(x) inverse_reordering = inverse_permutation(reordering) # Find new sparsity measure if x is not None: k = np.sum(np.abs(np.diff(x.flatten()[reordering])) > 0) else: logging.warning(('Make sure sparsity level k is ' 'adjusted for reordering!')) # Do the thing for ii in range(int(maxiter)): # Density compensation!!!! # # Take step # val = (x_hat + mu*np.abs(np.fft.ifft2(r))).flatten() val = (x_hat + mu * inverse_fun(r)).flatten() # Do the reordering if do_reordering: val = val[reordering] # Finite differences transformation first_samp = val[0] # save first sample for inverse transform fd = np.diff(val) # Hard thresholding fd[np.argsort(np.abs(fd))[:-1 * k]] = 0 # Inverse finite differences transformation res = np.hstack((first_samp, fd)).cumsum() if do_reordering: res = res[inverse_reordering] # Compute stopping criteria stop_criteria = np.linalg.norm(r) / norm_y # If the stop_criteria gets worse, get out of dodge if not ignore_residual and (stop_criteria > prev_stop_criteria): logging.warning('Residual increased! Not continuing!') break prev_stop_criteria = stop_criteria # Update x x_hat = res.reshape(x_hat.shape) # Show the people what they asked for if disp: logging.info( table.row([ii, stop_criteria, compare_mse(xabs, x_hat)])) if stop_criteria < tol: break # update the residual r = y - forward_fun(x_hat) return x_hat
def nIHT(A, y, k, c=0.1, kappa=None, x=None, maxiter=200, tol=1e-8, disp=False): '''Normalized iterative hard thresholding. A -- Measurement matrix y -- Measurements (i.e., y = Ax) k -- Number of nonzero coefficients preserved after thresholding. c -- Small, fixed constant. Tunable. kappa -- Constant, > 1/(1 - c). x -- True signal we want to estimate. maxiter -- Maximum number of iterations (of the outer loop). tol -- Stopping criteria. dip -- Whether or not to display iteration info. Implements Algorithm 8.6 from: Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # Basic checks assert 0 < c < 1, 'c must be in (0,1)' # length of measurement vector and original signal n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) if disp: table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Initializations x_hat = np.zeros(N) # Inital calculation of support val = A.T.dot(y) thresh = -np.sort(-np.abs(val))[k - 1] val[np.abs(val) < thresh] = 0 T = np.nonzero(val) # Find suitable kappa if the user didn't give us one if kappa is None: # kappa must be > 1/(1 - c), so try 2 times the lower bound kappa = 2 / (1 - c) else: assert kappa > 1 / (1 - c), 'kappa must be > 1/(1 - c)' # Do the iterative part of the thresholding... for ii in range(maxiter): # Compute residual r = y - np.dot(A, x_hat) # Check stopping criteria stop_criteria = np.linalg.norm(r) / np.linalg.norm(y) if stop_criteria < tol: break # Let's check out what's going on if disp: logging.info(table.row([ii, stop_criteria, compare_mse(x, x_hat)])) # Compute step size g = np.dot(A.T, r) mu = np.linalg.norm(g)**2 / np.linalg.norm(np.dot(A, g))**2 # Hard thresholding xn = x_hat + mu * g thresh = -np.sort(-np.abs(xn))[k - 1] xn[np.abs(xn) < thresh] = 0 # Compute support of xn Tn = np.nonzero(xn) # Decide what to do if np.array_equal(Tn, T): x_hat = xn else: cond = (1 - c) * np.linalg.norm(xn - x_hat)**2 / np.linalg.norm( np.dot(A, xn - x_hat))**2 if mu <= cond: x_hat = xn else: while mu > cond: mu /= kappa * (1 - c) xn = x_hat + mu * g thresh = -np.sort(-np.abs(xn))[k - 1] xn[np.abs(xn) < thresh] = 0 cond = (1 - c ) * np.linalg.norm(xn - x_hat)**2 / np.linalg.norm( np.dot(A, xn - x_hat))**2 Tn = np.nonzero(xn) x_hat = xn return (x_hat)
def amp2d(y, forward_fun, inverse_fun, sigmaType=2, randshift=False, tol=1e-8, x=None, ignore_residual=False, disp=False, maxiter=100): r'''Approximate message passing using wavelet sparsifying transform. Parameters ========== y : array_like Measurements, i.e., y = Ax. forward_fun : callable A, the forward transformation function. inverse_fun : callable A^H, the inverse transformation function. sigmaType : int Method for determining threshold. randshift : bool, optional Whether or not to randomly circular shift every iteration. tol : float, optional Stop when stopping criteria meets this threshold. x : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to ignore stopping criteria. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. Returns ======= wn : array_like Estimate of x. Notes ===== Solves the problem: .. math:: \min_x || \Psi(x) ||_1 \text{ s.t. } || y - \text{forward}(x) ||^2_2 < \epsilon^2 The CDF-97 wavelet is used. If `x=None`, then MSE will not be calculated. Algorithm described in [1]_, based on MATLAB implementation found at [2]_. References ========== .. [1] "Message Passing Algorithms for CS" Donoho et al., PNAS 2009;106:18914 .. [2] http://kyungs.bol.ucla.edu/Site/Software.html ''' # Make sure we have a defined compare_mse and Table for printing if disp: # Initialize display table from mr_utils.utils.printtable import Table if disp: table = Table(['iter', 'resid', 'resid diff', 'MSE'], [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) if x is not None: from skimage.measure import compare_mse xabs = np.abs(x) else: xabs = 0 compare_mse = lambda xx, yy: 0 # Do some initial calculations... mm = np.sum(abs(y) > np.finfo(float).eps) rfact = y.size / mm # I'm currently not sure how we found these optimim lambdas... OptimumLambdaSigned = loadmat(dirname(__file__) \ + '/OptimumLambdaSigned.mat') # has the optimal values of lambda delta_vec = OptimumLambdaSigned['delta_vec'][0] lambda_opt = OptimumLambdaSigned['lambda_opt'][0] delta = 1 / rfact lambdas = np.interp(delta, delta_vec, lambda_opt) # Initial values wn = np.zeros(y.shape, dtype=y.dtype) zn = y - forward_fun(wn) abc = 0 nx, ny = y.shape[:] res_norm = np.zeros(maxiter + 1) nn = np.zeros(maxiter + 1) res_diff = np.zeros(maxiter) res_norm[0] = np.linalg.norm(zn) norm_y = np.linalg.norm(y) nn[0] = res_norm[0] / norm_y for abc in range(int(maxiter)): # First-order Approximate Message Passing temp_z = inverse_fun(zn) + wn # Randomly shift left, right if we asked for it if randshift: rand_shift_x = np.random.randint(0, nx) rand_shift_y = np.random.randint(0, ny) temp_z = np.roll(temp_z, (rand_shift_x, rand_shift_y)) # Sparsify with wavelet transform temp_z, locations = cdf97_2d_forward(temp_z, level=5) # Compute sigma hat if sigmaType == 1: sigma_hat = np.median(np.abs(temp_z.flatten())) / .6745 else: sigma_hat = res_norm[abc] / np.sqrt(mm) # If sigma is zero put any VERY small number if sigma_hat == 0: sigma_hat = .1 # Soft Thresholding wn1 = (np.abs(temp_z) > lambdas*sigma_hat)*(np.abs(temp_z) \ - lambdas*sigma_hat)*np.sign(temp_z) # Compute a sparsity/measurement ratio amp_weight = np.sum(np.abs(wn1) > np.finfo(float).eps) / mm # Un-sparsify wn1 = cdf97_2d_inverse(wn1, locations) # random shift back if randshift: wn1 = np.roll(wn1, (-rand_shift_x, -rand_shift_y)) # Update the residual term residual = y - forward_fun(wn1) # Normalized data fidelity term res_norm[abc + 1] = np.linalg.norm(residual) nn[abc + 1] = res_norm[abc + 1] / norm_y res_diff[abc] = np.abs(nn[abc + 1] - nn[abc]) # Give the people what they asked for! if disp: logging.info( table.row([ abc, nn[abc + 1], res_diff[abc], compare_mse(xabs, np.abs(wn1)) ])) # Check stopping criteria if not ignore_residual and (res_diff[abc] < tol): break # Update Estimation wn = wn1 # Weight the residual with a little extra sauce if amp_weight > 1: zn = residual + 0.25 * zn else: zn = residual + amp_weight * zn return wn
def IST(A, y, mu=0.8, theta0=None, k=None, maxiter=500, tol=1e-8, x=None, disp=False): r'''Iterative soft thresholding algorithm (IST). Parameters ========== A : array_like Measurement matrix. y : array_like Measurements (i.e., y = Ax). mu : float, optional Step size (theta contraction factor, 0 < mu <= 1). theta0 : float, optional Initial threshold, decreased by factor of mu each iteration. k : int, optional Number of expected nonzero coefficients. maxiter : int, optional Maximum number of iterations. tol : float, optional Stopping criteria. x : array_like, optional True signal we are trying to estimate. disp : bool, optional Whether or not to display iterations. Returns ======= x_hat : array_like Estimate of x. Notes ===== Solves the problem: .. math:: \min_x || y - Ax ||^2_2 \text{ s.t. } ||x||_0 \leq k If `disp=True`, then MSE will be calculated using provided x. If `theta0=None`, the initial threshold of the IHT will be used as the starting theta. Implements Equations [22-23] from [1]_ References ========== .. [1] Rani, Meenu, S. B. Dhok, and R. B. Deshmukh. "A systematic review of compressive sensing: Concepts, implementations and applications." IEEE Access 6 (2018): 4875-4894. ''' # Check to make sure we have good mu assert 0 < mu <= 1, 'mu should be 0 < mu <= 1!' # length of measurement vector and original signal _n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) # Some fancy, asthetic touches... if disp: range_fun = range table = Table(['iter', 'norm', 'theta', 'MSE'], [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e']) else: from tqdm import trange range_fun = lambda x: trange(x, leave=False, desc='IST') # Initial estimate of x, x_hat x_hat = np.zeros(N) # Get initial residue r = y.copy() # Start theta at specified theta0 or use IHT first threshold if theta0 is None: assert k is not None, ('k (measure of sparsity) required to compute ' 'initial threshold!') theta = -np.sort(-np.abs(np.dot(A.T, r)))[k - 1] else: assert theta0 > 0, 'Threshold must be positive!' theta = theta0 # Set up header for logger if disp: hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Run until tol reached or maxiter reached tt = 0 for tt in range_fun(maxiter): # Update estimate using residual x_hat += np.dot(A.T, r) # Just like IHT, but use soft thresholding operator # It is unclear to me what sign function needs to be used: # count 0 as 0? x_hat = np.maximum(np.abs(x_hat) - theta, np.zeros( x_hat.shape)) * np.sign(x_hat) # update the residual r = y - np.dot(A, x_hat) # Check stopping criteria stop_criteria = np.linalg.norm(r) / np.linalg.norm(y) if stop_criteria < tol: break # Show MSE at current iteration if we wanted it if disp: logging.info( table.row([tt, stop_criteria, theta, compare_mse(x, x_hat)])) # Contract theta before we go back around the horn theta *= mu # Regroup and debrief... if tt == (maxiter - 1): logging.warning( 'Hit maximum iteration count, estimate may not be accurate!') else: if disp: logging.info('Final || r || . || y ||^-1 : %g', (np.linalg.norm(r) / np.linalg.norm(y))) return x_hat
def tv_l1_denoise(im, lam, disp=False, niter=100): '''TV-L1 image denoising with the primal-dual algorithm. Parameters ========== im : array_like image to be processed lam : float regularization parameter controlling the amount of denoising; smaller values imply more aggressive denoising which tends to produce more smoothed results disp : bool, optional print energy being minimized each iteration niter : int, optional number of iterations Returns ======= newim : array_like l1 denoised image. Raises ====== AssertionError When dimension of im is not 2. ''' L2 = 8.0 tau = 0.02 sigma = 1.0/(L2*tau) theta = 1.0 lt = lam*tau assert im.ndim == 2, 'This function only works for 2D images!' height, width = im.shape[:] unew = np.zeros(im.shape) p = np.zeros((height, width, 2)) # d = np.zeros(im.shape) ux = np.zeros(im.shape) uy = np.zeros(im.shape) mx = np.max(im) if mx > 1.0: # normalize nim = im/mx else: # leave intact nim = im # initialize u = nim p[:, :, 0] = np.append(u[:, 1:], u[:, -1:], axis=1) - u p[:, :, 1] = np.append(u[1:, :], u[-1:, :], axis=0) - u # Work out what we're displaying if disp: from mr_utils.utils.printtable import Table table = Table( ['Iter', 'Energy'], [len(repr(niter)), 8], ['d', 'e']) print(table.header()) range_fun = range else: from tqdm import trange range_fun = lambda x: trange(x, leave=False, desc="TV Denoise") for kk in range_fun(niter): # projection # compute gradient in ux, uy ux = np.append(u[:, 1:], u[:, -1:], axis=1) - u uy = np.append(u[1:, :], u[-1:, :], axis=0) - u p += sigma*np.stack((ux, uy), axis=2) # project normep = np.maximum(np.ones(im.shape), np.sqrt(p[:, :, 0]**2 + p[:, :, 1]**2)) p[:, :, 0] /= normep p[:, :, 1] /= normep # shrinkage # compute divergence in div div = np.vstack((p[:height-1, :, 1], np.zeros((1, width)))) \ - np.vstack((np.zeros((1, width)), p[:height-1, :, 1])) div += np.hstack((p[:, :width-1, 0], np.zeros((height, 1)))) \ - np.hstack((np.zeros((height, 1)), p[:, :width-1, 0])) # TV-L1 model v = u + tau*div unew = (v - lt)*(v - nim > lt) + (v + lt)*(v - nim < -lt) \ + nim*(np.abs(v - nim) <= lt) # extragradient step u = unew + theta*(unew - u) # energy being minimized if disp: E = np.sum(np.sqrt(ux.flatten()**2 + uy.flatten()**2)) \ + lam*np.sum(np.abs(u.flatten() - nim.flatten())) print(table.row((kk, E))) newim = u return newim
def IHT(A, y, k, mu=1, maxiter=500, tol=1e-8, x=None, disp=False): r'''Iterative hard thresholding algorithm (IHT). Parameters ---------- A : array_like Measurement matrix. y : array_like Measurements (i.e., y = Ax). k : int Number of expected nonzero coefficients. mu : float, optional Step size. maxiter : int, optional Maximum number of iterations. tol : float, optional Stopping criteria. x : array_like, optional True signal we are trying to estimate. disp : bool, optional Whether or not to display iterations. Returns ------- x_hat : array_like Estimate of x. Notes ----- Solves the problem: .. math:: \min_x || y - Ax ||^2_2 \text{ s.t. } ||x||_0 \leq k If `disp=True`, then MSE will be calculated using provided x. `mu=1` seems to satisfy Theorem 8.4 often, but might need to be adjusted (usually < 1). See normalized IHT for adaptive step size. Implements Algorithm 8.5 from [1]_. References ---------- .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal _n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) # Some fancy, asthetic touches... if disp: table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) range_fun = range else: from tqdm import trange range_fun = lambda x: trange(x, leave=False, desc='IHT') # Initial estimate of x, x_hat x_hat = np.zeros(N, dtype=y.dtype) # Get initial residue r = y.copy() # Set up header for logger if disp: hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Run until tol reached or maxiter reached tt = 0 for tt in range_fun(int(maxiter)): # Update estimate using residual scaled by step size x_hat += mu * np.dot(A.conj().T, r) # Leave only k coefficients nonzero (hard threshold) x_hat[np.argsort(np.abs(x_hat))[:-k]] = 0 stop_criteria = np.linalg.norm(r) / np.linalg.norm(y) # Show MSE at current iteration if we wanted it if disp: logging.info( table.row([tt, stop_criteria, np.mean((np.abs(x - x_hat)**2))])) # update the residual r = y - np.dot(A, x_hat) # Check stopping criteria if stop_criteria < tol: break # Regroup and debrief... if tt == (maxiter - 1): logging.warning(('Hit maximum iteration count, estimate ' 'may not be accurate!')) else: if disp: logging.info('Final || r || . || y ||^-1 : %g', (np.linalg.norm(r) / np.linalg.norm(y))) return x_hat
def cosamp(A,y,k,lstsq='exact',tol=1e-8,maxiter=500,x=None,disp=False): '''Compressive sampling matching pursuit (CoSaMP) algorithm. A -- Measurement matrix. y -- Measurements (i.e., y = Ax). k -- Number of expected nonzero coefficients. lstsq -- How to solve intermediate least squares problem. tol -- Stopping criteria. maxiter -- Maximum number of iterations. x -- True signal we are trying to estimate. disp -- Whether or not to display iterations. lstsq function: lstsq = { 'exact', 'lm', 'gd' }. 'exact' solves it using numpy's linalg.lstsq method. 'lm' uses solves with the Levenberg-Marquardt algorithm. 'gd' uses 3 iterations of a gradient descent solver. Implements Algorithm 8.7 from: Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal n,N = A.shape[:] # Initializations x_hat = np.zeros(N,dtype=y.dtype) r = y.copy() # Decide how we want to solve the intermediate least squares problem if lstsq == 'exact': lstsq_fun = lambda A0,y: np.linalg.lstsq(A0,y,rcond=None)[0] elif lstsq == 'lm': from scipy.optimize import least_squares lstsq_fun = lambda A0,y: least_squares(lambda x: np.linalg.norm(y - np.dot(A0,x)),np.zeros(A0.shape[1],dtype=y.dtype))['x'] elif lstsq == 'gd': # This doesn't work very well... from mr_utils.optimization import gd,fd_complex_step lstsq_fun = lambda A0,y: gd(lambda x: np.linalg.norm(y - np.dot(A0,x)),fd_complex_step,np.zeros(A0.shape[1],dtype=y.dtype),iter=3)[0] else: raise NotImplementedError() # Start up a table if disp: table = Table([ 'iter','residual','MSE' ],[ len(repr(maxiter)),8,8 ],[ 'd','e','e' ]) hdr = table.header() for line in hdr.split('\n'): logging.info(line) for ii in range(maxiter): # Get step direction g = np.dot(A.T,r) # Add 2*k largest elements of g to support set Tn = np.union1d(x_hat.nonzero()[0],np.argsort(np.abs(g))[-(2*k):]) # Solve the least squares problem xn = np.zeros(N) xn[Tn] = lstsq_fun(A[:,Tn],y) xn[np.argsort(np.abs(xn))[:-k]] = 0 x_hat = xn.copy() # Compute new residual r = y - np.dot(A,x_hat) # Compute stopping criteria stop_criteria = np.linalg.norm(r)/np.linalg.norm(y) # Show MSE at current iteration if we wanted it if disp: logging.info(table.row([ ii,stop_criteria,np.mean((np.abs(x - x_hat)**2)) ])) # Check stopping criteria if stop_criteria < tol: break return(x_hat)
first_m = m_hat[0] # Transform into finite differences domain fd = np.diff(E.conj().T.dot(r)) # Hard thresholding fd[np.argsort(np.abs(fd))[:-k]] = 0 # Inverse transform and take the step m_hat += mu * np.hstack((first_m, fd)).cumsum() # This may or may not be a good stopping criteria for this stop_criteria = np.linalg.norm(r) / np.linalg.norm(s) # Show MSE at current iteration if we wanted it if disp: print( table.row([tt, stop_criteria, np.mean((np.abs(m - m_hat)**2))])) # Check stopping criteria if stop_criteria < tol: break # Get new residual r = s - E.dot(m_hat) plt.imshow(np.abs(m_hat.reshape(smiley.shape))) plt.title('IHT Recon') plt.show()
def GD_TV(y, forward_fun, inverse_fun, alpha=.5, lam=.01, do_reordering=False, x=None, ignore_residual=False, disp=False, maxiter=200): r'''Gradient descent for a generic encoding model and TV constraint. Parameters ========== y : array_like Measured data (i.e., y = Ax). forward_fun : callable A, the forward transformation function. inverse_fun : callable A^H, the inverse transformation function. alpha : float, optional Step size. lam : float, optional TV constraint weight. do_reordering : bool, optional Whether or not to reorder for sparsity constraint. x : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to break out of loop if resid increases. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. Returns ======= x_hat : array_like Estimate of x. Notes ===== Solves the problem: .. math:: \min_x || y - Ax ||^2_2 + \lambda \text{TV}(x) If `x=None`, then MSE will not be calculated. ''' # Make sure compare_mse is defined if x is None: compare_mse = lambda xx, yy: 0 logging.info('No true x provided, MSE will not be calculated.') xabs = 0 else: from skimage.measure import compare_mse xabs = np.abs(x) # Precompute absolute value of true image # Get the reordering indicies ready, both for real and imag parts if do_reordering: from mr_utils.utils.sort2d import sort2d from mr_utils.utils.orderings import inverse_permutation _, reordering_r = sort2d(x.real) _, reordering_i = sort2d(x.imag) inverse_reordering_r = inverse_permutation(reordering_r) inverse_reordering_i = inverse_permutation(reordering_i) # Get some display stuff happening if disp: from mr_utils.utils.printtable import Table table = Table(['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Initialize x_hat = np.zeros(y.shape, dtype=y.dtype) r = -y.copy() prev_stop_criteria = np.inf norm_y = np.linalg.norm(y) # Do the thing for ii in range(int(maxiter)): # Fidelity term fidelity = inverse_fun(r) # Let's reorder if we said that was going to be a thing if do_reordering: # real part xr = x_hat.real.flatten()[reordering_r].reshape(x.shape) second_term_r = dTV(xr).flatten()[inverse_reordering_r] \ .reshape(x.shape) # imag part xi = x_hat.imag.flatten()[reordering_i].reshape(x.shape) second_term_i = dTV(xi).flatten()[inverse_reordering_i] \ .reshape(x.shape) # put it all together... second_term = second_term_r + 1j * second_term_i else: # Sparsity term second_term = dTV(x_hat) # Compute stop criteria stop_criteria = np.linalg.norm(r) / norm_y if not ignore_residual and stop_criteria > prev_stop_criteria: logging.warning(('Breaking out of loop after %d iterations. ' 'Norm of residual increased!'), ii) break prev_stop_criteria = stop_criteria # Take the step x_hat -= alpha * (fidelity + lam * second_term) # Tell the user what happened if disp: logging.info( table.row( [ii, stop_criteria, compare_mse(np.abs(x_hat), xabs)])) # Compute residual r = forward_fun(x_hat) - y return x_hat