def _get_convolve_adjoint_input_params(W, y, input_multi_channel, output_multi_channel): ndim = W.ndim - input_multi_channel - output_multi_channel output_shape = y.shape[-ndim:] filter_shape = W.shape[-ndim:] batch_shape = y.shape[:-ndim - output_multi_channel] batch_size = util.prod(batch_shape) if y.dtype != W.dtype: raise TypeError( 'y and W must have the same dtype, got {} and {}.'.format( y.dtype, W.dtype)) if backend.get_device(y) != backend.get_device(W): raise TypeError( 'y and W must be on the same device, got {} and {}.'.format( backend.get_device(y), backend.get_device(W))) if input_multi_channel: input_channel = W.shape[-ndim - 1] else: input_channel = 1 if output_multi_channel: output_channel = y.shape[-ndim - 1] else: output_channel = 1 return ndim, output_shape, filter_shape, batch_shape, \ batch_size, input_channel, output_channel
def _get_convolve_adjoint_filter_params(x, y, ndim, input_multi_channel, output_multi_channel): output_shape = y.shape[-ndim:] input_shape = x.shape[-ndim:] batch_shape = y.shape[:-ndim - output_multi_channel] batch_size = util.prod(batch_shape) if x.dtype != y.dtype: raise TypeError( 'x and y must have the same dtype, got {} and {}.'.format( x.dtype, y.dtype)) if backend.get_device(y) != backend.get_device(x): raise TypeError( 'y and x must be on the same device, got {} and {}.'.format( backend.get_device(y), backend.get_device(x))) if input_multi_channel: input_channel = x.shape[-ndim - 1] else: input_channel = 1 if output_multi_channel: output_channel = y.shape[-ndim - 1] else: output_channel = 1 return input_shape, output_shape, batch_shape, \ batch_size, input_channel, output_channel
def __init__(self, A, y, x=None, proxg=None, lamda=0, G=None, g=None, z=None, solver=None, max_iter=100, P=None, alpha=None, max_power_iter=30, accelerate=True, tau=None, sigma=None, rho=1, max_cg_iter=10, tol=0, save_objective_values=False, show_pbar=True, leave_pbar=True): self.A = A self.y = y self.x = x self.proxg = proxg self.lamda = lamda self.G = G self.g = g self.z = z self.solver = solver self.max_iter = max_iter self.P = P self.alpha = alpha self.max_power_iter = max_power_iter self.accelerate = accelerate self.tau = tau self.sigma = sigma self.rho = rho self.max_cg_iter = max_cg_iter self.tol = tol self.save_objective_values = save_objective_values self.show_pbar = show_pbar self.leave_pbar = leave_pbar self.y_device = backend.get_device(y) if self.x is None: with self.y_device: self.x = self.y_device.xp.zeros(A.ishape, dtype=y.dtype) self.x_device = backend.get_device(self.x) self._get_alg() if self.save_objective_values: self.objective_values = [self.objective()] super().__init__(self.alg, show_pbar=show_pbar, leave_pbar=leave_pbar)
def _apply(self, input): device = backend.get_device(input) with device: input = device.xp.conj(input) output = self.A(input) device = backend.get_device(output) with device: return device.xp.conj(output)
def __init__(self, A, y, proxg, eps, x=None, G=None, max_iter=100, tau=None, sigma=None, show_pbar=True): self.y = y self.x = x self.y_device = backend.get_device(y) if self.x is None: with self.y_device: self.x = self.y_device.xp.zeros(A.ishape, dtype=self.y.dtype) self.x_device = backend.get_device(self.x) if G is None: self.max_eig_app = MaxEig(A.H * A, dtype=self.x.dtype, device=self.x_device, show_pbar=show_pbar) proxfc = prox.Conj(prox.L2Proj(A.oshape, eps, y=y)) else: proxf1 = prox.L2Proj(A.oshape, eps, y=y) proxf2 = proxg proxfc = prox.Conj(prox.Stack([proxf1, proxf2])) proxg = prox.NoOp(A.ishape) A = linop.Vstack([A, G]) if tau is None or sigma is None: max_eig = MaxEig(A.H * A, dtype=self.x.dtype, device=self.x_device, show_pbar=show_pbar).run() tau = 1 sigma = 1 / max_eig with self.y_device: self.u = self.y_device.xp.zeros(A.oshape, dtype=self.y.dtype) alg = PrimalDualHybridGradient(proxfc, proxg, A, A.H, self.x, self.u, tau, sigma, max_iter=max_iter) super().__init__(alg, show_pbar=show_pbar)
def _apply(self, input): device = backend.get_device(input) x = backend.to_device(self.x, backend.get_device(input)) with device: x = x.astype(input.dtype, copy=False) return conv.convolve(x, input, mode=self.mode, input_multi_channel=self.input_multi_channel, output_multi_channel=self.output_multi_channel)
def _apply(self, input): device = backend.get_device(input) W = backend.to_device(self.W, backend.get_device(input)) with device: W = W.astype(input.dtype, copy=False) return conv.convolve_adjoint_input( W, input, mode=self.mode, input_multi_channel=self.input_multi_channel, output_multi_channel=self.output_multi_channel)
def __init__(self, A, y, x=None, proxg=None, lamda=0, G=None, g=None, R=None, mu=0, z=0, alg_name=None, max_iter=100, P=None, alpha=None, max_power_iter=30, accelerate=True, tau=None, sigma=None, save_objective_values=False, show_pbar=True): self.A = A self.y = y self.x = x self.proxg = proxg self.lamda = lamda self.G = G self.g = g self.R = R self.mu = mu self.z = z self.alg_name = alg_name self.max_iter = max_iter self.P = P self.alpha = alpha self.max_power_iter = max_power_iter self.accelerate = accelerate self.tau = tau self.sigma = sigma self.save_objective_values = save_objective_values self.show_pbar = show_pbar self.y_device = backend.get_device(y) if self.x is None: with self.y_device: self.x = self.y_device.xp.zeros(A.ishape, dtype=y.dtype) self.x_device = backend.get_device(self.x) self._get_alg() if self.save_objective_values: self.objective_values = []
def __init__(self, proxfc, proxg, A, AH, x, u, tau, sigma, theta=1, gamma_primal=0, gamma_dual=0, max_iter=100, tol=0): self.proxfc = proxfc self.proxg = proxg self.tol = tol self.A = A self.AH = AH self.u = u self.x = x self.tau = tau self.sigma = sigma self.theta = theta self.gamma_primal = gamma_primal self.gamma_dual = gamma_dual self.x_device = backend.get_device(x) self.u_device = backend.get_device(u) with self.x_device: self.x_ext = self.x.copy() if self.gamma_primal > 0: xp = self.x_device.xp with self.x_device: self.tau_min = xp.amin(xp.abs(tau)).item() if self.gamma_dual > 0: xp = self.u_device.xp with self.u_device: self.sigma_min = xp.amin(xp.abs(sigma)).item() self.resid = np.infty super().__init__(max_iter)
def _apply(self, input): data = backend.to_device(self.data, backend.get_device(input)) return conv.convolve(data, input, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def _apply(self, input): device = backend.get_device(input) filt = backend.to_device(self.filt, device) with device: return conv.convolve(input, filt, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def to_pytorch(array, requires_grad=True): # pragma: no cover """Zero-copy conversion from numpy/cupy array to pytorch tensor. For complex array input, returns a tensor with shape + [2], where tensor[..., 0] and tensor[..., 1] represent the real and imaginary. Args: array (numpy/cupy array): input. Returns: PyTorch tensor. """ import torch from torch.utils.dlpack import from_dlpack device = backend.get_device(array) if not np.issubdtype(array.dtype, np.floating): with device: shape = array.shape array = array.view(dtype=array.real.dtype) array = array.reshape(shape + (2, )) if device == backend.cpu_device: tensor = torch.from_numpy(array) else: tensor = from_dlpack(array.toDlpack()) tensor.requires_grad = requires_grad return tensor
def _update(self): y = self.A(self.x) device = backend.get_device(y) xp = device.xp with device: self.max_eig = util.asscalar(xp.linalg.norm(y)) backend.copyto(self.x, y / self.max_eig)
def __init__(self, gradf, x, alpha, proxg=None, f=None, beta=1, accelerate=False, max_iter=100, tol=0): if beta < 1 and f is None: raise TypeError( "Cannot do backtracking linesearch without specifying f.") self.gradf = gradf self.alpha = alpha self.f = f self.beta = beta self.accelerate = accelerate self.proxg = proxg self.x = x self.tol = tol self.device = backend.get_device(x) with self.device: if self.accelerate: self.z = self.x.copy() self.t = 1 self.resid = np.infty super().__init__(max_iter)
def _apply(self, input): device = backend.get_device(input) output = 0 with device: for n, linop in enumerate(self.linops): if n == 0: start = 0 else: start = self.indices[n - 1] if n == self.nops - 1: end = None else: end = self.indices[n] if self.axis is None: output += linop(input[start:end].reshape(linop.ishape)) else: ndim = len(linop.ishape) axis = self.axis % ndim slc = tuple([slice(None)] * axis + [slice(start, end)] + [slice(None)] * (ndim - axis - 1)) output += linop(input[slc]) return output
def _apply(self, input): device = backend.get_device(input) with device: coord = backend.to_device(self.coord, device) return interp.interpolate(input, coord, kernel=self.kernel, width=self.width, param=self.param)
def monte_carlo_sure(f, y, sigma, eps=1e-10): """Monte Carlo Stein Unbiased Risk Estimator (SURE). Monte carlo SURE assumes the observation y = x + e, where e is a white Gaussian array with standard deviation sigma. Monte carlo SURE provides an unbiased estimate of mean-squared error, ie: 1 / n || f(y) - x ||_2^2 Args: f (function): x -> f(x). y (array): observed measurement. sigma (float): noise standard deviation. Returns: float: SURE. References: Ramani, S., Blu, T. and Unser, M. 2008. Monte-Carlo Sure: A Black-Box Optimization of Regularization Parameters for General Denoising Algorithms. IEEE Transactions on Image Processing 17, 9 (2008), 1540-1554. """ device = backend.get_device(y) xp = device.xp n = y.size f_y = f(y) b = randn(y.shape, dtype=y.dtype, device=device) with device: divf_y = xp.real(xp.vdot(b, (f(y + eps * b) - f_y))) / eps sure = xp.mean(xp.abs(y - f_y)**2) - sigma**2 + \ 2 * sigma**2 * divf_y / n return sure
def flip(input, axes=None): """Flip input. Args: input (array): Input array. axes (None or tuple of ints): Axes to perform operation. Returns: array: Flipped result. """ if axes is None: axes = range(input.ndim) else: axes = _normalize_axes(axes, input.ndim) slc = [] for d in range(input.ndim): if d in axes: slc.append(slice(None, None, -1)) else: slc.append(slice(None)) slc = tuple(slc) device = backend.get_device(input) with device: output = input[slc] return output
def l1_proj(eps, input): """Projection onto L1 ball. Args: eps (float, or array): L1 ball scaling. input (array) Returns: array: Result. References: J. Duchi, S. Shalev-Shwartz, and Y. Singer, "Efficient projections onto the l1-ball for learning in high dimensions" 2008. """ device = backend.get_device(input) xp = device.xp with device: shape = input.shape input = input.ravel() if xp.linalg.norm(input, 1) < eps: return input else: shape = len(input) s = xp.sort(xp.abs(input))[::-1] st = (xp.cumsum(s) - eps) / (xp.arange(shape) + 1) idx = xp.flatnonzero((s - st) > 0).max() return soft_thresh(st[idx], input.reshape(shape))
def soft_thresh(lamda, input): r"""Soft threshold. Performs: .. math:: (| x | - \lambda)_+ \text{sgn}(x) Args: lamda (float, or array): Threshold parameter. input (array) Returns: array: soft-thresholded result. """ device = backend.get_device(input) xp = device.xp lamda = xp.real(lamda) with device: if device == backend.cpu_device: output = _soft_thresh(lamda, input) else: output = _soft_thresh_cuda(lamda, input) if np.issubdtype(input.dtype, np.floating): output = xp.real(output) return output
def _apply(self, input): device = backend.get_device(input) with device: coord = backend.to_device(self.coord, device) return fourier.nufft_adjoint( input, coord, self.oshape, oversamp=self.oversamp, width=self.width)
def soft_thresh(lamda, input): r"""Soft threshold. Performs: .. math:: (| x | - \lambda)_+ \text{sgn}(x) Args: lamda (float, or array): Threshold parameter. input (array) Returns: array: soft-thresholded result. """ device = backend.get_device(input) xp = device.xp if xp == np: return _soft_thresh(lamda, input) else: # pragma: no cover if np.isscalar(lamda): lamda = backend.to_device(lamda, device) return _soft_thresh_cuda(lamda, input)
def _apply(self, input): output = 0 with backend.get_device(output): for linop in self.linops: output += linop(input) return output
def apply(self, input): self._check_domain(input) with backend.get_device(input): output = self._apply(input) self._check_codomain(output) return output
def _apply(self, input): device = backend.get_device(input) xp = device.xp with device: output = xp.empty(self.oshape, dtype=input.dtype) for n, linop in enumerate(self.linops): if n == 0: start = 0 else: start = self.indices[n - 1] if n == self.nops - 1: end = None else: end = self.indices[n] if self.axis is None: output[start:end] = linop(input).ravel() else: ndim = len(linop.oshape) axis = self.axis % ndim slc = tuple([slice(None)] * axis + [slice(start, end)] + [slice(None)] * (ndim - axis - 1)) output[slc] = linop(input) return output
def _cudnn_convolve_adjoint_filter(x, y, mode='full'): dtype = y.dtype device = backend.get_device(y) xp = device.xp if np.issubdtype(dtype, np.complexfloating): with device: xr = xp.real(x) xi = xp.imag(x) yr = xp.real(y) yi = xp.imag(y) # Concatenate real and imaginary to input/output channels x = xp.concatenate([xr, xi], axis=1) y = xp.concatenate([yr, yi], axis=1) W = _cudnn_convolve_adjoint_filter(x, y, mode=mode) # Convert back to complex Wr = W[:W.shape[0] // 2, :W.shape[1] // 2] Wr += W[W.shape[0] // 2:, W.shape[1] // 2:] Wi = W[W.shape[0] // 2:, :W.shape[1] // 2] Wi -= W[:W.shape[0] // 2, W.shape[1] // 2:] return (Wr + 1j * Wi).astype(dtype) ndim = y.ndim - 2 input_channel = x.shape[1] output_channel = y.shape[1] input_shape = x.shape[-ndim:] output_shape = y.shape[-ndim:] strides = (1, ) * ndim dilations = (1, ) * ndim groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': filter_shape = tuple(p - m + 1 for m, p in zip(input_shape, output_shape)) pads = tuple(n - 1 for n in filter_shape) elif mode == 'valid': filter_shape = tuple(m - p + 1 for m, p in zip(input_shape, output_shape)) pads = (0, ) * ndim with device: W = xp.empty((output_channel, input_channel) + filter_shape, dtype=dtype) cudnn.convolution_backward_filter(x, y, W, pads, strides, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) W = util.flip(W, axes=range(-ndim, 0)) return W
def __init__(self, A, b, x, P=None, max_iter=100, tol=0): self.A = A self.P = P self.x = x self.tol = tol self.device = backend.get_device(x) with self.device: xp = self.device.xp self.r = b - self.A(self.x) if self.P is None: z = self.r else: z = self.P(self.r) if max_iter > 1: self.p = z.copy() else: self.p = z self.not_positive_definite = False self.rzold = xp.real(xp.vdot(self.r, z)) self.resid = self.rzold.item()**0.5 super().__init__(max_iter)
def _apply(self, input): with backend.get_device(input): if (np.issubdtype(self.idtype, np.complexfloating) and not np.issubdtype(self.odtype, np.complexfloating)): input = input.real return input.astype(self.odtype)
def ifft(input, oshape=None, axes=None, center=True, norm='ortho'): """IFFT function that supports centering. Args: input (array): input array. oshape (None or array of ints): output shape. axes (None or array of ints): Axes over which to compute the inverse FFT. norm (None or ``"ortho"``): Keyword to specify the normalization mode. Returns: array of dimension oshape. See Also: :func:`numpy.fft.ifftn` """ device = backend.get_device(input) xp = device.xp with device: if not np.issubdtype(input.dtype, np.complexfloating): input = input.astype(np.complex) if center: output = _ifftc(input, oshape=oshape, axes=axes, norm=norm) else: output = xp.fft.ifftn(input, s=oshape, axes=axes, norm=norm) if np.issubdtype(input.dtype, np.complexfloating) and input.dtype != output.dtype: output = output.astype(input.dtype) return output
def _apply(self, input): device = backend.get_device(input) with device: return fourier.nufft(input, self.coord, oversamp=self.oversamp, width=self.width)