Пример #1
0
def _get_convolve_adjoint_input_params(W, y, input_multi_channel,
                                       output_multi_channel):
    ndim = W.ndim - input_multi_channel - output_multi_channel
    output_shape = y.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    batch_shape = y.shape[:-ndim - output_multi_channel]
    batch_size = util.prod(batch_shape)

    if y.dtype != W.dtype:
        raise TypeError(
            'y and W must have the same dtype, got {} and {}.'.format(
                y.dtype, W.dtype))

    if backend.get_device(y) != backend.get_device(W):
        raise TypeError(
            'y and W must be on the same device, got {} and {}.'.format(
                backend.get_device(y), backend.get_device(W)))

    if input_multi_channel:
        input_channel = W.shape[-ndim - 1]
    else:
        input_channel = 1

    if output_multi_channel:
        output_channel = y.shape[-ndim - 1]
    else:
        output_channel = 1

    return ndim, output_shape, filter_shape, batch_shape, \
        batch_size, input_channel, output_channel
Пример #2
0
def _get_convolve_adjoint_filter_params(x, y, ndim, input_multi_channel,
                                        output_multi_channel):
    output_shape = y.shape[-ndim:]
    input_shape = x.shape[-ndim:]
    batch_shape = y.shape[:-ndim - output_multi_channel]
    batch_size = util.prod(batch_shape)

    if x.dtype != y.dtype:
        raise TypeError(
            'x and y must have the same dtype, got {} and {}.'.format(
                x.dtype, y.dtype))

    if backend.get_device(y) != backend.get_device(x):
        raise TypeError(
            'y and x must be on the same device, got {} and {}.'.format(
                backend.get_device(y), backend.get_device(x)))

    if input_multi_channel:
        input_channel = x.shape[-ndim - 1]
    else:
        input_channel = 1

    if output_multi_channel:
        output_channel = y.shape[-ndim - 1]
    else:
        output_channel = 1

    return input_shape, output_shape, batch_shape, \
        batch_size, input_channel, output_channel
Пример #3
0
    def __init__(self,
                 A,
                 y,
                 x=None,
                 proxg=None,
                 lamda=0,
                 G=None,
                 g=None,
                 z=None,
                 solver=None,
                 max_iter=100,
                 P=None,
                 alpha=None,
                 max_power_iter=30,
                 accelerate=True,
                 tau=None,
                 sigma=None,
                 rho=1,
                 max_cg_iter=10,
                 tol=0,
                 save_objective_values=False,
                 show_pbar=True,
                 leave_pbar=True):
        self.A = A
        self.y = y
        self.x = x
        self.proxg = proxg
        self.lamda = lamda
        self.G = G
        self.g = g
        self.z = z
        self.solver = solver
        self.max_iter = max_iter
        self.P = P
        self.alpha = alpha
        self.max_power_iter = max_power_iter
        self.accelerate = accelerate
        self.tau = tau
        self.sigma = sigma
        self.rho = rho
        self.max_cg_iter = max_cg_iter
        self.tol = tol
        self.save_objective_values = save_objective_values
        self.show_pbar = show_pbar
        self.leave_pbar = leave_pbar

        self.y_device = backend.get_device(y)
        if self.x is None:
            with self.y_device:
                self.x = self.y_device.xp.zeros(A.ishape, dtype=y.dtype)

        self.x_device = backend.get_device(self.x)
        self._get_alg()
        if self.save_objective_values:
            self.objective_values = [self.objective()]

        super().__init__(self.alg, show_pbar=show_pbar, leave_pbar=leave_pbar)
Пример #4
0
    def _apply(self, input):
        device = backend.get_device(input)
        with device:
            input = device.xp.conj(input)

        output = self.A(input)

        device = backend.get_device(output)
        with device:
            return device.xp.conj(output)
