def tikhonov(*, psfs, measurements, tikhonov_lam=0.129, tikhonov_order=1): """Perform Tikhonov regularization based image reconstruction for PSSI. Solves x_hat = argmin_{x} { ||Ax-y||_2^2 + lam * ||Dx||_2^2 }. D is the discrete derivative operator of order `tikhonov_order`. Args: psfs (PSFs): PSFs object containing psfs and other csbs state data measured_noisy (ndarray): 4d array of noisy measurements tikhonov_lam (float): regularization parameter of tikhonov tikhonov_order (int): [0,1 or 2] order of the discrete derivative operator used in the tikhonov regularization Returns: 4d array of the reconstructed images """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] # expand psf_dfts exp_psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(exp_psf_dfts), exp_psf_dfts) # DFT of the kernel corresponding to (D^TD) LAM = get_LAM(rows=rows, cols=cols, order=tikhonov_order) return np.real( np.fft.ifft2( block_mul( block_inv(psfs.psf_GAM + tikhonov_lam * np.einsum('ij,kl', np.eye(num_sources), LAM)), block_mul( block_herm(exp_psf_dfts), np.fft.fft2(np.fft.fftshift(measurements, axes=(1, 2)))))))
def test_block_mul(self): from mas.block import block_mul # matrix x matrix result = block_mul(self.x, self.y) self.assertEqual(result.shape, (3, 3, 5, 5)) # matrix x col vec result = block_mul(self.x, self.y[:, 0, :, :]) self.assertEqual(result.shape, (3, 5, 5))
def iteration_end(psfs, lowest_psf_group_index): """ """ psfs.initialized_data['GAM'] -= block_mul( block_herm(psfs.initialized_data['psf_dfts'] [lowest_psf_group_index:lowest_psf_group_index + 1]), psfs.initialized_data['psf_dfts'] [lowest_psf_group_index:lowest_psf_group_index + 1])
def cost(psfs, psf_group_index, **kwargs): """ """ iteration_SIG_e_dft = (SIG_e_dft(psfs, kwargs['lam']) - block_mul( block_herm( psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index + 1]), psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index + 1])) if kwargs['no_dc']: cov = block_inv(iteration_SIG_e_dft) cov[:, :, 0, 0] = 0 # set the DC values to zero return np.real(np.sum(np.trace(cov))) return np.real(np.sum(np.trace(block_inv(iteration_SIG_e_dft))))
def primal1_update_gaussian(*, pre_primal1, SIG_inv, ATy, nu): """Function that updates the first primal variable of ADMM based on the least squares data fidelity term which inherently assumes that the measurements have additive Gaussian noise. Args: pre_primal1 (ndarray): intermediate ADMM variable which is typically W(primal2 - dual) where W is the regularization transform SIG_inv (ndarray): spectrum of (A^TA+nu*W^TW + ...)^{-1} ATy (ndarray): FA^Ty term where F is DFT matrix nu (float): augmented Lagrangian parameter of ADMM (step size) kwargs (dict): keyword arguments of parameters related to the regularizer Returns: ndarray of updated primal1 """ return np.real( np.fft.ifft2(block_mul(SIG_inv, ATy + nu * np.fft.fft2(pre_primal1))))
def init(psfs, **kwargs): """ """ _, _, rows, cols = psfs.psfs.shape # psf_dfts = np.fft.fft2(psfs.psfs, axes=(2, 3)) psf_dfts = psfs.psf_dfts initialized_data = { "psf_dfts": psf_dfts, "GAM": block_mul( block_herm( # scale rows of psf_dfts by copies # split across sqrt to prevent rounding error np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts)), np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts), ), "LAM": get_LAM(rows=rows, cols=cols, order=kwargs['order']) } psfs.initialized_data = initialized_data
def strollr(*, sources, psfs, measurements, recon_init, iternum, periter, lr, theta, s, lam, patch_shape, transform, learning, window_size, group_size, group_size_s): """Function that implements PSSI deconvolution using both sparsifying transform learning and emposing low rankness on the grouped patches of the images. The algorithm details can be found on [arXiv:1808.01316]. For efficiency, multiprocessing has been used in some parts of the algorithm. Args: sources (ndarray): 4d array of sources psfs (PSFs): PSFs object containing psfs and related data measurements (ndarray): 4d array of measurements recon_init (ndarray): initialization for the reconstructed image(s) iternum (int): number of iterations of ADMM periter (int): iteration period of displaying the reconstructions lr (float): augmented Lagrangian parameter of the low-rank term theta (float): penalty parameter of the low-rank term s (float): augmented Lagrangian parameter of the sparsity term lam (float): penalty parameter of the sparsity term patch_shape (tuple): tuple of the shape of the patches used transform (ndarray): ndarray of the sparsifying transform used learning (bool): boolean variable of whether the transform gets updated window_size (tuple): tuple of the window size over which the group of similar patches to the centered reference patch is searched group_size (int): the number of patches in each group (of similar patches) group_size_s (int): the number of patches in each group (for sparsity) """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] psize = np.size(np.empty(patch_shape)) if type(theta) is np.float: theta = np.ones(num_sources) * theta if type(lam) is np.float: lam = np.ones(num_sources) * lam recon = recon_init ################# pre-compute some arrays for efficiency ################# patcher_spectrum = np.einsum('ij,kl->ijkl', np.eye(num_sources), psize * np.ones((rows, cols))) psfs.psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts) psfdfts_h_meas = block_mul(block_herm( psfs.psf_dfts), np.fft.fft2(np.fft.fftshift( measurements, axes=(1, 2)))) # this is reshaped FA^Ty term where F is DFT matrix SIG_inv = block_inv(psfs.psf_GAM + (s + lr) * patcher_spectrum) for iter in range(iternum): # ----- Low-Rank Approximation ----- patches = patch_extractor(recon, patch_shape=patch_shape) patch_means = np.mean(patches, axis=0) patches_zeromean = (patches - patch_means).T indices = np.zeros((patches.shape[1], group_size), dtype=np.int) pool = multiprocessing.Pool() lowrank_i = functools.partial(lowrank, patches_zeromean=patches_zeromean, window_size=window_size, imsize=(rows, cols), threshold=theta, group_size=group_size) D, indices = zip(*pool.map(lowrank_i, np.arange(patches.shape[1]))) D = np.array(D) indices = np.array(indices, dtype=np.int) pool.close() pool.join() D += np.einsum('ik,j->ijk', patch_means[indices], np.ones(patches.shape[0])) if s > 0: # ----- Sparse Coding ----- patches_3d = np.zeros( (patches.shape[0] * group_size_s, patches.shape[1])) for i in range(group_size_s): patches_3d[i * patches.shape[0]:(i + 1) * patches.shape[0], :] = patches[:, indices[:, i]] sparse_codes = transform @ patches_3d #FIXME - implement variable threshold for each image # sparse_indices = (sparse_codes > recon.lam).astype(np.int) # if not sparse_indices.any(): # sparse_codes = np.zeros_like(sparse_codes) # # sparse_codes = sparse_codes * sparse_indices for i in range(num_sources): ind = np.arange(i * rows * cols, (i + 1) * rows * cols) sparse_codes[:, ind] = hard_thresholding(sparse_codes[:, ind], threshold=np.sqrt( lam[i] / s)) if learning is True: u, s, v_T = np.linalg.svd(sparse_codes @ patches_3d.T) transform = u @ v_T # ----- Image Update ----- Whb1 = np.zeros_like(patches) Whb = transform.T @ sparse_codes for i in range(group_size_s): Whb1 = indsum( Whb1, Whb[i * patches.shape[0]:(i + 1) * patches.shape[0], :], indices[:, i]) indvals = np.array( list(Counter(indices[:, :group_size_s].flatten()).values())) indkeys = np.argsort( np.array( list(Counter(indices[:, :group_size_s].flatten()).keys()))) Whb1 = Whb1 / indvals[indkeys] Fc = np.fft.fft2( patch_aggregator(Whb1, patch_shape=patch_shape, image_shape=(num_sources, rows, cols))) else: Fc = np.zeros_like(psfdfts_h_meas) if lr > 0: VhD = np.zeros_like(patches) for i in range(group_size): VhD = indsum(VhD, D[:, :, i].T, indices[:, i]) indvals = np.array(list(Counter(indices.flatten()).values())) indkeys = np.argsort( np.array(list(Counter(indices.flatten()).keys()))) VhD = VhD / indvals[indkeys] Fd = np.fft.fft2( patch_aggregator(VhD, patch_shape=patch_shape, image_shape=(num_sources, rows, cols))) else: Fd = np.zeros_like(psfdfts_h_meas) recon = np.real( np.fft.ifft2(block_mul(SIG_inv, s * Fc + lr * Fd + psfdfts_h_meas))) if (iter + 1) % periter == 0 or iter == iternum - 1: deconv_plotter(sources=sources, recons=recon, iter=iter) return recon
def sparsepatch(*, sources, psfs, measurements, recon_init, iternum, plot=True, periter, nu, lam, patch_shape, transform, learning): """Function that implements PSSI deconvolution with a patch based sparsifying transform for sparse recovery. P1 and P3 formulations described in [doi:10.1137/141002293] have been implemented without the transform update step. Args: sources (ndarray): 4d array of sources psfs (PSFs): PSFs object containing psfs and related data measurements (ndarray): 4d array of measurements recon_init (ndarray): initialization for the reconstructed image(s) iternum (int): number of iterations of ADMM plot (boolean): if set to True, display the reconstructions as the iterations go periter (int): iteration period of displaying the reconstructions nu (float): augmented Lagrangian parameter lam (float): penalty parameter of the sparsity term patch_shape (tuple): tuple of the shape of the patches used transform (ndarray): ndarray of the sparsifying transform used learning (bool): boolean variable of whether the transform gets updated """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] psize = np.size(np.empty(patch_shape)) # mse_inner = np.zeros((num_sources,recon.maxiter)) if type(lam) is np.float or type(lam) is np.int: lam = np.ones(num_sources) * lam ################## initialize the reconstruction ################## recon = recon_init ################# pre-compute some arrays for efficiency ################# psfs.psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts) psfdfts_h_meas = block_mul(block_herm( psfs.psf_dfts), np.fft.fft2(np.fft.fftshift( measurements, axes=(1, 2)))) # this is reshaped FA^Ty term where F is DFT matrix LAM = psize * np.ones((rows, cols)) spectrum = nu * np.einsum('ij,kl->ijkl', np.eye(num_sources), LAM) SIG_inv = block_inv(psfs.psf_GAM + spectrum) for iter in range(iternum): # ----- Sparse Coding ----- patches = patch_extractor(recon, patch_shape=patch_shape) sparse_codes = transform @ patches for i in range(num_sources): sparse_codes[:, i * rows * cols:(i + 1) * rows * cols] = hard_thresholding( sparse_codes[:, i * rows * cols:(i + 1) * rows * cols], threshold=np.sqrt(lam[i] / nu)) # ----- Image Update ----- Fc = np.fft.fft2( patch_aggregator(transform.conj().T @ sparse_codes, patch_shape=patch_shape, image_shape=(num_sources, rows, cols))) recon = np.real( np.fft.ifft2(block_mul(SIG_inv, nu * Fc + psfdfts_h_meas))) if learning is True: u, s, vT = np.linalg.svd(sparse_codes @ patches.T) transform = u @ vT # mse_inner[:,iter] = np.mean( # (sources - recon.reconstructed)**2, # axis=(1, 2, 3) # ) # # dfid = recon.nu*(1/np.size(measurements))*np.sum(abs( # block_mul( # psfs.selected_psf_dfts, np.fft.fft2(recon.reconstructed) # ) - np.fft.fft2(measurements) # )**2) # # sp_error = np.sum((recon.transform@patch_extractor( # recon.reconstructed, # patch_shape=recon.patch_shape # )-sparse_codes)**2) if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1): deconv_plotter(sources=sources, recons=recon, iter=iter) return recon
plotter4d(sources, title='Orig', figsize=(5.6, 8), cmap='gist_heat') [k, num_sources, aa, bb] = psfs.selected_psfs.shape[:2] + sources.shape[2:] LAM = get_LAM(rows=aa, cols=bb, order=tikhonov_order) ssim = np.zeros((len(tikhonov_lam), num_instances, num_sources)) psnr = np.zeros((len(tikhonov_lam), num_instances, num_sources)) for i in range(len(tikhonov_lam)): SIG_inv = block_inv(psfs.selected_GAM + tikhonov_lam[i] * np.einsum('ij,kl', np.eye(num_sources), LAM)) for j in range(num_instances): # DFT of the kernel corresponding to (D^TD) recon = np.real( np.fft.ifft2( block_mul( SIG_inv, block_mul(psfs.selected_psf_dfts_h, np.fft.fft2(measured_noisy_instances[j]))))) ###### COMPUTE PERFORMANCE METRICS ###### mse = np.mean((sources - recon)**2, axis=(1, 2, 3)) psnr[i, j] = 20 * np.log10( np.max(sources, axis=(1, 2, 3)) / np.sqrt(mse)) for p in range(num_sources): ssim[i, j, p] = compare_ssim(sources[p, 0], recon[p, 0], data_range=np.max(recon[p, 0]) - np.min(recon[p, 0])) plotter4d(recon, cmap='gist_heat', figsize=(5.6, 8), title='Recon. SSIM={}\n Recon. PSNR={}'.format(
def admm(*, sources=None, psfs, measurements, regularizer, recon_init, iternum, plot=True, periter=5, nu, lam, **kwargs): """Function that implements PSSI deconvolution with the ADMM algorithm using the specified regularization method. Args: sources (ndarray): 3d array of sources psfs (PSFs): PSFs object containing psfs and related data measurements (ndarray): 3d array of measurements regularizer (function): function that specifies the regularization type recon_init (ndarray): initialization for the reconstructed image(s) iternum (int): number of iterations of ADMM plot (boolean): if set to True, display the reconstructions as the iterations go periter (int): iteration period of displaying the reconstructions nu (float): augmented Lagrangian parameter (step size) of ADMM lam (list): regularization parameter of dimension num_sources kwargs (dict): keyword arguments to be passed to the regularizers Returns: ndarray of reconstructed images """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] if type(lam) is np.float or type(lam) is np.float64 or type(lam) is np.int: lam = np.ones(num_sources) * lam ################## initialize the primal/dual variables ################## primal1 = recon_init primal2 = None dual = None ################# pre-compute some arrays for efficiency ################# psfs.psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts) psfdfts_h_meas = block_mul(block_herm( psfs.psf_dfts), np.fft.fft2(np.fft.fftshift( measurements, axes=(1, 2)))) # this is reshaped FA^Ty term where F is DFT matrix SIG_inv = get_SIG_inv(regularizer=regularizer, measurements=measurements, psfs=psfs, nu=nu, **kwargs) for iter in range(iternum): ######################### PRIMAL 1,2 UPDATES ######################### primal1, primal2, pre_primal2, dual = regularizer( psfs=psfs, measurements=measurements, psfdfts_h_meas=psfdfts_h_meas, SIG_inv=SIG_inv, primal1=primal1, primal2=primal2, dual=dual, nu=nu, lam=lam, **kwargs) ########################### DUAL UPDATE ########################### dual += (pre_primal2 - primal2) if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1): deconv_plotter(sources=sources, recons=primal1, iter=iter) return primal1