def test_convolve_filter_adjoint_full(self): mode = 'full' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for device in devices: xp = device.xp with device: for dtype in dtypes: with self.subTest(dtype=dtype, device=device): data = xp.ones([1, 3], dtype=dtype) output = xp.ones([1, 5], dtype=dtype) filt_shape = [1, 3] filt = backend.to_device( conv.convolve_filter_adjoint(output, data, filt_shape, mode=mode)) npt.assert_allclose(filt, [[3, 3, 3]], atol=1e-5) data = xp.ones([1, 3], dtype=dtype) output = xp.ones([1, 4], dtype=dtype) filt_shape = [1, 2] filt = backend.to_device( conv.convolve_filter_adjoint(output, data, filt_shape, mode=mode)) npt.assert_allclose(filt, [[3, 3]], atol=1e-5) data = xp.ones([1, 1, 3], dtype=dtype) output = xp.ones([2, 1, 5], dtype=dtype) filt_shape = [2, 1, 1, 3] filt = backend.to_device( conv.convolve_filter_adjoint(output, data, filt_shape, mode=mode, multi_channel=True), backend.cpu_device) npt.assert_allclose(filt, [[[[3, 3, 3]]], [[[3, 3, 3]]]], atol=1e-5) data = xp.ones([1, 1, 3], dtype=dtype) output = xp.ones([2, 1, 3], dtype=dtype) filt_shape = [2, 1, 1, 3] strides = [1, 2] filt = backend.to_device( conv.convolve_filter_adjoint(output, data, filt_shape, mode=mode, strides=strides, multi_channel=True), backend.cpu_device) npt.assert_allclose(filt, [[[[2, 1, 2]]], [[[2, 1, 2]]]], atol=1e-5)
def test_convolve_adjoint_input_full(self): mode = 'full' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for device in devices: xp = device.xp with device: for dtype in [np.float32, np.float64, np.complex64, np.complex128]: y = xp.ones([1, 5], dtype=dtype) W = xp.ones([1, 3], dtype=dtype) x = backend.to_device(conv.convolve_adjoint_input(W, y, mode=mode), backend.cpu_device) npt.assert_allclose(x, [[3, 3, 3]], atol=1e-5) y = xp.ones([1, 4], dtype=dtype) W = xp.ones([1, 2], dtype=dtype) x = backend.to_device(conv.convolve_adjoint_input(W, y, mode=mode), backend.cpu_device) npt.assert_allclose(x, [[2, 2, 2]], atol=1e-5) y = xp.ones([2, 1, 5], dtype=dtype) W = xp.ones([2, 1, 3], dtype=dtype) x = backend.to_device(conv.convolve_adjoint_input(W, y, mode=mode, output_multi_channel=True), backend.cpu_device) npt.assert_allclose(x, [[6, 6, 6]], atol=1e-5)
def test_convolve_valid(self): mode = 'valid' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for D in [1, 2, 3]: for device in devices: xp = device.xp with device: for dtype in dtypes: with self.subTest(D=D, dtype=dtype, device=device): data = util.dirac([3] + [1] * (D - 1), device=device, dtype=dtype) filt = xp.ones([3] + [1] * (D - 1), dtype=dtype) output = backend.to_device(conv.convolve( data, filt, mode=mode)) npt.assert_allclose( output, np.ones([1] * D), atol=1e-5) data = util.dirac([3] + [1] * (D - 1), device=device, dtype=dtype) filt = xp.ones([2] + [1] * (D - 1), dtype=dtype) output = backend.to_device(conv.convolve( data, filt, mode=mode)) npt.assert_allclose( output, np.ones([2] + [1] * (D - 1)), atol=1e-5) data = util.dirac([1, 3] + [1] * (D - 1), device=device, dtype=dtype) filt = xp.ones([2, 1, 3] + [1] * (D - 1), dtype=dtype) output = backend.to_device( conv.convolve(data, filt, mode=mode, multi_channel=True), backend.cpu_device) npt.assert_allclose( output, np.ones([2, 1] + [1] * (D - 1)), atol=1e-5) data = util.dirac([1, 3] + [1] * (D - 1), device=device, dtype=dtype) filt = xp.ones([2, 1, 3] + [1] * (D - 1), dtype=dtype) strides = [2] + [1] * (D - 1) output = backend.to_device( conv.convolve(data, filt, mode=mode, strides=strides, multi_channel=True), backend.cpu_device) npt.assert_allclose( output, np.ones([2, 1] + [1] * (D - 1)), atol=1e-5)
def test_convolve_adjoint_filter_valid(self): mode = 'valid' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) ndim = 2 for device in devices: xp = device.xp with device: for dtype in [np.float32, np.float64, np.complex64, np.complex128]: x = xp.ones([1, 3], dtype=dtype) y = xp.ones([1, 1], dtype=dtype) W = backend.to_device(conv.convolve_adjoint_filter(x, y, ndim, mode=mode), backend.cpu_device) npt.assert_allclose(W, [[1, 1, 1]], atol=1e-5) x = xp.ones([1, 3], dtype=dtype) y = xp.ones([1, 2], dtype=dtype) W = backend.to_device(conv.convolve_adjoint_filter(x, y, ndim, mode=mode), backend.cpu_device) npt.assert_allclose(W, [[2, 2]], atol=1e-5) x = xp.ones([1, 1, 3], dtype=dtype) y = xp.ones([2, 1, 1], dtype=dtype) W = backend.to_device(conv.convolve_adjoint_filter(x, y, ndim, mode=mode, output_multi_channel=True), backend.cpu_device) npt.assert_allclose(W, [[[1, 1, 1]], [[1, 1, 1]]], atol=1e-5)
def test_convolve_full(self): mode = 'full' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for device in devices: xp = device.xp with device: for dtype in [np.float32, np.float64, np.complex64, np.complex128]: x = util.dirac([1, 3], device=device, dtype=dtype) W = xp.ones([1, 3], dtype=dtype) y = backend.to_device(conv.convolve(x, W, mode=mode), backend.cpu_device) npt.assert_allclose(y, [[0, 1, 1, 1, 0]], atol=1e-5) x = util.dirac([1, 3], device=device, dtype=dtype) W = xp.ones([1, 2], dtype=dtype) y = backend.to_device(conv.convolve(x, W, mode=mode), backend.cpu_device) npt.assert_allclose(y, [[0, 1, 1, 0]], atol=1e-5) x = util.dirac([1, 3], device=device, dtype=dtype) W = xp.ones([2, 1, 3], dtype=dtype) y = backend.to_device(conv.convolve(x, W, mode=mode, output_multi_channel=True), backend.cpu_device) npt.assert_allclose(y, [[[0, 1, 1, 1, 0]], [[0, 1, 1, 1, 0]]], atol=1e-5)
def convolve(data, filt, mode='full', strides=None, multi_channel=False): r"""Convolution that supports multi-dimensional and multi-channel inputs. This function follows the signal processing definition of convolution. Args: data (array): data array of shape: :math:`[..., m_1, ..., m_D]` if multi_channel is False, :math:`[..., c_i, m_1, ..., m_D]` otherwise. filt (array): filter array of shape: :math:`[n_1, ..., n_D]` if multi_channel is False :math:`[c_o, c_i, n_1, ..., n_D]` otherwise. mode (str): {'full', 'valid'}. strides (None or tuple of ints): convolution strides of length D. multi_channel (bool): specify if input/output has multiple channels. Returns: array: output array of shape: :math:`[..., p_1, ..., p_D]` if multi_channel is False, :math:`[..., c_o, p_1, ..., p_D]` otherwise. """ device = backend.get_device(data) filt = backend.to_device(filt, device) with device: filt = filt.astype(data.dtype, copy=False) if device == backend.cpu_device: output = _convolve(data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: # pragma: no cover if config.cudnn_enabled: if np.issubdtype(data.dtype, np.floating): output = _convolve_cuda(data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: output = _complex(_convolve_cuda, data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: data = backend.to_device(data) filt = backend.to_device(filt) output = _convolve_data_adjoint(data, output, mode=mode, strides=strides, multi_channel=multi_channel) output = backend.to_device(output, device) return output
def _apply(self, input): device = backend.get_device(input) coord = backend.to_device(self.coord, device) kernel = backend.to_device(self.kernel, device) with device: return interp.gridding(input, self.oshape, self.width, kernel, coord)
def test_convolve_data_adjoint_valid(self): mode = 'valid' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for device in devices: xp = device.xp with device: for dtype in dtypes: with self.subTest(dtype=dtype, device=device): output = xp.ones([1, 1], dtype=dtype) filt = xp.ones([1, 3], dtype=dtype) data_shape = [1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode)) npt.assert_allclose(data, [[1, 1, 1]], atol=1e-5) output = xp.ones([1, 2], dtype=dtype) filt = xp.ones([1, 2], dtype=dtype) data_shape = [1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode)) npt.assert_allclose(data, [[1, 2, 1]], atol=1e-5) output = xp.ones([2, 1, 1], dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) data_shape = [1, 1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode, multi_channel=True), backend.cpu_device) npt.assert_allclose(data, [[[2, 2, 2]]], atol=1e-5) output = xp.ones([2, 1, 1], dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) data_shape = [1, 1, 4] strides = [1, 2] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode, strides=strides, multi_channel=True), backend.cpu_device) npt.assert_allclose(data, [[[2, 2, 2, 0]]], atol=1e-5)
def _apply(self, input): device = backend.get_device(input) coord = backend.to_device(self.coord, device) kernel = backend.to_device(self.kernel, device) shift = backend.to_device(self.shift, device) with device: return interp.interpolate(input, self.width, kernel, coord * self.scale + shift)
def _scale_coord(coord, shape, oversamp): ndim = coord.shape[-1] device = backend.get_device(coord) scale = backend.to_device( [_get_ugly_number(oversamp * i) / i for i in shape[-ndim:]], device) shift = backend.to_device( [_get_ugly_number(oversamp * i) // 2 for i in shape[-ndim:]], device) with device: coord = scale * coord + shift return coord
def test_convolve_full(self): mode = 'full' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) dtypes = [np.float32, np.float64, np.complex64, np.complex128] for device in devices: xp = device.xp with device: for dtype in dtypes: with self.subTest(dtype=dtype, device=device): data = util.dirac([1, 3], device=device, dtype=dtype) filt = xp.ones([1, 3], dtype=dtype) output = backend.to_device( conv.convolve(data, filt, mode=mode)) npt.assert_allclose(output, [[0, 1, 1, 1, 0]], atol=1e-5) data = util.dirac([1, 3], device=device, dtype=dtype) filt = xp.ones([1, 2], dtype=dtype) output = backend.to_device( conv.convolve(data, filt, mode=mode)) npt.assert_allclose(output, [[0, 1, 1, 0]], atol=1e-5) data = util.dirac([1, 1, 3], device=device, dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) output = backend.to_device( conv.convolve(data, filt, mode=mode, multi_channel=True), backend.cpu_device) npt.assert_allclose( output, [[[0, 1, 1, 1, 0]], [[0, 1, 1, 1, 0]]], atol=1e-5) data = util.dirac([1, 1, 3], device=device, dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) strides = [1, 2] output = backend.to_device( conv.convolve(data, filt, mode=mode, strides=strides, multi_channel=True)) npt.assert_allclose(output, [[[0, 1, 0]], [[0, 1, 0]]], atol=1e-5)
def _apply(self, input): data = backend.to_device(self.data, backend.get_device(input)) return conv.convolve(data, input, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def _apply(self, input): device = backend.get_device(input) with device: coord = backend.to_device(self.coord, device) return interp.interpolate(input, coord, kernel=self.kernel, width=self.width, param=self.param)
def _apply(self, input): device = backend.get_device(input) filt = backend.to_device(self.filt, device) with device: return conv.convolve(input, filt, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def _apply(self, input): device = backend.get_device(input) with device: coord = backend.to_device(self.coord, device) return fourier.nufft_adjoint( input, coord, self.oshape, oversamp=self.oversamp, width=self.width)
def soft_thresh(lamda, input): r"""Soft threshold. Performs: .. math:: (| x | - \lambda)_+ \text{sgn}(x) Args: lamda (float, or array): Threshold parameter. input (array) Returns: array: soft-thresholded result. """ device = backend.get_device(input) xp = device.xp if xp == np: return _soft_thresh(lamda, input) else: # pragma: no cover if np.isscalar(lamda): lamda = backend.to_device(lamda, device) return _soft_thresh_cuda(lamda, input)
def _apply(self, input): device = backend.get_device(input) data = backend.to_device(self.data, device) with device: return conv.convolve_filter_adjoint( input, data, self.oshape, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def axpy(y, a, x): """Compute y = a * x + y. Args: y (array): Output array. a (scalar): Input scalar. x (array): Input array. """ device = backend.get_device(x) x = backend.to_device(x, device) a = backend.to_device(a, device) with device: if device == backend.cpu_device: _axpy(y, a, x, out=y) else: _axpy_cuda(y, a, x)
def xpay(y, a, x): """Compute y = x + a * y. Args: y (array): Output array. a (scalar): Input scalar. x (array): Input array. """ device = backend.get_device(y) x = backend.to_device(x, device) a = backend.to_device(a, device) with device: if device == backend.cpu_device: _xpay(y, a, x, out=y) else: _xpay_cuda(a, x, y)
def _apply(self, input): device = backend.get_device(input) xp = device.xp mat = backend.to_device(self.mat, device) with device: if self.adjoint: mat = xp.conj(mat).swapaxes(-1, -2) return xp.matmul(input, mat)
def nufft(input, coord, oversamp=1.25, width=4.0, n=128): """Non-uniform Fast Fourier Transform. Args: input (array): input signal domain array of shape (..., n_{ndim - 1}, ..., n_1, n_0), where ndim is specified by coord.shape[-1]. The nufft is applied on the last ndim axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimensions to apply the nufft. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. n (int): number of sampling points of the interpolation kernel. Returns: array: Fourier domain data of shape input.shape[:-ndim] + coord.shape[:-1]. References: Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation IEEE Transactions on Signal Processing, 51(2), 560-574. Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. IEEE transactions on medical imaging, 24(6), 799-808. """ device = backend.get_device(input) ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 os_shape = _get_oversamp_shape(input.shape, ndim, oversamp) with device: output = input.copy() # Apodize _apodize(output, ndim, oversamp, width, beta) # Zero-pad output /= util.prod(input.shape[-ndim:])**0.5 output = util.resize(output, os_shape) # FFT output = fft(output, axes=range(-ndim, 0), norm=None) # Interpolate coord = _scale_coord(backend.to_device( coord, device), input.shape, oversamp) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.interpolate(output, width, kernel, coord) return output
def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4.0, n=128): """Adjoint non-uniform Fast Fourier Transform. Args: input (array): input Fourier domain array of shape (...) + coord.shape[:-1]. That is, the last dimensions of input must match the first dimensions of coord. The nufft_adjoint is applied on the last coord.ndim - 1 axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimension to apply nufft adjoint. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oshape (tuple of ints): output shape of the form (..., n_{ndim - 1}, ..., n_1, n_0). oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. n (int): number of sampling points of the interpolation kernel. Returns: array: signal domain array with shape specified by oshape. See Also: :func:`sigpy.nufft.nufft` """ device = backend.get_device(input) ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 if oshape is None: oshape = list(input.shape[:-coord.ndim + 1]) + estimate_shape(coord) else: oshape = list(oshape) os_shape = _get_oversamp_shape(oshape, ndim, oversamp) with device: # Gridding coord = _scale_coord(backend.to_device( coord, device), oshape, oversamp) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.gridding(input, os_shape, width, kernel, coord) # IFFT output = ifft(output, axes=range(-ndim, 0), norm=None) # Crop output = util.resize(output, oshape) output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5 # Apodize _apodize(output, ndim, oversamp, width, beta) return output
def asscalar(input): """Returns input array as scalar. Args: input (array): Input array Returns: scalar. """ return np.asscalar(backend.to_device(input, backend.cpu_device))
def fwt(input, wave_name='db4', axes=None, level=None): """Forward wavelet transform. Args: input (array): Input array. axes (None or tuple of int): Axes to perform wavelet transform. wave_name (str): Wavelet name. level (None or int): Number of wavelet levels. """ device = backend.get_device(input) input = backend.to_device(input, backend.cpu_device) zshape = [((i + 1) // 2) * 2 for i in input.shape] zinput = util.resize(input, zshape) coeffs = pywt.wavedecn(zinput, wave_name, mode='zero', axes=axes, level=level) output, _ = pywt.coeffs_to_array(coeffs, axes=axes) output = backend.to_device(output, device) return output
def iwt(input, oshape, coeff_slices, wave_name='db4', axes=None, level=None): """Inverse wavelet transform. Args: input (array): Input array. oshape (tuple of ints): Output shape. coeff_slices (list of slice): Slices to split coefficients. axes (None or tuple of int): Axes to perform wavelet transform. wave_name (str): Wavelet name. level (None or int): Number of wavelet levels. """ device = backend.get_device(input) input = backend.to_device(input, backend.cpu_device) input = pywt.array_to_coeffs(input, coeff_slices, output_format='wavedecn') output = pywt.waverecn(input, wave_name, mode='zero', axes=axes) output = util.resize(output, oshape) output = backend.to_device(output, device) return output
def _summarize(self): if self.save_objective_values: self.objective_values.append(self.objective()) if self.show_pbar: if self.save_objective_values: self.pbar.set_postfix( obj='{0:.2E}'.format(self.objective_values[-1])) else: self.pbar.set_postfix(resid='{0:.2E}'.format( backend.to_device(self.alg.resid, backend.cpu_device)))
def _apply(self, input): device = backend.get_device(input) x = backend.to_device(self.x, backend.get_device(input)) with device: x = x.astype(input.dtype, copy=False) return conv.convolve(x, input, mode=self.mode, input_multi_channel=self.input_multi_channel, output_multi_channel=self.output_multi_channel)
def _apply(self, input): device = backend.get_device(input) W = backend.to_device(self.W, backend.get_device(input)) with device: W = W.astype(input.dtype, copy=False) return conv.convolve_adjoint_input( W, input, mode=self.mode, input_multi_channel=self.input_multi_channel, output_multi_channel=self.output_multi_channel)
def interpolate(input, width, kernel, coord): """Interpolation from array to points specified by coordinates. Args: input (array): Input array of shape [..., ny, nx] width (float): Interpolation kernel width. kernel (array): Interpolation kernel. coord (array): Coordinate array of shape [..., ndim] Returns: output (array): Output array of coord.shape[:-1] """ ndim = coord.shape[-1] batch_shape = input.shape[:-ndim] batch_size = util.prod(batch_shape) pts_shape = coord.shape[:-1] npts = util.prod(pts_shape) device = backend.get_device(input) xp = device.xp isreal = np.issubdtype(input.dtype, np.floating) coord = backend.to_device(coord, device) kernel = backend.to_device(kernel, device) with device: input = input.reshape([batch_size] + list(input.shape[-ndim:])) coord = coord.reshape([npts, ndim]) output = xp.zeros([batch_size, npts], dtype=input.dtype) _interpolate = _select_interpolate(ndim, npts, device, isreal) if device == backend.cpu_device: _interpolate(output, input, width, kernel, coord) else: # pragma: no cover _interpolate(input, width, kernel, coord, output, size=npts) return output.reshape(batch_shape + pts_shape)
def xpay(y, a, x): """Compute y = x + a * y. Args: y (array): Output array. a (scalar or array): Input scalar. x (array): Input array. """ device = backend.get_device(y) xp = device.xp if xp == np: _xpay(y, a, x, out=y) else: if np.isscalar(a): a = backend.to_device(a, device) _xpay_cuda(a, x, y)