Пример #5
0
    def __init__(self,
                 A,
                 y,
                 proxg,
                 eps,
                 x=None,
                 G=None,
                 max_iter=100,
                 tau=None,
                 sigma=None,
                 show_pbar=True):
        self.y = y
        self.x = x
        self.y_device = backend.get_device(y)
        if self.x is None:
            with self.y_device:
                self.x = self.y_device.xp.zeros(A.ishape, dtype=self.y.dtype)

        self.x_device = backend.get_device(self.x)
        if G is None:
            self.max_eig_app = MaxEig(A.H * A,
                                      dtype=self.x.dtype,
                                      device=self.x_device,
                                      show_pbar=show_pbar)

            proxfc = prox.Conj(prox.L2Proj(A.oshape, eps, y=y))
        else:
            proxf1 = prox.L2Proj(A.oshape, eps, y=y)
            proxf2 = proxg
            proxfc = prox.Conj(prox.Stack([proxf1, proxf2]))
            proxg = prox.NoOp(A.ishape)
            A = linop.Vstack([A, G])

        if tau is None or sigma is None:
            max_eig = MaxEig(A.H * A,
                             dtype=self.x.dtype,
                             device=self.x_device,
                             show_pbar=show_pbar).run()
            tau = 1
            sigma = 1 / max_eig

        with self.y_device:
            self.u = self.y_device.xp.zeros(A.oshape, dtype=self.y.dtype)

        alg = PrimalDualHybridGradient(proxfc,
                                       proxg,
                                       A,
                                       A.H,
                                       self.x,
                                       self.u,
                                       tau,
                                       sigma,
                                       max_iter=max_iter)

        super().__init__(alg, show_pbar=show_pbar)
Пример #6
0
    def _apply(self, input):
        device = backend.get_device(input)
        x = backend.to_device(self.x, backend.get_device(input))
        with device:
            x = x.astype(input.dtype, copy=False)

        return conv.convolve(x,
                             input,
                             mode=self.mode,
                             input_multi_channel=self.input_multi_channel,
                             output_multi_channel=self.output_multi_channel)
Пример #7
0
    def _apply(self, input):
        device = backend.get_device(input)
        W = backend.to_device(self.W, backend.get_device(input))
        with device:
            W = W.astype(input.dtype, copy=False)

        return conv.convolve_adjoint_input(
            W,
            input,
            mode=self.mode,
            input_multi_channel=self.input_multi_channel,
            output_multi_channel=self.output_multi_channel)
Пример #8
0
    def __init__(self,
                 A,
                 y,
                 x=None,
                 proxg=None,
                 lamda=0,
                 G=None,
                 g=None,
                 R=None,
                 mu=0,
                 z=0,
                 alg_name=None,
                 max_iter=100,
                 P=None,
                 alpha=None,
                 max_power_iter=30,
                 accelerate=True,
                 tau=None,
                 sigma=None,
                 save_objective_values=False,
                 show_pbar=True):
        self.A = A
        self.y = y
        self.x = x
        self.proxg = proxg
        self.lamda = lamda
        self.G = G
        self.g = g
        self.R = R
        self.mu = mu
        self.z = z
        self.alg_name = alg_name
        self.max_iter = max_iter
        self.P = P
        self.alpha = alpha
        self.max_power_iter = max_power_iter
        self.accelerate = accelerate
        self.tau = tau
        self.sigma = sigma
        self.save_objective_values = save_objective_values
        self.show_pbar = show_pbar

        self.y_device = backend.get_device(y)
        if self.x is None:
            with self.y_device:
                self.x = self.y_device.xp.zeros(A.ishape, dtype=y.dtype)

        self.x_device = backend.get_device(self.x)
        self._get_alg()
        if self.save_objective_values:
            self.objective_values = []
Пример #9
0
    def __init__(self,
                 proxfc,
                 proxg,
                 A,
                 AH,
                 x,
                 u,
                 tau,
                 sigma,
                 theta=1,
                 gamma_primal=0,
                 gamma_dual=0,
                 max_iter=100,
                 tol=0):
        self.proxfc = proxfc
        self.proxg = proxg
        self.tol = tol

        self.A = A
        self.AH = AH

        self.u = u
        self.x = x

        self.tau = tau
        self.sigma = sigma
        self.theta = theta
        self.gamma_primal = gamma_primal
        self.gamma_dual = gamma_dual

        self.x_device = backend.get_device(x)
        self.u_device = backend.get_device(u)

        with self.x_device:
            self.x_ext = self.x.copy()

        if self.gamma_primal > 0:
            xp = self.x_device.xp
            with self.x_device:
                self.tau_min = xp.amin(xp.abs(tau)).item()

        if self.gamma_dual > 0:
            xp = self.u_device.xp
            with self.u_device:
                self.sigma_min = xp.amin(xp.abs(sigma)).item()

        self.resid = np.infty

        super().__init__(max_iter)
