def test_resize(self): # Zero-pad x = np.array([1, 2, 3]) oshape = [5] y = util.resize(x, oshape) npt.assert_allclose(y, [0, 1, 2, 3, 0]) x = np.array([1, 2, 3]) oshape = [4] y = util.resize(x, oshape) npt.assert_allclose(y, [0, 1, 2, 3]) x = np.array([1, 2]) oshape = [5] y = util.resize(x, oshape) npt.assert_allclose(y, [0, 1, 2, 0, 0]) x = np.array([1, 2]) oshape = [4] y = util.resize(x, oshape) npt.assert_allclose(y, [0, 1, 2, 0]) # Zero-pad non centered x = np.array([1, 2, 3]) oshape = [5] y = util.resize(x, oshape, oshift=[0]) npt.assert_allclose(y, [1, 2, 3, 0, 0]) # Crop x = np.array([0, 1, 2, 3, 0]) oshape = [3] y = util.resize(x, oshape) npt.assert_allclose(y, [1, 2, 3]) x = np.array([0, 1, 2, 3]) oshape = [3] y = util.resize(x, oshape) npt.assert_allclose(y, [1, 2, 3]) x = np.array([0, 1, 2, 0, 0]) oshape = [2] y = util.resize(x, oshape) npt.assert_allclose(y, [1, 2]) x = np.array([0, 1, 2, 0]) oshape = [2] y = util.resize(x, oshape) npt.assert_allclose(y, [1, 2]) # Crop non centered x = np.array([1, 2, 3, 0, 0]) oshape = [3] y = util.resize(x, oshape, ishift=[0]) npt.assert_allclose(y, [1, 2, 3])
def _fft_convolve_adjoint_filter(x, y, mode='full'): ndim = x.ndim - 2 batch_size = len(x) output_channel = y.shape[1] input_channel = x.shape[1] output_shape = y.shape[-ndim:] input_shape = x.shape[-ndim:] if mode == 'full': filter_shape = tuple(p - m + 1 for m, p in zip(input_shape, output_shape)) pad_shape = output_shape elif mode == 'valid': filter_shape = tuple(m - p + 1 for m, p in zip(input_shape, output_shape)) pad_shape = input_shape dtype = x.dtype device = backend.get_device(x) xp = device.xp with device: x = xp.conj(util.flip(x, axes=range(-ndim, 0))) x = x.reshape((batch_size, 1, input_channel) + input_shape) y = y.reshape((batch_size, output_channel, 1) + output_shape) x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape, oshift=[0] * x.ndim) y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape, oshift=[0] * y.ndim) if np.issubdtype(dtype, np.floating): x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho') y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.sum(x_fft * y_fft, axis=0) W = xp.fft.irfftn(W_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False) y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False) W_fft = xp.sum(x_fft * y_fft, axis=0) W = fourier.ifft(W_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0, 0] + [m - 1 for m in input_shape] elif mode == 'valid': shift = [0, 0] + [p - 1 for p in output_shape] W = util.resize(W, (output_channel, input_channel) + filter_shape, ishift=shift) W *= util.prod(pad_shape)**0.5 return W
def _fft_convolve_adjoint_input(W, y, mode='full'): ndim = y.ndim - 2 batch_size = len(y) output_channel, input_channel = W.shape[:2] output_shape = y.shape[-ndim:] filter_shape = W.shape[-ndim:] if mode == 'full': input_shape = tuple(p - n + 1 for p, n in zip(output_shape, filter_shape)) pad_shape = output_shape elif mode == 'valid': input_shape = tuple(p + n - 1 for p, n in zip(output_shape, filter_shape)) pad_shape = input_shape dtype = y.dtype device = backend.get_device(y) xp = device.xp with device: y = y.reshape((batch_size, output_channel, 1) + output_shape) W = xp.conj(util.flip(W, axes=range(-ndim, 0))) y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape, oshift=[0] * y.ndim) W_pad = util.resize(W, (output_channel, input_channel) + pad_shape, oshift=[0] * W.ndim) if np.issubdtype(dtype, np.floating): y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho') x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = xp.fft.irfftn(x_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False) W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False) x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0, 0] + [n - 1 for n in filter_shape] elif mode == 'valid': shift = [0] * x.ndim x = util.resize(x, (batch_size, input_channel) + input_shape, ishift=shift) x *= util.prod(pad_shape)**0.5 return x
def _fft_convolve(x, W, mode='full'): ndim = x.ndim - 2 batch_size = len(x) output_channel, input_channel = W.shape[:2] input_shape = x.shape[-ndim:] filter_shape = W.shape[-ndim:] if mode == 'full': output_shape = tuple(m + n - 1 for m, n in zip(input_shape, filter_shape)) pad_shape = output_shape elif mode == 'valid': output_shape = tuple(m - n + 1 for m, n in zip(input_shape, filter_shape)) pad_shape = input_shape dtype = x.dtype device = backend.get_device(x) xp = device.xp with device: x = x.reshape((batch_size, 1, input_channel) + input_shape) x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape, oshift=[0] * x.ndim) W_pad = util.resize(W, (output_channel, input_channel) + pad_shape, oshift=[0] * W.ndim) if np.issubdtype(dtype, np.floating): x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho') y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1) y = xp.fft.irfftn(y_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False) W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False) y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1) y = fourier.ifft(y_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0] * y.ndim elif mode == 'valid': shift = [0, 0] + [n - 1 for n in filter_shape] y = util.resize(y, (batch_size, output_channel) + output_shape, ishift=shift) y *= util.prod(pad_shape)**0.5 return y
def nufft(input, coord, oversamp=1.25, width=4): """Non-uniform Fast Fourier Transform. Args: input (array): input signal domain array of shape (..., n_{ndim - 1}, ..., n_1, n_0), where ndim is specified by coord.shape[-1]. The nufft is applied on the last ndim axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimensions to apply the nufft. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. n (int): number of sampling points of the interpolation kernel. Returns: array: Fourier domain data of shape input.shape[:-ndim] + coord.shape[:-1]. References: Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation IEEE Transactions on Signal Processing, 51(2), 560-574. Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. IEEE transactions on medical imaging, 24(6), 799-808. """ ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 os_shape = _get_oversamp_shape(input.shape, ndim, oversamp) output = input.copy() # Apodize _apodize(output, ndim, oversamp, width, beta) # Zero-pad output /= util.prod(input.shape[-ndim:])**0.5 output = util.resize(output, os_shape) # FFT output = fft(output, axes=range(-ndim, 0), norm=None) # Interpolate coord = _scale_coord(coord, input.shape, oversamp) output = interp.interpolate(output, coord, kernel='kaiser_bessel', width=width, param=beta) output /= width**ndim return output
def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4): """Adjoint non-uniform Fast Fourier Transform. Args: input (array): input Fourier domain array of shape (...) + coord.shape[:-1]. That is, the last dimensions of input must match the first dimensions of coord. The nufft_adjoint is applied on the last coord.ndim - 1 axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimension to apply nufft adjoint. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oshape (tuple of ints): output shape of the form (..., n_{ndim - 1}, ..., n_1, n_0). oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. n (int): number of sampling points of the interpolation kernel. Returns: array: signal domain array with shape specified by oshape. See Also: :func:`sigpy.nufft.nufft` """ ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 if oshape is None: oshape = list(input.shape[:-coord.ndim + 1]) + estimate_shape(coord) else: oshape = list(oshape) os_shape = _get_oversamp_shape(oshape, ndim, oversamp) # Gridding coord = _scale_coord(coord, oshape, oversamp) output = interp.gridding(input, coord, os_shape, kernel='kaiser_bessel', width=width, param=beta) output /= width**ndim # IFFT output = ifft(output, axes=range(-ndim, 0), norm=None) # Crop output = util.resize(output, oshape) output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5 # Apodize _apodize(output, ndim, oversamp, width, beta) return output
def _ifftc(input, oshape=None, axes=None, norm='ortho'): ndim = input.ndim axes = util._normalize_axes(axes, ndim) xp = backend.get_array_module(input) if oshape is None: oshape = input.shape tmp = util.resize(input, oshape) tmp = xp.fft.ifftshift(tmp, axes=axes) tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm) output = xp.fft.fftshift(tmp, axes=axes) return output
def _ifftc(input, oshape=None, axes=None, norm='ortho'): ndim = input.ndim axes = util._normalize_axes(axes, ndim) device = backend.get_device(input) xp = device.xp if oshape is None: oshape = input.shape with device: tmp = util.resize(input, oshape) tmp = xp.fft.ifftshift(tmp, axes=axes) tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm) output = xp.fft.fftshift(tmp, axes=axes) return output
def iwt(input, oshape, coeff_slices, wave_name='db4', axes=None, level=None): """Inverse wavelet transform. Args: input (array): Input array. oshape (tuple of ints): Output shape. coeff_slices (list of slice): Slices to split coefficients. axes (None or tuple of int): Axes to perform wavelet transform. wave_name (str): Wavelet name. level (None or int): Number of wavelet levels. """ device = backend.get_device(input) input = backend.to_device(input, backend.cpu_device) input = pywt.array_to_coeffs(input, coeff_slices, output_format='wavedecn') output = pywt.waverecn(input, wave_name, mode='zero', axes=axes) output = util.resize(output, oshape) output = backend.to_device(output, device) return output
def fwt(input, wave_name='db4', axes=None, level=None): """Forward wavelet transform. Args: input (array): Input array. axes (None or tuple of int): Axes to perform wavelet transform. wave_name (str): Wavelet name. level (None or int): Number of wavelet levels. """ device = backend.get_device(input) input = backend.to_device(input, backend.cpu_device) zshape = [((i + 1) // 2) * 2 for i in input.shape] zinput = util.resize(input, zshape) coeffs = pywt.wavedecn(zinput, wave_name, mode='zero', axes=axes, level=level) output, _ = pywt.coeffs_to_array(coeffs, axes=axes) output = backend.to_device(output, device) return output
def _apply(self, input): return util.resize(input, self.oshape, ishift=self.ishift, oshift=self.oshift)
def _apply(self, input): with backend.get_device(input): return util.resize(input, self.oshape, ishift=self.ishift, oshift=self.oshift)
def _get_dataset(self, i): return util.resize(np.load(self.filepaths[i]), self.shape[1:])