示例#1
0
 def __init__(self,
              diag,
              dims=None,
              dir=0,
              device='cpu',
              togpu=(False, False),
              tocpu=(False, False),
              dtype=torch.float32):
     if not isinstance(diag, (torch.Tensor, ComplexTensor)):
         self.complex = True if np.iscomplexobj(diag) else False
         self.diag = \
             complextorch_fromnumpy(diag.flatten()) if self.complex \
                 else torch.from_numpy(diag.flatten())
     else:
         self.complex = True if isinstance(diag, ComplexTensor) else False
         self.diag = flatten(diag) if self.complex else diag.flatten()
     if dims is None:
         self.shape = (len(self.diag), len(self.diag))
         self.dims = None
         self.reshape = False
     else:
         diagdims = [1] * len(dims)
         diagdims[dir] = dims[dir]
         self.diag = reshape(self.diag, diagdims) if self.complex \
             else self.diag.reshape(diagdims)
         self.shape = (np.prod(dims), np.prod(dims))
         self.dims = dims
         self.reshape = True
     self.device = device
     self.togpu = togpu
     self.tocpu = tocpu
     self.dtype = dtype
     self.explicit = False
     self.Op = None
示例#2
0
 def _matvec(self, x):
     if not self.inplace:
         if self.complex:
             x = x.__graph_copy__(x.real, x.imag)
         else:
             x = x.clone()
     if self.shape[0] == self.shape[1]:
         y = x
     elif self.shape[0] < self.shape[1]:
         if self.complex:
             y = x[:, :self.shape[0]]
         else:
             y = x[:self.shape[0]]
     else:
         if self.complex:
             y = complextorch_fromnumpy(
                 np.zeros(self.shape[0], dtype=self.npdtype))
             y[:, :self.shape[1]] = x
         else:
             y = torch.zeros(self.shape[0], dtype=self.dtype)
             y[:self.shape[1]] = x
     return y
示例#3
0
def dottest(Op,
            nr,
            nc,
            tol=1e-6,
            dtype=torch.float32,
            complexflag=0,
            device='cpu',
            raiseerror=True,
            verb=False):
    r"""Dot test.

    Generate random vectors :math:`\mathbf{u}` and :math:`\mathbf{v}`
    and perform dot-test to verify the validity of forward and adjoint operators.
    This test can help to detect errors in the operator implementation.

    Parameters
    ----------
    Op : :obj:`torch.Tensor`
        Linear operator to test.
    nr : :obj:`int`
        Number of rows of operator (i.e., elements in data)
    nc : :obj:`int`
        Number of columns of operator (i.e., elements in model)
    tol : :obj:`float`, optional
        Dottest tolerance
    dtype : :obj:`torch.dtype`, optional
        Type of elements in random vectors
    complexflag : :obj:`bool`, optional
        generate random vectors with real (0) or complex numbers
        (1: only model, 2: only data, 3:both)
    device : :obj:`str`, optional
        Device to be used
    raiseerror : :obj:`bool`, optional
        Raise error or simply return ``False`` when dottest fails
    verb : :obj:`bool`, optional
        Verbosity

    Raises
    ------
    ValueError
        If dot-test is not verified within chosen tolerance.

    Notes
    -----
    A dot-test is mathematical tool used in the development of numerical
    linear operators.

    More specifically, a correct implementation of forward and adjoint for
    a linear operator should verify the the following *equality*
    within a numerical tolerance:

    .. math::
        (\mathbf{Op}*\mathbf{u})^H*\mathbf{v} =
        \mathbf{u}^H*(\mathbf{Op}^H*\mathbf{v})

    """
    np_dtype = torch.ones(1, dtype=torch.float32).numpy().dtype
    if complexflag in (0, 2):
        u = torch.randn(nc, dtype=dtype)
    else:
        u = complextorch_fromnumpy(
            np.random.randn(nc).astype(np_dtype) +
            1j * np.random.randn(nc).astype(np_dtype))

    if complexflag in (0, 1):
        v = torch.randn(nr, dtype=dtype)
    else:
        v = complextorch_fromnumpy(np.random.randn(nr).astype(np_dtype) + \
                                   1j*np.random.randn(nr).astype(np_dtype))
    u, v = u.to(device), v.to(device)

    y = Op.matvec(u)  # Op * u
    x = Op.rmatvec(v)  # Op'* v

    if complexflag == 0:
        yy = torch.dot(y, v)  # (Op  * u)' * v
        xx = torch.dot(u, x)  # u' * (Op' * v)
    else:
        yy = np.vdot(y, v)  # (Op  * u)' * v
        xx = np.vdot(u, x)  # u' * (Op' * v)

    if complexflag == 0:
        if torch.abs((yy - xx) / ((yy + xx + 1e-15) / 2)) < tol:
            if verb:
                print('Dot test passed, v^T(Opu)=%f - u^T(Op^Tv)=%f' %
                      (yy, xx))
            return True
        else:
            if raiseerror:
                raise ValueError(
                    'Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f' % (yy, xx))
            if verb:
                print('Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f' %
                      (yy, xx))
            return False
    else:
        checkreal = np.abs((np.real(yy) - np.real(xx)) /
                           ((np.real(yy) + np.real(xx) + 1e-15) / 2)) < tol
        checkimag = np.abs((np.real(yy) - np.real(xx)) /
                           ((np.real(yy) + np.real(xx) + 1e-15) / 2)) < tol

        if checkreal and checkimag:
            if verb:
                print('Dot test passed, v^T(Opu)=%f - u^T(Op^Tv)=%f' %
                      (yy, xx))
            return True
        else:
            if raiseerror:
                raise ValueError(
                    'Dot test failed, v^H(Opu)=%f - u^H(Op^Hv)=%f' % (yy, xx))
            if verb:
                print('Dot test failed, v^H(Opu)=%f - u^H(Op^Hv)=%f' %
                      (yy, xx))
            return False