Пример #10
0
 def _apply(self, input):
     data = backend.to_device(self.data, backend.get_device(input))
     return conv.convolve(data,
                          input,
                          mode=self.mode,
                          strides=self.strides,
                          multi_channel=self.multi_channel)
Пример #11
0
 def _apply(self, input):
     device = backend.get_device(input)
     filt = backend.to_device(self.filt, device)
     with device:
         return conv.convolve(input, filt, mode=self.mode,
                              strides=self.strides,
                              multi_channel=self.multi_channel)
Пример #12
0
def to_pytorch(array, requires_grad=True):  # pragma: no cover
    """Zero-copy conversion from numpy/cupy array to pytorch tensor.

    For complex array input, returns a tensor with shape + [2],
    where tensor[..., 0] and tensor[..., 1] represent the real
    and imaginary.

    Args:
        array (numpy/cupy array): input.

    Returns:
        PyTorch tensor.

    """
    import torch
    from torch.utils.dlpack import from_dlpack

    device = backend.get_device(array)
    if not np.issubdtype(array.dtype, np.floating):
        with device:
            shape = array.shape
            array = array.view(dtype=array.real.dtype)
            array = array.reshape(shape + (2, ))

    if device == backend.cpu_device:
        tensor = torch.from_numpy(array)
    else:
        tensor = from_dlpack(array.toDlpack())

    tensor.requires_grad = requires_grad
    return tensor
Пример #13
0
 def _update(self):
     y = self.A(self.x)
     device = backend.get_device(y)
     xp = device.xp
     with device:
         self.max_eig = util.asscalar(xp.linalg.norm(y))
         backend.copyto(self.x, y / self.max_eig)
Пример #14
0
    def __init__(self,
                 gradf,
                 x,
                 alpha,
                 proxg=None,
                 f=None,
                 beta=1,
                 accelerate=False,
                 max_iter=100,
                 tol=0):
        if beta < 1 and f is None:
            raise TypeError(
                "Cannot do backtracking linesearch without specifying f.")

        self.gradf = gradf
        self.alpha = alpha
        self.f = f
        self.beta = beta
        self.accelerate = accelerate
        self.proxg = proxg
        self.x = x
        self.tol = tol

        self.device = backend.get_device(x)
        with self.device:
            if self.accelerate:
                self.z = self.x.copy()
                self.t = 1

        self.resid = np.infty
        super().__init__(max_iter)
Пример #15
0
    def _apply(self, input):
        device = backend.get_device(input)
        output = 0
        with device:
            for n, linop in enumerate(self.linops):
                if n == 0:
                    start = 0
                else:
                    start = self.indices[n - 1]

                if n == self.nops - 1:
                    end = None
                else:
                    end = self.indices[n]

                if self.axis is None:
                    output += linop(input[start:end].reshape(linop.ishape))
                else:
                    ndim = len(linop.ishape)
                    axis = self.axis % ndim

                    slc = tuple([slice(None)] * axis + [slice(start, end)] +
                                [slice(None)] * (ndim - axis - 1))

                    output += linop(input[slc])

        return output
Пример #16
0
 def _apply(self, input):
     device = backend.get_device(input)
     with device:
         coord = backend.to_device(self.coord, device)
         return interp.interpolate(input, coord,
                                   kernel=self.kernel,
                                   width=self.width, param=self.param)
Пример #17
0
def monte_carlo_sure(f, y, sigma, eps=1e-10):
    """Monte Carlo Stein Unbiased Risk Estimator (SURE).

    Monte carlo SURE assumes the observation y = x + e,
    where e is a white Gaussian array with standard deviation sigma.
    Monte carlo SURE provides an unbiased estimate of mean-squared error, ie:
    1 / n || f(y) - x ||_2^2

    Args:
        f (function): x -> f(x).
        y (array): observed measurement.
        sigma (float): noise standard deviation.

    Returns:
       float: SURE.

    References:
        Ramani, S., Blu, T. and Unser, M. 2008.
        Monte-Carlo Sure: A Black-Box Optimization of Regularization Parameters
        for General Denoising Algorithms. IEEE Transactions on Image Processing
        17, 9 (2008), 1540-1554.
    """
    device = backend.get_device(y)
    xp = device.xp

    n = y.size
    f_y = f(y)
    b = randn(y.shape, dtype=y.dtype, device=device)
    with device:
        divf_y = xp.real(xp.vdot(b, (f(y + eps * b) - f_y))) / eps
        sure = xp.mean(xp.abs(y - f_y)**2) - sigma**2 + \
            2 * sigma**2 * divf_y / n

    return sure
