def tikhonov(*, psfs, measurements, tikhonov_lam=0.129, tikhonov_order=1): """Perform Tikhonov regularization based image reconstruction for PSSI. Solves x_hat = argmin_{x} { ||Ax-y||_2^2 + lam * ||Dx||_2^2 }. D is the discrete derivative operator of order `tikhonov_order`. Args: psfs (PSFs): PSFs object containing psfs and other csbs state data measured_noisy (ndarray): 4d array of noisy measurements tikhonov_lam (float): regularization parameter of tikhonov tikhonov_order (int): [0,1 or 2] order of the discrete derivative operator used in the tikhonov regularization Returns: 4d array of the reconstructed images """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] # expand psf_dfts exp_psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(exp_psf_dfts), exp_psf_dfts) # DFT of the kernel corresponding to (D^TD) LAM = get_LAM(rows=rows, cols=cols, order=tikhonov_order) return np.real( np.fft.ifft2( block_mul( block_inv(psfs.psf_GAM + tikhonov_lam * np.einsum('ij,kl', np.eye(num_sources), LAM)), block_mul( block_herm(exp_psf_dfts), np.fft.fft2(np.fft.fftshift(measurements, axes=(1, 2)))))))
def test_size_equalizer(self): # scale up bigger = size_equalizer(self.simple_array, (4, 4)) self.assertEqualNp( bigger, np.array(((0, 0, 0, 0), (0, 1, 2, 0), (0, 3, 4, 0), (0, 0, 0, 0)))) # scale down self.assertEqualNp(size_equalizer(bigger, (2, 2)), self.simple_array) self.assertEqual( size_equalizer(np.random.random((5, 5)), (4, 4)).shape, (4, 4)) # test vectorization x = np.random.random((3, 3)) y = np.repeat(x[np.newaxis, :, :], 4, axis=0) self.assertEqual(size_equalizer(y, (2, 2)).shape, (4, 2, 2))
def default_adjoint(x, psfs): [p, aa, bb] = x.shape [k, p, ss, ss] = psfs.psfs.shape ta, tb = [aa + ss - 1, bb + ss - 1] # FIXME: make it work for 2D input, remove selected_psfs # FIXME: ;move psf_dft computation to PSFs (make PSFs accept sampling_interval and o # output size arguments) # reshape psfs expanded_psfs = size_equalizer(psfs.psfs, ref_size=[aa, bb]) expanded_psfs = np.repeat(expanded_psfs, psfs.copies.astype(int), axis=0) expanded_psf_dfts = np.fft.fft2(expanded_psfs).transpose((1, 0, 2, 3)) # ----- forward ----- im = np.fft.fftshift(np.fft.ifft2( np.einsum('ijkl,jkl->ikl', expanded_psf_dfts, np.fft.fft2(x))), axes=(1, 2)) # im = get_measurements(sources=x, psfs=psfs, real=True) return radon_forward(im)
def motion_deblur(*, sv, coadded, drift): """ """ pixel_size_um = sv.pixel_size * 1e6 # width of the final motion blur kernel with CCD pixel size kernel_size = 11 (x, y) = (pixel_size_um * drift[0], pixel_size_um * drift[1]) N = int(np.ceil(np.max((abs(x), abs(y))))) # set the shape of the initial kernel with 1 um pixels based on the estimated drift kernel_um = np.zeros((2 * N + 1, 2 * N + 1)) # calculate the line representing the motion blur rr, cc, val = line_aa( N + np.round((y / 2)).astype(int), N - np.round((x / 2)).astype(int), N - np.round((y / 2)).astype(int), N + np.round((x / 2)).astype(int), ) # update the kernel with the calculated line kernel_um[rr, cc] = val # resize the initial 1 um kernel to the given pixel size kernel = resize(size_equalizer(kernel_um, [int(pixel_size_um) * kernel_size] * 2), [kernel_size] * 2, anti_aliasing=True) # compute the analytical photon sieve PSF with the given pixel size psfs = copy.deepcopy(sv.psfs) # convolve the photon sieve PSF with the motion blur kernel to find the "effective blurring kernel" psfs.psfs[0, 0] = convolve2d(psfs.psfs[0, 0], kernel, mode='same') # normalize the kernel psfs.psfs[0, 0] /= psfs.psfs[0, 0].sum() # normalize the coadded image (doesn't change anything but helps choosing regularization parameter consistently) coadded /= coadded.max() # do a tikhonov regularized deblurring on the coadded image to remove the # in-frame blur with the calculated "effective blurring kernel" recon_tik = tikhonov(measurements=coadded[np.newaxis, :, :], psfs=psfs, tikhonov_lam=1e1, tikhonov_order=1) plt.figure() plt.imshow(recon_tik[0], cmap='gist_heat') plt.title('Deblurred Tikhonov') plt.show() # do a Plug and Play with BM3D reconstruction with tikhonov initialization recon = admm(measurements=coadded[np.newaxis, :, :], psfs=psfs, regularizer=partial(bm3d_pnp), recon_init=recon_tik, plot=False, iternum=5, periter=1, nu=10**-0.0, lam=[10**-0.5]) plt.figure() plt.imshow(recon[0], cmap='gist_heat') plt.title('Deblurred') plt.show() return recon
def shift_and_sum(frames, drift, mode='full', shift_method='roll'): """Coadd frames by given shift Args: frames (ndarray): input frames to coadd drift (ndarray): drift between adjacent frames mode (str): zeropad before coadding ('full') or crop to region of frame overlap ('crop'), or crop to region of first frame ('first') shift_method (str): method for shifting frames ('roll', 'fourier') pad (bool): zeropad images before coadding Returns: (ndarray): coadded images """ assert type(drift) is np.ndarray, "'drift' should be ndarray" print('1') pad = np.ceil(drift * (len(frames) - 1)).astype(int) pad_r = (0, pad[0]) if drift[0] > 0 else (-pad[0], 0) pad_c = (0, pad[1]) if drift[1] > 0 else (-pad[1], 0) print('2') frames_ones = np.pad( np.ones(frames.shape, dtype=int), ((0, 0), pad_r, pad_c), mode='constant', ) print('3') frames_pad = np.pad(frames, ((0, 0), pad_r, pad_c), mode='constant') print('3') summation = np.zeros(frames_pad[0].shape, dtype='complex128') print('4') summation_scale = np.zeros(frames_pad[0].shape, dtype=int) print('5') for time_diff, (frame, frame_ones) in tqdm(enumerate(zip(frames_pad, frames_ones))): shift = np.array(drift) * (time_diff + 1) if shift_method == 'roll': integer_shift = np.floor(shift).astype(int) shifted = roll(frame, (integer_shift[0], integer_shift[1])) shifted_ones = roll(frame_ones, (integer_shift[0], integer_shift[1])) elif shift_method == 'fourier': shifted = np.fft.ifftn( fourier_shift(np.fft.fftn(frame), (shift[0], shift[1]))) shifted_ones = np.fft.ifftn( fourier_shift(np.fft.fftn(frame_ones), (shift[0], shift[1]))) else: raise Exception('Invalid shift_method') summation += shifted summation_scale += shifted_ones summation /= summation_scale if mode == 'crop': summation = size_equalizer( summation, np.array(frames_pad[0].shape).astype(int) - 2 * np.ceil(drift * (len(frames_pad) - 1)).astype(int)) elif mode == 'full': pass elif mode == 'first': summation_scale[summation_scale == 0] = 1 summation = summation[:frames.shape[1], :frames.shape[2]] elif mode == 'center': summation_scale[summation_scale == 0] = 1 summation = size_equalizer(summation, frames.shape[1:]) else: raise Exception('Invalid mode') return summation.real
from mas.forward_model import get_measurements, add_noise, size_equalizer from mas.psf_generator import PSFs, PhotonSieve, circ_incoherent_psf from mas.deconvolution import tikhonov, admm from mas.deconvolution.admm import bm3d_pnp from mas.plotting import plotter4d from mas.data import strands_ext from functools import partial from skimage.measure import compare_ssim as ssim # %% meas ------------------------ source_wavelengths = np.array([33.4e-9, 33.5e-9]) num_sources = len(source_wavelengths) meas_size = [160,160] sources = strands_ext[0:num_sources] sources_ref = size_equalizer(sources, ref_size=meas_size) ps = PhotonSieve(diameter=16e-2, smallest_hole_diameter=7e-6) # generate psfs psfs = PSFs( ps, sampling_interval=3.5e-6, measurement_wavelengths=source_wavelengths, source_wavelengths=source_wavelengths, psf_generator=circ_incoherent_psf, # image_width=psf_width, num_copies=1 ) # ############## INVERSE CRIME ############## # measured = get_measurements(
from imageio import imread import numpy as np import matplotlib.pyplot as plt from scipy.misc import face from scipy.ndimage import rotate from mas.forward_model import downsample, upsample, size_equalizer offset = (500, 500) x = face(gray=True) # roi = upsample(downsample(x[500:600, 500:600], factor=5), factor=5) crop_roi = x[:-1, :-1] roll_roi = np.roll(np.roll(x, -offset[0], axis=0), -offset[1], axis=1) x_fft = np.fft.fft2(x) # crop_fft = size_equalizer(np.fft.fft2(crop_roi), x.shape) crop_fft = np.fft.fft2(size_equalizer(crop_roi, x.shape)) roll_fft = np.fft.fft2(roll_roi) crop_csd = ( np.multiply(x_fft, np.conj(crop_fft)) / np.abs(np.multiply(x_fft, np.conj(crop_fft))) ) crop_phase = np.abs(np.fft.ifft2(crop_csd)) roll_csd = ( np.multiply(x_fft, np.conj(roll_fft)) / np.abs(np.multiply(x_fft, np.conj(roll_fft))) ) roll_phase = np.abs(np.fft.ifft2(roll_csd)) def zoom_plot(im, offset, width):
import numpy as np import imageio, pickle, h5py num_instances = 1 tikhonov_order = 1 tikhonov_lam = 1e-2 if type(tikhonov_lam) is np.int or type(tikhonov_lam) is np.float: tikhonov_lam = [tikhonov_lam] psf_width = 201 # source_wavelengths = np.array([9.4e-9]) source_wavelengths = np.array([33.4e-9, 33.5e-9]) num_sources = len(source_wavelengths) source1 = size_equalizer( np.array( h5py.File( '/home/kamo/Research/mas/nanoflare_videos/NanoMovie0_2000strands_94.h5' )['NanoMovie0_2000strands_94'])[0], (160, 160)) source2 = size_equalizer( np.array( h5py.File( '/home/kamo/Research/mas/nanoflare_videos/NanoMovie0_2000strands_94.h5' )['NanoMovie0_2000strands_94'])[250], (160, 160)) # source1 = rectangle_adder(image=np.zeros((100,100)), size=(30,30), upperleft=(35,10)) # source2 = 10 * rectangle_adder(image=np.zeros((100,100)), size=(30,30), upperleft=(35,60)) # source1 = readsav('/home/kamo/Research/mas/nanoflare_videos/old/movie0_1250strands_335.sav',python_dict=True)['movie'][500] # source2 = readsav('/home/kamo/Research/mas/nanoflare_videos/old/movie0_1250strands_94.sav',python_dict=True)['movie'][500] [aa, bb] = source1.shape meas_size = tuple(np.array([aa, bb]) - 0) sources = np.zeros((len(source_wavelengths), 1, aa, bb)) sources[0, 0] = source1 / source1.max()
def admm(*, sources=None, psfs, measurements, regularizer, recon_init, iternum, plot=True, periter=5, nu, lam, **kwargs): """Function that implements PSSI deconvolution with the ADMM algorithm using the specified regularization method. Args: sources (ndarray): 3d array of sources psfs (PSFs): PSFs object containing psfs and related data measurements (ndarray): 3d array of measurements regularizer (function): function that specifies the regularization type recon_init (ndarray): initialization for the reconstructed image(s) iternum (int): number of iterations of ADMM plot (boolean): if set to True, display the reconstructions as the iterations go periter (int): iteration period of displaying the reconstructions nu (float): augmented Lagrangian parameter (step size) of ADMM lam (list): regularization parameter of dimension num_sources kwargs (dict): keyword arguments to be passed to the regularizers Returns: ndarray of reconstructed images """ num_sources = psfs.psfs.shape[1] rows, cols = measurements.shape[1:] if type(lam) is np.float or type(lam) is np.float64 or type(lam) is np.int: lam = np.ones(num_sources) * lam ################## initialize the primal/dual variables ################## primal1 = recon_init primal2 = None dual = None ################# pre-compute some arrays for efficiency ################# psfs.psf_dfts = np.repeat(np.fft.fft2( size_equalizer(psfs.psfs, ref_size=[rows, cols])), psfs.copies.astype(int), axis=0) psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts) psfdfts_h_meas = block_mul(block_herm( psfs.psf_dfts), np.fft.fft2(np.fft.fftshift( measurements, axes=(1, 2)))) # this is reshaped FA^Ty term where F is DFT matrix SIG_inv = get_SIG_inv(regularizer=regularizer, measurements=measurements, psfs=psfs, nu=nu, **kwargs) for iter in range(iternum): ######################### PRIMAL 1,2 UPDATES ######################### primal1, primal2, pre_primal2, dual = regularizer( psfs=psfs, measurements=measurements, psfdfts_h_meas=psfdfts_h_meas, SIG_inv=SIG_inv, primal1=primal1, primal2=primal2, dual=dual, nu=nu, lam=lam, **kwargs) ########################### DUAL UPDATE ########################### dual += (pre_primal2 - primal2) if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1): deconv_plotter(sources=sources, recons=primal1, iter=iter) return primal1
def __init__(self, sieve, *, source_wavelengths=np.array([33.4, 33.5]) * 1e-9, measurement_wavelengths=30, image_width=1001, cropped_width=None, energy_ratio=0.9995, num_copies=1, psf_generator=circ_incoherent_psf, sampling_interval=3.5e-6, grid_width=10, zero_mean=False): focal_lengths = sieve.diameter * sieve.smallest_hole_diameter / source_wavelengths dofs = 2 * sieve.smallest_hole_diameter**2 / source_wavelengths if type(measurement_wavelengths) is int: approx_start = sieve.diameter * sieve.smallest_hole_diameter / ( max(focal_lengths) + grid_width * max(dofs)) approx_end = sieve.diameter * sieve.smallest_hole_diameter / ( min(focal_lengths) - grid_width * min(dofs)) measurement_wavelengths = np.linspace(approx_start, approx_end, measurement_wavelengths) measurement_wavelengths = np.insert( measurement_wavelengths, np.searchsorted(measurement_wavelengths, source_wavelengths), source_wavelengths) psfs = np.empty((0, len(source_wavelengths), image_width, image_width)) # generate incoherent measurements for each wavelength and plane location bar = tqdm(total=len(measurement_wavelengths) * len(source_wavelengths), desc='PSFs', leave=None, position=1) for m, measurement_wavelength in enumerate(measurement_wavelengths): psf_group = np.empty((0, image_width, image_width)) for n, source_wavelength in enumerate(source_wavelengths): bar.update(1) # sys.stdout.write('\033[K') # print( # 'PSF {}/{}\r'.format( # m * len(source_wavelengths) + n + 1, # len(measurement_wavelengths) * len(source_wavelengths) # ), # end='' # ) psf = psf_generator( sieve=sieve, source_wavelength=float(source_wavelength), measurement_wavelength=measurement_wavelength, image_width=image_width, source_distance=float('inf'), sampling_interval=float(sampling_interval)) psf_group = np.append(psf_group, [psf], axis=0) psfs = np.append(psfs, [psf_group], axis=0) bar.close() if cropped_width is not None: psfs = size_equalizer(psfs, [cropped_width, cropped_width]) else: width0 = size_compressor(psfs[0, -1], energy_ratio=energy_ratio) width1 = size_compressor(psfs[-1, 0], energy_ratio=energy_ratio) width = max(width0, width1) psfs = size_equalizer(psfs, [width, width]) if zero_mean: psfs -= psfs.mean(axis=(2, 3))[:, :, np.newaxis, np.newaxis] self.psfs = psfs self.psf_dfts = np.fft.fft2(psfs) self.num_copies = num_copies self.copies = np.ones((len(measurement_wavelengths))) * num_copies self.measurement_wavelengths = measurement_wavelengths self.source_wavelengths = source_wavelengths self.sampling_interval = sampling_interval self.cropped_width = cropped_width self.copies_history = []