def ft1D(x,y,*,axis=0,zero_DC=False): """Takes in x and y = y(x), and returns k and the Fourier transform of y(x) -> f(k) along a single (1D) axis Handles all of the annoyances of fftshift and ifftshift, and gets the normalization right Args: x (np.ndarray) : independent variable, must be 1D y (np.ndarray) : dependent variable, can be nD Kwargs: axis (int) : which axis to perform FFT zero_DC (bool) : if true, sets f(0) = 0 """ dx = x[1]-x[0] k = fftshift(fftfreq(x.size,d=dx))*2*np.pi fft_norm = dx shifted_x = ifftshift(x) if np.isclose(shifted_x[0],0): f = fft(ifftshift(y,axes=(axis)),axis=axis)*fft_norm else: f = fft(y,axis=axis)*fft_norm if zero_DC: nd_slice = [slice(None) for i in range(len(f.shape))] nd_slice[axis] = slice(0,1,1) nd_slice = tuple(nd_slice) f[nd_slice] = 0 f = fftshift(f,axes=(axis)) return k, f
def polarization_to_signal(self,P_of_t_in,*,return_polarization=False, local_oscillator_number = -1): """This function generates a frequency-resolved signal from a polarization field local_oscillator_number - usually the local oscillator will be the last pulse in the list self.efields""" pulse_time = self.pulse_times[local_oscillator_number] if self.gamma != 0: exp_factor = np.exp(-self.gamma * (self.t-pulse_time)) P_of_t_in *= exp_factor P_of_t = P_of_t_in if return_polarization: return P_of_t if local_oscillator_number == 'impulsive': efield = np.exp(1j*self.w*(pulse_time)) else: pulse_time_ind = np.argmin(np.abs(self.t - pulse_time)) pulse_start_ind = pulse_time_ind - self.size//2 pulse_end_ind = pulse_time_ind + self.size//2 + self.size%2 t_slice = slice(pulse_start_ind, pulse_end_ind,None) efield = np.zeros(self.t.size,dtype='complex') efield[t_slice] = self.efields[local_oscillator_number] efield = fftshift(ifft(ifftshift(efield)))*len(P_of_t)*(self.t[1]-self.t[0])/np.sqrt(2*np.pi) if P_of_t.size%2: P_of_t = P_of_t[:-1] efield = efield[:len(P_of_t)] P_of_w = fftshift(ifft(ifftshift(P_of_t)))*len(P_of_t)*(self.t[1]-self.t[0])/np.sqrt(2*np.pi) signal = np.imag(P_of_w * np.conjugate(efield)) return signal
def ift1D(k,f,*,axis=0,zero_DC=False): """Takes in k and f = f(k), and returns x and the discrete Fourier transform of f(k) -> y(x). Handles all of the annoyances of fftshift and ifftshift, and gets the normalization right Args: x (np.ndarray): independent variable y (np.ndarray): dependent variable Kwargs: axis (int) : which axis to perform FFT """ dk = k[1]-k[0] x = fftshift(fftfreq(k.size,d=dk))*2*np.pi ifft_norm = dk*k.size/(2*np.pi) shifted_k = ifftshift(k) if np.isclose(shifted_k[0],0): y = ifft(ifftshift(f,axes=(axis)),axis=axis)*ifft_norm else: y = ifft(f,axis=axis)*ifft_norm if zero_DC: nd_slice = [slice(None) for i in range(len(y.shape))] nd_slice[axis] = slice(0,1,1) nd_slice = tuple(nd_slice) y[nd_slice] = 0 y = fftshift(y,axes=(axis)) return x, y
def polarization_to_signal(self, P_of_t_in, *, return_polarization=False, local_oscillator_number=-1, undersample_factor=1): """This function generates a frequency-resolved signal from a polarization field local_oscillator_number - usually the local oscillator will be the last pulse in the list self.efields""" undersample_slice = slice(None, None, undersample_factor) P_of_t = P_of_t_in[undersample_slice] t = self.t[undersample_slice] dt = t[1] - t[0] pulse_time = self.pulse_times[local_oscillator_number] if self.gamma != 0: exp_factor = np.exp(-self.gamma * np.abs(t - pulse_time)) P_of_t *= exp_factor if self.sigma_I != 0: inhomogeneous = np.exp(-(t - pulse_time)**2 * self.sigma_I**2 / 2) P_of_t *= inhomogeneous if return_polarization: return P_of_t pulse_time_ind = np.argmin(np.abs(self.t - pulse_time)) efield = np.zeros(self.t.size, dtype='complex') if self.efield_t.size == 1: # Impulsive limit efield[pulse_time_ind] = self.efields[local_oscillator_number] efield = fftshift(ifft(ifftshift(efield))) * efield.size / np.sqrt( 2 * np.pi) else: pulse_start_ind = pulse_time_ind - self.size // 2 pulse_end_ind = pulse_time_ind + self.size // 2 + self.size % 2 t_slice = slice(pulse_start_ind, pulse_end_ind, None) efield[t_slice] = self.efields[local_oscillator_number] efield = fftshift(ifft(ifftshift(efield))) * self.t.size * ( self.t[1] - self.t[0]) / np.sqrt(2 * np.pi) # if P_of_t.size%2: # P_of_t = P_of_t[:-1] # t = t[:-1] halfway = self.w.size // 2 pm = self.w.size // (2 * undersample_factor) efield_min_ind = halfway - pm efield_max_ind = halfway + pm + self.w.size % 2 efield = efield[efield_min_ind:efield_max_ind] P_of_w = fftshift(ifft( ifftshift(P_of_t))) * len(P_of_t) * dt / np.sqrt(2 * np.pi) signal = np.imag(P_of_w * np.conjugate(efield)) return signal
def add_gaussian_linewidth(self, sigma): self.old_signal = self.signal.copy() sig_tau_t = fftshift(fft(ifftshift(self.old_signal, axes=(-1)), axis=-1), axes=(-1)) sig_tau_t = sig_tau_t * ( np.exp(-self.t**2 / (2 * sigma**2))[np.newaxis, np.newaxis, :] * np.exp(-self.t21_array**2 / (2 * sigma**2))[:, np.newaxis, np.newaxis]) sig_tau_w = fftshift(ifft(ifftshift(sig_tau_t, axes=(-1)), axis=-1), axes=(-1)) self.signal = sig_tau_w
def stCoefs_1d(signals, filters, block_size=100, second_order=False): """ compute 1D signals scattering coefficents input signals: 2D numpy array (# samples, signal_length) filters: 2D numpy array (# filters, signal_length) output filtered_signal: 3D numpy array (# samples, # filters, signal_length) """ signal_length = signals.shape[-1] signal_shape = signals.shape[:-1] signal_size = reduce(lambda x, y: x * y, signal_shape) signals.shape = (signal_size, signal_length) filter_length = filters.shape[-1] filter_shape = filters.shape[:-1] filter_size = reduce(lambda x, y: x * y, filter_shape) filters.shape = (filter_size, filter_length) f_signals = np.fft.fft(signals, axis=1) f_filters = np.fft.fft(fft.ifftshift(filters, axes=(1,)), axis=1) if not second_order: f_conv = f_signals[:, np.newaxis, :] * f_filters[np.newaxis, :, :] else: f_conv = np.zeros(signal_size, filter_size, signal_length) filtered = fft.ifft(f_conv, axis=2) return filtered
def ft_interface(a, s, axes, norm, **kwargs): # call fft and shift the result shftax = axes[:-1] if omitlast else axes if forward: return fftmod.fftshift(ftfunc(a, s, axes, norm, **kwargs), shftax) else: return ftfunc(fftmod.ifftshift(a, shftax), s, axes, norm, **kwargs)
def fillPowerFromTemplate(self, twodPower): """ Fill the power2D.powerMap with the input power array. Parameters ---------- twodPower : array_like The 2D data array specifying the template power to fill with. """ tdp = twodPower.copy() # interpolate if tdp.Nx != self.Nx or tdp.Ny != self.Ny: # first divide out the area factor area = tdp.Nx*tdp.Ny*tdp.pixScaleX*tdp.pixScaleY tdp.powerMap *= (tdp.Nx*tdp.Ny)**2 / area lx_shifted = fftshift(tdp.lx) ly_shifted = fftshift(tdp.ly) tdp_shifted = fftshift(tdp.powerMap) f_interp = interp2d(lx_shifted, ly_shifted, tdp_shifted) cl_new = f_interp(fftshift(self.lx), fftshift(self.ly)) cl_new = ifftshift(cl_new) area = self.Nx*self.Ny*self.pixScaleX*self.pixScaleY cl_new *= area / (self.Nx*self.Ny*1.)**2 self.powerMap[:] = cl_new[:] else: self.powerMap[:] = tdp.powerMap[:]
def convolve(arr1, arr2, dx=None, axes=None): """ Performs a centred convolution of input arrays Parameters ---------- arr1, arr2 : `numpy.ndarray` Arrays to be convolved. If dimensions are not equal then 1s are appended to the lower dimensional array. Otherwise, arrays must be broadcastable. dx : float > 0, list of float, or `None` , optional Grid spacing of input arrays. Output is scaled by `dx**max(arr1.ndim, arr2.ndim)`. default=`None` applies no scaling axes : tuple of ints or `None`, optional Choice of axes to convolve. default=`None` convolves all axes """ if arr2.ndim > arr1.ndim: arr1, arr2 = arr2, arr1 if axes is None: axes = range(arr2.ndim) arr2 = arr2.reshape(arr2.shape + (1, ) * (arr1.ndim - arr2.ndim)) if dx is None: dx = 1 elif isscalar(dx): dx = dx**(len(axes) if axes is not None else arr1.ndim) else: dx = prod(dx) arr1 = fftn(arr1, axes=axes) arr2 = fftn(ifftshift(arr2), axes=axes) out = ifftn(arr1 * arr2, axes=axes) * dx return require(out, requirements="CA")
def decomposition_fft(X, filter, **kwargs): """Decompose a 2d input field into multiple spatial scales by using the Fast Fourier Transform (FFT) and a bandpass filter. Parameters ---------- X : array_like Two-dimensional array containing the input field. All values are required to be finite. filter : dict A filter returned by any method implemented in bandpass_filters.py. Other Parameters ---------------- MASK : array_like Optional mask to use for computing the statistics for the cascade levels. Pixels with MASK==False are excluded from the computations. Returns ------- out : ndarray A dictionary described in the module documentation. The parameter n is determined from the filter (see bandpass_filters.py). """ MASK = kwargs.get("MASK", None) if len(X.shape) != 2: raise ValueError("the input is not two-dimensional array") if MASK is not None and MASK.shape != X.shape: raise ValueError("dimension mismatch between X and MASK: X.shape=%s, MASK.shape=%s" % \ (str(X.shape), str(MASK.shape))) if X.shape != filter["weights_2d"].shape[1:3]: raise ValueError( "dimension mismatch between X and filter: X.shape=%s, filter['weights_2d'].shape[1:3]=%s" % (str(X.shape), str(filter["weights_2d"].shape[1:3]))) if np.any(~np.isfinite(X)): raise ValueError("X contains non-finite values") result = {} means = [] stds = [] F = fft.fftshift(fft.fft2(X, **fft_kwargs)) X_decomp = [] for k in range(len(filter["weights_1d"])): W_k = filter["weights_2d"][k, :, :] X_ = np.real(fft.ifft2(fft.ifftshift(F * W_k), **fft_kwargs)) X_decomp.append(X_) if MASK is not None: X_ = X_[MASK] means.append(np.mean(X_)) stds.append(np.std(X_)) result["cascade_levels"] = np.stack(X_decomp) result["means"] = means result["stds"] = stds return result
def generate_psf(self, sphase=slice(4, None, None), size=None, zsize=None, zrange=None): """Make a perfect PSF""" # make a copy of the internal model model = copy.copy(self.model) # update zsize or zrange if zsize is not None: model.zsize = zsize if zrange is not None: model.zrange = zrange # generate the PSF from the reconstructed phase if not hasattr(self, 'zd_result'): self.fit_to_zernikes(120) model._gen_psf(ifftshift(self.zd_result.complex_pupil(sphase=sphase))) # reshpae PSF if needed in x/y dimensions psf = model.PSFi nz, ny, nx = psf.shape assert ny == nx, "Something is very wrong" if size is not None: if nx < size: # if size is too small, pad it out. psf = fft_pad(psf, (nz, size, size), mode="constant") elif nx > size: # if size is too big, crop it lb = size // 2 hb = size - lb myslice = slice(nx // 2 - lb, nx // 2 + hb) psf = psf[:, myslice, myslice] # return data return psf
def stCoefs_1d(signals, filters, block_size=100, second_order=False): """ compute 1D signals scattering coefficents input signals: 2D numpy array (# samples, signal_length) filters: 2D numpy array (# filters, signal_length) output filtered_signal: 3D numpy array (# samples, # filters, signal_length) """ signal_length = signals.shape[-1] signal_shape = signals.shape[:-1] signal_size = reduce(lambda x, y: x * y, signal_shape) signals.shape = (signal_size, signal_length) filter_length = filters.shape[-1] filter_shape = filters.shape[:-1] filter_size = reduce(lambda x, y: x * y, filter_shape) filters.shape = (filter_size, filter_length) f_signals = np.fft.fft(signals, axis=1) f_filters = np.fft.fft(fft.ifftshift(filters, axes=(1, )), axis=1) if not second_order: f_conv = f_signals[:, np.newaxis, :] * f_filters[np.newaxis, :, :] else: f_conv = np.zeros(signal_size, filter_size, signal_length) filtered = fft.ifft(f_conv, axis=2) return filtered
def check_efield_resolution(self, efield, *, plot_fields=False): efield_tail = np.max(np.abs([efield[0], efield[-1]])) if efield_tail > np.max(np.abs(efield)) / 100: warnings.warn( 'Consider using larger time interval, pulse does not decay to less than 1% of maximum value in time domain' ) efield_fft = fftshift(fft(ifftshift(efield))) * self.dt efield_fft_tail = np.max(np.abs([efield_fft[0], efield_fft[-1]])) if efield_fft_tail > np.max(np.abs(efield_fft)) / 100: warnings.warn( '''Consider using smaller value of dt, pulse does not decay to less than 1% of maximum value in frequency domain''' ) if plot_fields: fig, axes = plt.subplots(1, 2) l1, l2, = axes[0].plot(self.efield_t, np.real(efield), self.efield_t, np.imag(efield)) plt.legend([l1, l2], ['Real', 'Imag']) axes[1].plot(self.efield_w, np.real(efield_fft), self.efield_w, np.imag(efield_fft)) axes[0].set_ylabel('Electric field Amp') axes[0].set_xlabel('Time ($\omega_0^{-1})$') axes[1].set_xlabel('Frequency ($\omega_0$)') fig.suptitle( 'Check that efield is well-resolved in time and frequency') plt.show()
def retrieve_phase_far_field(src_fname, save_path, output_fname=None, pad_length=256, n_epoch=100, learning_rate=0.001): # raw data is assumed to be centered at zero frequency prj_np = dxchange.read_tiff(src_fname) if output_fname is None: output_fname = os.path.basename( os.path.splitext(src_fname)[0]) + '_recon' # take modulus and inverse shift prj_np = ifftshift(np.sqrt(prj_np)) obj_init = np.random.normal(50, 10, list(prj_np.shape) + [2]) obj = tf.Variable(obj_init, dtype=tf.float32, name='obj') prj = tf.constant(prj_np, name='prj') obj_real = tf.cast(obj[:, :, 0], dtype=tf.complex64) obj_imag = tf.cast(obj[:, :, 1], dtype=tf.complex64) # obj_pad = tf.pad(obj, [[pad_length, pad_length], [pad_length, pad_length], [0, 0]], mode='SYMMETRIC') det = tf.fft2d(obj_real + 1j * obj_imag, name='detector_plane') loss = tf.reduce_mean(tf.squared_difference(tf.abs(det), prj, name='loss')) sess = tf.Session() optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) optimizer = optimizer.minimize(loss) sess.run(tf.global_variables_initializer()) for i_epoch in range(n_epoch): t0 = time.time() _, current_loss = sess.run([optimizer, loss]) print('Iteration {}: loss = {}, Δt = {} s.'.format( i_epoch, current_loss, time.time() - t0)) det_final = sess.run(det) obj_final = sess.run(obj) res = np.linalg.norm(obj_final, 2, axis=2) dxchange.write_tiff(res, os.path.join(save_path, output_fname), dtype='float32', overwrite=True) dxchange.write_tiff(fftshift(np.angle(det_final)), os.path.join(save_path, 'detector_phase'), dtype='float32', overwrite=True) dxchange.write_tiff(fftshift(np.abs(det_final)**2), os.path.join(save_path, 'detector_mag'), dtype='float32', overwrite=True) return
def polarization_to_signal(self, P_of_t_in, *, local_oscillator_number=-1, undersample_factor=1): """This function generates a frequency-resolved signal from a polarization field local_oscillator_number - usually the local oscillator will be the last pulse in the list self.efields""" undersample_slice = slice(None, None, undersample_factor) P_of_t = P_of_t_in[undersample_slice].copy() t = self.t[undersample_slice] dt = t[1] - t[0] pulse_time = self.pulse_times[local_oscillator_number] efield_t = self.efield_times[local_oscillator_number] center = -self.centers[local_oscillator_number] P_of_t = P_of_t # * np.exp(-1j*center*t) pulse_time_ind = np.argmin(np.abs(self.t)) efield = np.zeros(self.t.size, dtype='complex') if efield_t.size == 1: # Impulsive limit: delta in time is flat in frequency efield = np.ones( self.w.size) * self.efields[local_oscillator_number] else: pulse_start_ind = pulse_time_ind - efield_t.size // 2 pulse_end_ind = pulse_time_ind + efield_t.size // 2 + efield_t.size % 2 t_slice = slice(pulse_start_ind, pulse_end_ind, None) efield[t_slice] = self.efields[local_oscillator_number] efield = fftshift(ifft(ifftshift(efield))) * efield.size * dt halfway = self.w.size // 2 pm = self.w.size // (2 * undersample_factor) efield_min_ind = halfway - pm efield_max_ind = halfway + pm + self.w.size % 2 efield = efield[efield_min_ind:efield_max_ind] P_of_w = fftshift(ifft(ifftshift(P_of_t))) * P_of_t.size * dt signal = P_of_w * np.conjugate(efield) if not self.return_complex_signal: return np.imag(signal) else: return 1j * signal
def easy_ifft(data, axes=None): """utility method that includes fft shifting""" return ifftshift( ifftn( fftshift( data, axes=axes ), axes=axes ), axes=axes)
def cfftn(data, axes): """ Centered fast fourier transform, n-dimensional. :param data: Complex input data. :param axes: Axes along which to shift and transform. :return: Fourier transformed data. """ return fft.fftshift(fft.fftn(fft.ifftshift(data, axes=axes), axes=axes, norm='ortho'), axes=axes)
def filter_frames(self, data): output = np.empty_like(data[0]) nSlices = data[0].shape[self.slice_dir] for i in range(nSlices): self.sslice[self.slice_dir] = i sino = fft.fftshift(self.fft_object(data[0][tuple(self.sslice)])) sino[self.row1:self.row2] = \ sino[self.row1:self.row2] * self.filtercomplex sino = fft.ifftshift(sino) sino = self.ifft_object(sino).real output[self.sslice] = sino return output
def ft_interface(a, s, axes, norm, **kwargs): # call fft and shift the result if omitlast: if isinstance(axes, int) or len(axes) < 2: # no fftshift in the case of rfft return ftfunc(a, s, axes, norm, **kwargs) else: shftax = axes[:-1] else: shftax = axes if forward: return fftmod.fftshift(ftfunc(a, s, axes, norm, **kwargs), shftax) else: return ftfunc(fftmod.ifftshift(a, shftax), s, axes, norm, **kwargs)
def fftdeconvolve(image, psf): """ De-convolution by directly dividing the DFT... may not be numerically desirable for many applications. Noise could be an issue. Use scipy.fftpack for now; will re-write for anfft later... Taken from this post on stackoverflow.com: http://stackoverflow.com/questions/17473917/is-there-a-equivalent-of-scipy-signal-deconvolve-for-2d-arrays """ if not _pyfftw: raise NotImplementedError image = image.astype('float') psf = psf.astype('float') # image_fft = fftpack.fftshift(fftpack.fftn(image)) # psf_fft = fftpack.fftshift(fftpack.fftn(psf)) image_fft = fftshift(fftn(image)) psf_fft = fftshift(fftn(psf)) kernel = fftshift(ifftn(ifftshift(image_fft / psf_fft))) return kernel
def fresnel_propagate(wavefield, energy_ev, psize_cm, dist_cm, fresnel_approx=True, pad=0, sign_convention=1): """ Perform Fresnel propagation on a batch of wavefields. :param wavefield: complex wavefield with shape [n_batches, n_y, n_x]. :param energy_ev: float. :param psize_cm: size-3 vector with pixel size ([dy, dx, dz]). :param dist_cm: propagation distance. :return: """ minibatch_size = wavefield.shape[0] if pad > 0: wavefield = np.pad(wavefield, [[0, 0], [pad, pad], [pad, pad]], mode='edge') grid_shape = wavefield.shape[1:] if len(psize_cm) == 1: psize_cm = [psize_cm] * 3 voxel_nm = np.array(psize_cm) * 1.e7 lmbda_nm = 1240. / energy_ev mean_voxel_nm = np.prod(voxel_nm)**(1. / 3) size_nm = np.array(grid_shape) * voxel_nm dist_nm = dist_cm * 1e7 h = get_kernel(dist_nm, lmbda_nm, voxel_nm, grid_shape, fresnel_approx=fresnel_approx, sign_convention=sign_convention) wavefield = ifft2( ifftshift(fftshift(fft2(wavefield), axes=[1, 2]) * h, axes=[1, 2])) if pad > 0: wavefield = wavefield[:, pad:-pad, pad:-pad] return wavefield
def subtract_DC(signal,return_ft = False, axis = 1): """Use discrete fourier transform to remove the DC component of a signal. Args: signal (np.ndarray): real signal to be processed return_ft (bool): if True, return the Fourier transform of the input signal axis (int): axis along which the fourier trnasform is to be taken """ sig_fft = fft(ifftshift(signal,axes=(axis)),axis=axis) nd_slice = [slice(None) for i in range(len(sig_fft.shape))] nd_slice[axis] = slice(0,1,1) nd_slice = tuple(nd_slice) sig_fft[nd_slice] = 0 if not return_ft: sig = fftshift(ifft(sig_fft),axes=(axis)) else: sig = sig_fft return sig
def fft_filter(im, rois, shift, plot=True, imshow_kwargs={}): f0 = fft.fftshift(fft.fft2(im)) f1 = f0.copy() if plot: kw = dict(vmin=-0.05, vmax=1) kw.update(imshow_kwargs) fig, axes = plt.subplots(2, 2, figsize=(9, 4), sharex='col', sharey='col') (ax, ax_f), (ax1, ax1_f) = axes ax.imshow(im, **kw) ax_f.imshow(abs(f0), norm=LogNorm()) dx, dy = shift for j, roi in enumerate(rois): print(j, roi) roi_t = translate_roi(roi, dx, dy) roi_symm = symmetric_roi(roi, im.shape) roi_symm_t = symmetric_roi(roi_t, im.shape) f1[roi] = f1[roi_t] f1[roi_symm] = f1[roi_symm_t] if plot: roi_to_rect(roi, ax=ax_f, index=j) roi_to_rect(roi, ax=ax1_f) roi_to_rect(roi_t, ax=ax_f, ec='r') roi_to_rect(roi_symm, ax=ax_f) roi_to_rect(roi_symm, ax=ax1_f) roi_to_rect(roi_symm_t, ax=ax_f, ec='r') im1 = fft.ifft2(fft.ifftshift(f1)).real if plot: ax1.imshow(im1, **kw) ax1_f.imshow(abs(f1), norm=LogNorm()) return im1
def pattern_params(my_pat, size=2): """Find stuff""" # REAL FFT! # note the limited shifting, we don't want to shift the last axis my_pat_fft = fftshift(rfftn(ifftshift(my_pat)), axes=tuple(range(my_pat.ndim))[:-1]) my_abs_pat_fft = abs(my_pat_fft) # find dc loc, center of FFT after shifting sizeky, sizekx = my_abs_pat_fft.shape # remember we didn't shift the last axis! dc_loc = (sizeky // 2, 0) # mask data and find next biggest peak dc_power = my_abs_pat_fft[dc_loc] my_abs_pat_fft[dc_loc] = 0 max_loc = np.unravel_index(my_abs_pat_fft.argmax(), my_abs_pat_fft.shape) # pull the 3x3 region around the peak and fit max_shift = localize_peak(my_abs_pat_fft[slice_maker(max_loc, 3)]) # calculate precise peak relative to dc peak = np.array(max_loc) + np.array(max_shift) - np.array(dc_loc) # correct location based on initial data shape peak_corr = peak / np.array(my_pat.shape) # calc angle preciseangle = np.arctan2(*peak_corr) # calc period precise_period = 1 / norm(peak_corr) # calc phase phase = np.angle(my_pat_fft[max_loc[0], max_loc[1]]) # calc modulation depth numerator = abs(my_pat_fft[slice_maker(max_loc, size)].sum()) mod = numerator / dc_power return {"period": precise_period, "angle": preciseangle, "phase": phase, "fft": my_pat_fft, "mod": mod, "max_loc": max_loc}
def fftconvolve_fast(data, kernel, **kwargs): """A faster version of fft convolution In this case the kernel ifftshifted before FFT but the data is not. This can be done because the effect of fourier convolution is to "wrap" around the data edges so whether we ifftshift before FFT and then fftshift after it makes no difference so we can skip the step entirely. """ # TODO: add error checking like in the above and add functionality # for complex inputs. Also could add options for different types of # padding. dshape = np.array(data.shape) kshape = np.array(kernel.shape) # find maximum dimensions maxshape = np.max((dshape, kshape), 0) # calculate a nice shape fshape = [sig.fftpack.helper.next_fast_len(int(d)) for d in maxshape] # pad out with reflection pad_data = fft_pad(data, fshape, "reflect") # calculate padding padding = tuple( _calc_pad(o, n) for o, n in zip(data.shape, pad_data.shape)) # so that we can calculate the cropping, maybe this should be integrated # into `fft_pad` ... fslice = tuple( slice(s, -e) if e != 0 else slice(s, None) for s, e in padding) if kernel.shape != pad_data.shape: # its been assumed that the background of the kernel has already been # removed and that the kernel has already been centered kernel = fft_pad(kernel, pad_data.shape, mode='constant') k_kernel = rfftn(ifftshift(kernel), pad_data.shape, **kwargs) k_data = rfftn(pad_data, pad_data.shape, **kwargs) convolve_data = irfftn(k_kernel * k_data, pad_data.shape, **kwargs) # return data with same shape as original data return convolve_data[fslice]
ax3.set_xlim([x1,x2-dx]) ax3.set_ylim([y1,y2-dy]) ax3.set_title("Spatial density (NUMERIC)") ax2.imshow(phi.T, origin='lower', interpolation='none', extent=[px1,px2-dpx,py1,py2-dpy],vmin=amin(phi_exact), vmax=amax(phi_exact)) ax4.set_xlabel('x') ax4.set_ylabel('y') ax2.set_xlim([px1,px2-dpx]) ax2.set_ylim([py1,py2-dpy]) ax4.set_title("Momentum density (NUMERIC)") plt.tight_layout() fig.savefig('frames' + '/%04d.png' % k) fig.clf() plt.close('all') dt = (t2-t1)/100. # the first very rough guess of time step expU = exp(dt*dU) expT = exp(dt*dT) W = fftshift(gauss(xgrid, ygrid, pxgrid, pygrid, x0, y0, px0, py0, sigma_x, sigma_y, sigma_px, sigma_py)) t = t1 Nt = 1 while t <= t2: Wexact = W_analytic(xgrid, ygrid, pxgrid, pygrid, t) draw_frame(ifftshift(W), Wexact, Nt) W = solve_spectral(W, expU, expT) t += dt Nt += 1
def multislice_propagate(grid_delta_batch, grid_beta_batch, probe_real, probe_imag, energy_ev, psize_cm, free_prop_cm=None, return_intermediate=False, sign_convention=1): """ Perform multislice propagation on a batch of 3D objects. :param grid_delta_batch: 4D array for object delta with shape [n_batches, n_y, n_x, n_z]. :param grid_beta_batch: 4D array for object beta with shape [n_batches, n_y, n_x, n_z]. :param probe_real: 2D array for the real part of the probe. :param probe_imag: 2D array for the imaginary part of the probe. :param energy_ev: :param psize_cm: size-3 vector with pixel size ([dy, dx, dz]). :param free_prop_cm: :return: """ minibatch_size = grid_delta_batch.shape[0] grid_shape = grid_delta_batch.shape[1:] voxel_nm = np.array(psize_cm) * 1.e7 wavefront = np.zeros([minibatch_size, grid_shape[0], grid_shape[1]], dtype='complex64') wavefront += (probe_real + 1j * probe_imag) lmbda_nm = 1240. / energy_ev mean_voxel_nm = np.prod(voxel_nm)**(1. / 3) size_nm = np.array(grid_shape) * voxel_nm n_slice = grid_shape[-1] delta_nm = voxel_nm[-1] # h = get_kernel_ir(delta_nm, lmbda_nm, voxel_nm, grid_shape) h = get_kernel(delta_nm, lmbda_nm, voxel_nm, grid_shape, sign_convention=sign_convention) k = 2. * PI * delta_nm / lmbda_nm if return_intermediate: wavefront_ls = [] wavefront_ls.append(abs(wavefront)) for i in trange(n_slice): delta_slice = grid_delta_batch[:, :, :, i] beta_slice = grid_beta_batch[:, :, :, i] c = np.exp(1j * k * delta_slice) * np.exp(-k * beta_slice) wavefront = wavefront * c if i < n_slice - 1: wavefront = ifft2( ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h, axes=[1, 2])) if return_intermediate: wavefront_ls.append(abs(wavefront)) if free_prop_cm not in [None, 0]: if free_prop_cm == 'inf': wavefront = fftshift(fft2(wavefront), axes=[1, 2]) else: dist_nm = free_prop_cm * 1e7 l = np.prod(size_nm)**(1. / 3) crit_samp = lmbda_nm * dist_nm / l algorithm = 'TF' if mean_voxel_nm > crit_samp else 'IR' if algorithm == 'TF': h = get_kernel(dist_nm, lmbda_nm, voxel_nm, grid_shape, sign_convention=sign_convention) wavefront = ifft2( ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h, axes=[1, 2])) else: h = get_kernel_ir(dist_nm, lmbda_nm, voxel_nm, grid_shape, sign_convention=sign_convention) wavefront = ifft2( ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h, axes=[1, 2])) if return_intermediate: if free_prop_cm not in [None, 0]: wavefront_ls.append(abs(wavefront)) return wavefront, wavefront_ls else: return wavefront
def fft_convolve2d(x, f, mode='same', boundary='constant', fft_filter=False): r""" Performs fast 2d convolution in the frequency domain convolving each image channel with its corresponding filter channel. Parameters ---------- x : ``(channels, height, width)`` `ndarray` Image. f : ``(channels, height, width)`` `ndarray` Filter. mode : str {`full`, `same`, `valid`}, optional Determines the shape of the resulting convolution. boundary: str {`constant`, `symmetric`}, optional Determines how the image is padded. fft_filter: `bool`, optional If `True`, the filter is assumed to be defined on the frequency domain. If `False` the filter is assumed to be defined on the spatial domain. Returns ------- c: ``(channels, height, width)`` `ndarray` Result of convolving each image channel with its corresponding filter channel. """ if fft_filter: # extended shape is filter shape ext_shape = np.asarray(f.shape[-2:]) # extend image and filter ext_x = pad(x, ext_shape, boundary=boundary) # compute ffts of extended image fft_ext_x = fft2(ext_x) fft_ext_f = f else: # extended shape x_shape = np.asarray(x.shape[-2:]) f_shape = np.asarray(f.shape[-2:]) f_half_shape = (f_shape / 2).astype(int) ext_shape = x_shape + f_half_shape - 1 # extend image and filter ext_x = pad(x, ext_shape, boundary=boundary) ext_f = pad(f, ext_shape) # compute ffts of extended image and extended filter fft_ext_x = fft2(ext_x) fft_ext_f = fft2(ext_f) # compute extended convolution in Fourier domain fft_ext_c = fft_ext_f * fft_ext_x # compute ifft of extended convolution ext_c = np.real(ifftshift(ifft2(fft_ext_c), axes=(-2, -1))) if mode is 'full': return ext_c elif mode is 'same': return crop(ext_c, x_shape) elif mode is 'valid': return crop(ext_c, x_shape - f_half_shape + 1) else: raise ValueError( "mode={}, is not supported. The only supported " "modes are: 'full', 'same' and 'valid'.".format(mode))
def retrieve_phase(data, params, max_iters=200, pupil_tol=1e-8, mse_tol=1e-8, phase_only=False, mclass=HanserPSF): """Retrieve the phase across the objective's back pupil from an experimentally measured PSF. Follows: [Hanser, B. M.; Gustafsson, M. G. L.; Agard, D. A.; Sedat, J. W. Phase Retrieval for High-Numerical-Aperture Optical Systems. Optics Letters 2003, 28 (10), 801.](dx.doi.org/10.1364/OL.28.000801) Parameters ---------- data : ndarray (3 dim) The experimentally measured PSF of a subdiffractive source params : dict Parameters to pass to HanserPSF, size and zsize will be automatically updated from data.shape max_iters : int The maximum number of iterations to run, default is 200 pupil_tol : float the tolerance in percent change in change in pupil, default is 1e-8 mse_tol : float the tolerance in percent change for the mean squared error between data and simulated data, default is 1e-8 phase_only : bool True means only the phase of the back pupil is retrieved while the amplitude is not. Returns ------- PR_result : PhaseRetrievalResult An object that contains the phase retrieval result """ # make sure data is square assert data.shape[1] == data.shape[2], "Data is not square in x/y" assert data.ndim == 3, "Data doesn't have enough dims" # make sure the user hasn't screwed up the params params.update( dict(vec_corr="none", condition="none", zsize=data.shape[0], size=data.shape[-1])) # assume that data prep has been handled outside function # The field magnitude is the square root of the intensity mag = psqrt(data) # generate a model from parameters model = mclass(**params) # generate coordinates model._gen_kr() # start a list for iteration mse = np.zeros(max_iters) mse_diff = np.zeros(max_iters) pupil_diff = np.zeros(max_iters) # generate a pupil to start with new_pupil = model._gen_pupil() # save it as a mask mask = new_pupil.real # iterate old_mse, old_pupil = None, None for i in range(max_iters): # generate new mse and add it to the list model._gen_psf(new_pupil) new_mse = _calc_mse(data, model.PSFi) mse[i] = new_mse if i > 0: # calculate the difference in mse to test for convergence mse_diff[i] = abs(old_mse - new_mse) / old_mse # calculate the difference in pupil pupil_diff[i] = (abs(old_pupil - new_pupil)** 2).mean() / (abs(old_pupil)**2).mean() else: mse_diff[i] = np.nan pupil_diff[i] = np.nan # check tolerances, how much has the pupil changed, how much has the mse changed # and what's the absolute mse logger.info( "Iteration {}, mse_diff = {:.2g}, pupil_diff = {:.2g}".format( i, mse_diff[i], pupil_diff[i])) if pupil_diff[i] < pupil_tol or mse_diff[i] < mse_tol or mse[ i] < mse_tol: break # update old_mse old_mse = new_mse # retrieve new pupil old_pupil = new_pupil # keep phase phase = np.angle(model.PSFa.squeeze()) # replace magnitude with experimentally measured mag new_psf = mag * np.exp(1j * phase) # generate the new pupils new_pupils = fftn(ifftshift(new_psf, axes=(1, 2)), axes=(1, 2)) # undo defocus and take the mean new_pupils /= model._calc_defocus() new_pupil = new_pupils.mean(0) * mask # if phase only discard magnitude info if phase_only: new_pupil = np.exp(1j * np.angle(new_pupil)) * mask else: logger.warn("Reach max iterations without convergence") mse = mse[:i + 1] mse_diff = mse_diff[:i + 1] pupil_diff = pupil_diff[:i + 1] # shift mask mask = fftshift(mask) # shift phase then unwrap and mask phase = unwrap_phase(fftshift(np.angle(new_pupil))) * mask # shift magnitude magnitude = fftshift(abs(new_pupil)) * mask return PhaseRetrievalResult(magnitude, phase, mse, pupil_diff, mse_diff, model)
def ifftnc(x, axes): tmp = fft.fftshift(x, axes=axes) tmp = fft.ifftn(tmp, axes=axes) return fft.ifftshift(tmp, axes=axes)
def process_frames(self, data): sino = fft.fftshift(self.fft_object(data[0])) sino[self.row1:self.row2] = \ sino[self.row1:self.row2] * self.filtercomplex sino = fft.ifftshift(sino) return self.ifft_object(sino).real
def set_pulse_shapes(self, pump_field, probe_field, *, plot_fields=True): """Sets a list of 4 pulse amplitudes, given an input pump shape and probe shape. Assumes 4-wave mixing signals, and so 4 interactions """ self.efields = [pump_field, pump_field, probe_field, probe_field] if self.efield_t.size == 1: pass else: pump_tail = np.max(np.abs([pump_field[0], pump_field[-1]])) probe_tail = np.max(np.abs([probe_field[0], probe_field[-1]])) if pump_field.size != self.efield_t.size: warnings.warn( 'Pump must be evaluated on efield_t, the grid defined by dt and num_conv_points' ) if probe_field.size != self.efield_t.size: warnings.warn( 'Probe must be evaluated on efield_t, the grid defined by dt and num_conv_points' ) if pump_tail > np.max(np.abs(pump_field)) / 100: warnings.warn( 'Consider using larger num_conv_points, pump does not decay to less than 1% of maximum value in time domain' ) if probe_tail > np.max(np.abs(probe_field)) / 100: warnings.warn( 'Consider using larger num_conv_points, probe does not decay to less than 1% of maximum value in time domain' ) pump_fft = fftshift(fft(ifftshift(pump_field))) * self.dt probe_fft = fftshift(fft(ifftshift(probe_field))) * self.dt pump_fft_tail = np.max(np.abs([pump_fft[0], pump_fft[-1]])) probe_fft_tail = np.max(np.abs([probe_fft[0], probe_fft[-1]])) if pump_fft_tail > np.max(np.abs(pump_fft)) / 100: warnings.warn( '''Consider using smaller value of dt, pump does not decay to less than 1% of maximum value in frequency domain''' ) if probe_fft_tail > np.max(np.abs(probe_fft)) / 100: warnings.warn( '''Consider using smaller value of dt, probe does not decay to less than 1% of maximum value in frequency domain''' ) if plot_fields: fig, axes = plt.subplots(2, 2) l1, l2, = axes[0, 0].plot(self.efield_t, np.real(pump_field), self.efield_t, np.imag(pump_field)) plt.legend([l1, l2], ['Real', 'Imag']) axes[0, 1].plot(self.efield_w, np.real(pump_fft), self.efield_w, np.imag(pump_fft)) axes[1, 0].plot(self.efield_t, np.real(probe_field), self.efield_t, np.imag(probe_field)) axes[1, 1].plot(self.efield_w, np.real(probe_fft), self.efield_w, np.imag(probe_fft)) axes[0, 0].set_ylabel('Pump Amp') axes[1, 0].set_ylabel('Probe Amp') axes[1, 0].set_xlabel('Time') axes[1, 1].set_xlabel('Frequency') fig.suptitle( 'Check that pump and probe are well-resolved in time and frequency' )
def py_zogy(Nf, Rf, P_Nf, P_Rf, S_Nf, S_Rf, SN, SR, dx=0.25, dy=0.25): '''Python implementation of ZOGY image subtraction algorithm. As per Frank's instructions, will assume images have been aligned, background subtracted, and gain-matched. Arguments: N: New image (filename) R: Reference image (filename) P_N: PSF of New image (filename) P_R: PSF or Reference image (filename) S_N: 2D Uncertainty (sigma) of New image (filename) S_R: 2D Uncertainty (sigma) of Reference image (filename) SN: Average uncertainty (sigma) of New image SR: Average uncertainty (sigma) of Reference image dx: Astrometric uncertainty (sigma) in x coordinate dy: Astrometric uncertainty (sigma) in y coordinate Returns: D: Subtracted image P_D: PSF of subtracted image S_corr: Corrected subtracted image ''' # Load the new and ref images into memory N = fits.open(Nf)[0].data R = fits.open(Rf)[0].data # Load the PSFs into memory P_N_small = fits.open(P_Nf)[0].data P_R_small = fits.open(P_Rf)[0].data # Place PSF at center of image with same size as new / reference P_N = np.zeros(N.shape) P_R = np.zeros(R.shape) idx = [slice(N.shape[0]/2 - P_N_small.shape[0]/2, N.shape[0]/2 + P_N_small.shape[0]/2 + 1), slice(N.shape[1]/2 - P_N_small.shape[1]/2, N.shape[1]/2 + P_N_small.shape[1]/2 + 1)] P_N[idx] = P_N_small P_R[idx] = P_R_small # Shift the PSF to the origin so it will not introduce a shift P_N = fft.fftshift(P_N) P_R = fft.fftshift(P_R) # Take all the Fourier Transforms N_hat = fft.fft2(N) R_hat = fft.fft2(R) P_N_hat = fft.fft2(P_N) P_R_hat = fft.fft2(P_R) # Fourier Transform of Difference Image (Equation 13) D_hat_num = (P_R_hat * N_hat - P_N_hat * R_hat) D_hat_den = np.sqrt(SN**2 * np.abs(P_R_hat**2) + SR**2 * np.abs(P_N_hat**2)) D_hat = D_hat_num / D_hat_den # Flux-based zero point (Equation 15) FD = 1. / np.sqrt(SN**2 + SR**2) # Difference Image # TODO: Why is the FD normalization in there? D = np.real(fft.ifft2(D_hat)) / FD # Fourier Transform of PSF of Subtraction Image (Equation 14) P_D_hat = P_R_hat * P_N_hat / FD / D_hat_den # PSF of Subtraction Image P_D = np.real(fft.ifft2(P_D_hat)) P_D = fft.ifftshift(P_D) P_D = P_D[idx] # Fourier Transform of Score Image (Equation 17) S_hat = FD * D_hat * np.conj(P_D_hat) # Score Image S = np.real(fft.ifft2(S_hat)) # Now start calculating Scorr matrix (including all noise terms) # Start out with source noise # Load the sigma images into memory S_N = fits.open(S_Nf)[0].data S_R = fits.open(S_Rf)[0].data # Sigma to variance V_N = S_N**2 V_R = S_R**2 # Fourier Transform of variance images V_N_hat = fft.fft2(V_N) V_R_hat = fft.fft2(V_R) # Equation 28 kr_hat = np.conj(P_R_hat) * np.abs(P_N_hat**2) / (D_hat_den**2) kr = np.real(fft.ifft2(kr_hat)) # Equation 29 kn_hat = np.conj(P_N_hat) * np.abs(P_R_hat**2) / (D_hat_den**2) kn = np.real(fft.ifft2(kn_hat)) # Noise in New Image: Equation 26 V_S_N = np.real(fft.ifft2(V_N_hat * fft.fft2(kn**2))) # Noise in Reference Image: Equation 27 V_S_R = np.real(fft.ifft2(V_R_hat * fft.fft2(kr**2))) # Astrometric Noise # Equation 31 # TODO: Check axis (0/1) vs x/y coordinates S_N = np.real(fft.ifft2(kn_hat * N_hat)) dSNdx = S_N - np.roll(S_N, 1, axis=1) dSNdy = S_N - np.roll(S_N, 1, axis=0) # Equation 30 V_ast_S_N = dx**2 * dSNdx**2 + dy**2 * dSNdy**2 # Equation 33 S_R = np.real(fft.ifft2(kr_hat * R_hat)) dSRdx = S_R - np.roll(S_R, 1, axis=1) dSRdy = S_R - np.roll(S_R, 1, axis=0) # Equation 32 V_ast_S_R = dx**2 * dSRdx**2 + dy**2 * dSRdy**2 # Calculate Scorr S_corr = S / np.sqrt(V_S_N + V_S_R + V_ast_S_N + V_ast_S_R) return D, P_D, S_corr
def fillWithGRFFromTemplate(self, twodPower, bufferFactor=1, threads=1): """ Generate a Gaussian random field from an input power spectrum specified as a 2d powerMap Notes ----- BufferFactor = 1 means the map will have periodic boundary function, while BufferFactor > 1 means the map will be genrated on a patch bufferFactor times larger in each dimension and then cut out so as to have non-periodic boundary conditions. Fills the data field of the map with the GRF realization. """ ft = fftTools.fftFromLiteMap(self, threads=threads) Ny = self.Ny * bufferFactor Nx = self.Nx * bufferFactor bufferFactor = int(bufferFactor) assert bufferFactor >= 1 realPart = numpy.zeros([Ny, Nx]) imgPart = numpy.zeros([Ny, Nx]) ly = fftfreq(Ny, d=self.pixScaleY) * (2 * numpy.pi) lx = fftfreq(Nx, d=self.pixScaleX) * (2 * numpy.pi) # print ly modLMap = numpy.zeros([Ny, Nx]) iy, ix = numpy.mgrid[0:Ny, 0:Nx] modLMap[iy, ix] = numpy.sqrt(ly[iy] ** 2 + lx[ix] ** 2) # divide out area factor area = twodPower.Nx * twodPower.Ny * twodPower.pixScaleX * twodPower.pixScaleY twodPower.powerMap *= (twodPower.Nx * twodPower.Ny) ** 2 / area if bufferFactor > 1 or twodPower.Nx != Nx or twodPower.Ny != Ny: lx_shifted = fftshift(twodPower.lx) ly_shifted = fftshift(twodPower.ly) twodPower_shifted = fftshift(twodPower.powerMap) f_interp = interp2d(lx_shifted, ly_shifted, twodPower_shifted) # ell = numpy.ravel(twodPower.modLMap) # Cell = numpy.ravel(twodPower.powerMap) # print ell # print Cell # s = splrep(ell,Cell,k=3) # # # ll = numpy.ravel(modLMap) # kk = splev(ll,s) kk = f_interp(fftshift(lx), fftshift(ly)) kk = ifftshift(kk) # id = numpy.where(modLMap > ell.max()) # kk[id] = 0. # add a cosine ^2 falloff at the very end # id2 = numpy.where( (ll> (ell.max()-500)) & (ll<ell.max())) # lEnd = ll[id2] # kk[id2] *= numpy.cos((lEnd-lEnd.min())/(lEnd.max() -lEnd.min())*numpy.pi/2) # pylab.loglog(ll,kk) area = Nx * Ny * self.pixScaleX * self.pixScaleY # p = numpy.reshape(kk,[Ny,Nx]) /area * (Nx*Ny)**2 p = kk # / area * (Nx*Ny)**2 else: area = Nx * Ny * self.pixScaleX * self.pixScaleY p = twodPower.powerMap # /area*(Nx*Ny)**2 realPart = numpy.sqrt(p) * numpy.random.randn(Ny, Nx) imgPart = numpy.sqrt(p) * numpy.random.randn(Ny, Nx) kMap = realPart + 1j * imgPart if have_pyFFTW: data = numpy.real(ifft2(kMap, threads=threads)) else: data = numpy.real(ifft2(kMap)) b = bufferFactor self.data = data[(b - 1) / 2 * self.Ny : (b + 1) / 2 * self.Ny, (b - 1) / 2 * self.Nx : (b + 1) / 2 * self.Nx]
def upgradePixelPitch(m, N=1, threads=1): """ Go to finer pixels with fourier interpolation. Parameters ---------- m : liteMap The liteMap object holding the data to upgrade the pixel size of. N : int, optional Go to 2^N times smaller pixels. Default is 1. threads : int, optional Number of threads to use in pyFFTW calculations. Default is 1. Returns ------- mNew : liteMap The map with smaller pixels. """ if N < 1: return m.copy() Ny = m.Ny * 2 ** N Nx = m.Nx * 2 ** N npix = Ny * Nx if have_pyFFTW: ft = fft2(m.data, threads=threads) else: ft = fft2(m.data) ftShifted = fftshift(ft) newFtShifted = numpy.zeros((Ny, Nx), dtype=numpy.complex128) # From the numpy.fft.fftshift help: # """ # Shift zero-frequency component to center of spectrum. # # This function swaps half-spaces for all axes listed (defaults to all). # If len(x) is even then the Nyquist component is y[0]. # """ # # So in the case that we have an odd dimension in our map, we want to put # the extra zero at the beginning if m.Nx % 2 != 0: offsetX = (Nx - m.Nx) / 2 + 1 else: offsetX = (Nx - m.Nx) / 2 if m.Ny % 2 != 0: offsetY = (Ny - m.Ny) / 2 + 1 else: offsetY = (Ny - m.Ny) / 2 newFtShifted[offsetY : offsetY + m.Ny, offsetX : offsetX + m.Nx] = ftShifted del ftShifted ftNew = ifftshift(newFtShifted) del newFtShifted # Finally, deconvolve by the pixel window mPix = numpy.copy(numpy.real(ftNew)) mPix[:] = 0.0 mPix[ mPix.shape[0] / 2 - (2 ** (N - 1)) : mPix.shape[0] / 2 + (2 ** (N - 1)), mPix.shape[1] / 2 - (2 ** (N - 1)) : mPix.shape[1] / 2 + (2 ** (N - 1)), ] = (1.0 / (2.0 ** N) ** 2) if have_pyFFTW: ftPix = fft2(mPix, threads=threads) else: ftPix = fft2(mPix) del mPix inds = numpy.where(ftNew != 0) ftNew[inds] /= numpy.abs(ftPix[inds]) if have_pyFFTW: newData = ifft2(ftNew, threads=threads) * (2 ** N) ** 2 else: newData = ifft2(ftNew) * (2 ** N) ** 2 del ftNew del ftPix x0_new, y0_new = m.pixToSky(0, 0) m = m.copy() # don't overwrite original m.wcs.header.update("NAXIS1", 2 ** N * m.wcs.header["NAXIS1"]) m.wcs.header.update("NAXIS2", 2 ** N * m.wcs.header["NAXIS2"]) m.wcs.header.update("CDELT1", m.wcs.header["CDELT1"] / 2.0 ** N) m.wcs.header.update("CDELT2", m.wcs.header["CDELT2"] / 2.0 ** N) m.wcs.updateFromHeader() p_x, p_y = m.skyToPix(x0_new, y0_new) m.wcs.header.update("CRPIX1", m.wcs.header["CRPIX1"] - p_x) m.wcs.header.update("CRPIX2", m.wcs.header["CRPIX2"] - p_y) m.wcs.updateFromHeader() mNew = liteMapFromDataAndWCS(numpy.real(newData), m.wcs) mNew.data[:] = numpy.real(newData[:]) return mNew
def fft_convolve2d(x, f, mode='same', boundary='constant', fft_filter=False): r""" Performs fast 2d convolution in the frequency domain convolving each image channel with its corresponding filter channel. Parameters ---------- x : ``(channels, height, width)`` `ndarray` Image. f : ``(channels, height, width)`` `ndarray` Filter. mode : str {`full`, `same`, `valid`}, optional Determines the shape of the resulting convolution. boundary: str {`constant`, `symmetric`}, optional Determines how the image is padded. fft_filter: `bool`, optional If `True`, the filter is assumed to be defined on the frequency domain. If `False` the filter is assumed to be defined on the spatial domain. Returns ------- c: ``(channels, height, width)`` `ndarray` Result of convolving each image channel with its corresponding filter channel. """ if fft_filter: # extended shape is filter shape ext_shape = np.asarray(f.shape[-2:]) # extend image and filter ext_x = pad(x, ext_shape, boundary=boundary) # compute ffts of extended image fft_ext_x = fft2(ext_x) fft_ext_f = f else: # extended shape x_shape = np.asarray(x.shape[-2:]) f_shape = np.asarray(f.shape[-2:]) f_half_shape = (f_shape / 2).astype(int) ext_shape = x_shape + f_half_shape - 1 # extend image and filter ext_x = pad(x, ext_shape, boundary=boundary) ext_f = pad(f, ext_shape) # compute ffts of extended image and extended filter fft_ext_x = fft2(ext_x) fft_ext_f = fft2(ext_f) # compute extended convolution in Fourier domain fft_ext_c = fft_ext_f * fft_ext_x # compute ifft of extended convolution ext_c = np.real(ifftshift(ifft2(fft_ext_c), axes=(-2, -1))) if mode is 'full': return ext_c elif mode is 'same': return crop(ext_c, x_shape) elif mode is 'valid': return crop(ext_c, x_shape - f_half_shape + 1) else: raise ValueError("mode={}, is not supported. The only supported " "modes are: 'full', 'same' and 'valid'.".format(mode))