Пример #18
0
def flip(input, axes=None):
    """Flip input.

    Args:
        input (array): Input array.
        axes (None or tuple of ints): Axes to perform operation.

    Returns:
        array: Flipped result.
    """

    if axes is None:
        axes = range(input.ndim)
    else:
        axes = _normalize_axes(axes, input.ndim)

    slc = []
    for d in range(input.ndim):
        if d in axes:
            slc.append(slice(None, None, -1))
        else:
            slc.append(slice(None))

    slc = tuple(slc)
    device = backend.get_device(input)
    with device:
        output = input[slc]

    return output
Пример #19
0
def l1_proj(eps, input):
    """Projection onto L1 ball.

    Args:
        eps (float, or array): L1 ball scaling.
        input (array)

    Returns:
        array: Result.

    References:
        J. Duchi, S. Shalev-Shwartz, and Y. Singer, "Efficient projections onto
        the l1-ball for learning in high dimensions" 2008.

    """
    device = backend.get_device(input)
    xp = device.xp

    with device:
        shape = input.shape
        input = input.ravel()

        if xp.linalg.norm(input, 1) < eps:
            return input
        else:
            shape = len(input)
            s = xp.sort(xp.abs(input))[::-1]
            st = (xp.cumsum(s) - eps) / (xp.arange(shape) + 1)
            idx = xp.flatnonzero((s - st) > 0).max()
            return soft_thresh(st[idx], input.reshape(shape))
Пример #20
0
def soft_thresh(lamda, input):
    r"""Soft threshold.

    Performs:

    .. math:: 
        (| x | - \lambda)_+  \text{sgn}(x)

    Args:
        lamda (float, or array): Threshold parameter.
        input (array)

    Returns:
        array: soft-thresholded result.

    """
    device = backend.get_device(input)
    xp = device.xp

    lamda = xp.real(lamda)
    with device:
        if device == backend.cpu_device:
            output = _soft_thresh(lamda, input)
        else:
            output = _soft_thresh_cuda(lamda, input)

        if np.issubdtype(input.dtype, np.floating):
            output = xp.real(output)

    return output
Пример #21
0
 def _apply(self, input):
     device = backend.get_device(input)
     with device:
         coord = backend.to_device(self.coord, device)
         return fourier.nufft_adjoint(
             input, coord, self.oshape,
             oversamp=self.oversamp, width=self.width)
Пример #22
0
def soft_thresh(lamda, input):
    r"""Soft threshold.

    Performs:

    .. math::
        (| x | - \lambda)_+  \text{sgn}(x)

    Args:
        lamda (float, or array): Threshold parameter.
        input (array)

    Returns:
        array: soft-thresholded result.

    """
    device = backend.get_device(input)
    xp = device.xp
    if xp == np:
        return _soft_thresh(lamda, input)
    else:  # pragma: no cover
        if np.isscalar(lamda):
            lamda = backend.to_device(lamda, device)

        return _soft_thresh_cuda(lamda, input)
Пример #23
0
    def _apply(self, input):
        output = 0
        with backend.get_device(output):
            for linop in self.linops:
                output += linop(input)

        return output
Пример #24
0
    def apply(self, input):
        self._check_domain(input)
        with backend.get_device(input):
            output = self._apply(input)

        self._check_codomain(output)
        return output
Пример #25
0
    def _apply(self, input):
        device = backend.get_device(input)
        xp = device.xp
        with device:
            output = xp.empty(self.oshape, dtype=input.dtype)
            for n, linop in enumerate(self.linops):
                if n == 0:
                    start = 0
                else:
                    start = self.indices[n - 1]

                if n == self.nops - 1:
                    end = None
                else:
                    end = self.indices[n]

                if self.axis is None:
                    output[start:end] = linop(input).ravel()
                else:
                    ndim = len(linop.oshape)
                    axis = self.axis % ndim
                    slc = tuple([slice(None)] * axis + [slice(start, end)] +
                                [slice(None)] * (ndim - axis - 1))
                    output[slc] = linop(input)

        return output
Пример #26
0
    def _cudnn_convolve_adjoint_filter(x, y, mode='full'):
        dtype = y.dtype
        device = backend.get_device(y)
        xp = device.xp
        if np.issubdtype(dtype, np.complexfloating):
            with device:
                xr = xp.real(x)
                xi = xp.imag(x)
                yr = xp.real(y)
                yi = xp.imag(y)

                # Concatenate real and imaginary to input/output channels
                x = xp.concatenate([xr, xi], axis=1)
                y = xp.concatenate([yr, yi], axis=1)

                W = _cudnn_convolve_adjoint_filter(x, y, mode=mode)

                # Convert back to complex
                Wr = W[:W.shape[0] // 2, :W.shape[1] // 2]
                Wr += W[W.shape[0] // 2:, W.shape[1] // 2:]
                Wi = W[W.shape[0] // 2:, :W.shape[1] // 2]
                Wi -= W[:W.shape[0] // 2, W.shape[1] // 2:]
                return (Wr + 1j * Wi).astype(dtype)

        ndim = y.ndim - 2
        input_channel = x.shape[1]
        output_channel = y.shape[1]
        input_shape = x.shape[-ndim:]
        output_shape = y.shape[-ndim:]
        strides = (1, ) * ndim
        dilations = (1, ) * ndim
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            filter_shape = tuple(p - m + 1
                                 for m, p in zip(input_shape, output_shape))
            pads = tuple(n - 1 for n in filter_shape)
        elif mode == 'valid':
            filter_shape = tuple(m - p + 1
                                 for m, p in zip(input_shape, output_shape))
            pads = (0, ) * ndim

        with device:
            W = xp.empty((output_channel, input_channel) + filter_shape,
                         dtype=dtype)
            cudnn.convolution_backward_filter(x,
                                              y,
                                              W,
                                              pads,
                                              strides,
                                              dilations,
                                              groups,
                                              deterministic=deterministic,
                                              auto_tune=auto_tune,
                                              tensor_core=tensor_core)
            W = util.flip(W, axes=range(-ndim, 0))

        return W
Пример #27
0
    def __init__(self, A, b, x, P=None, max_iter=100, tol=0):
        self.A = A
        self.P = P
        self.x = x
        self.tol = tol
        self.device = backend.get_device(x)
        with self.device:
            xp = self.device.xp
            self.r = b - self.A(self.x)

            if self.P is None:
                z = self.r
            else:
                z = self.P(self.r)

            if max_iter > 1:
                self.p = z.copy()
            else:
                self.p = z

            self.not_positive_definite = False
            self.rzold = xp.real(xp.vdot(self.r, z))
            self.resid = self.rzold.item()**0.5

        super().__init__(max_iter)
Пример #28
0
    def _apply(self, input):
        with backend.get_device(input):
            if (np.issubdtype(self.idtype, np.complexfloating)
                    and not np.issubdtype(self.odtype, np.complexfloating)):
                input = input.real

            return input.astype(self.odtype)
Пример #29
0
def ifft(input, oshape=None, axes=None, center=True, norm='ortho'):
    """IFFT function that supports centering.

    Args:
        input (array): input array.
        oshape (None or array of ints): output shape.
        axes (None or array of ints): Axes over which to compute
            the inverse FFT.
        norm (None or ``"ortho"``): Keyword to specify the normalization mode.

    Returns:
        array of dimension oshape.

    See Also:
        :func:`numpy.fft.ifftn`

    """
    device = backend.get_device(input)
    xp = device.xp

    with device:
        if not np.issubdtype(input.dtype, np.complexfloating):
            input = input.astype(np.complex)

        if center:
            output = _ifftc(input, oshape=oshape, axes=axes, norm=norm)
        else:
            output = xp.fft.ifftn(input, s=oshape, axes=axes, norm=norm)

        if np.issubdtype(input.dtype,
                         np.complexfloating) and input.dtype != output.dtype:
            output = output.astype(input.dtype)

        return output
Пример #30
0
 def _apply(self, input):
     device = backend.get_device(input)
     with device:
         return fourier.nufft(input,
                              self.coord,
                              oversamp=self.oversamp,
                              width=self.width)