Пример #1
0
 def test_csrsm2(self, dtype):
     if not cusparse.check_availability('csrsm2'):
         raise unittest.SkipTest('csrsm2 is not available')
     if (self.format == 'csc' and numpy.dtype(dtype).char in 'FD' and
             self.transa == 'H'):
         raise unittest.SkipTest('unsupported combination')
     self._setup(dtype)
     x = self.b.copy(order=self.order)
     cusparse.csrsm2(self.a, x, alpha=self.alpha,
                     lower=self.lower, unit_diag=self.unit_diag,
                     transa=self.transa, blocking=self.blocking,
                     level_info=self.level_info)
     testing.assert_allclose(x, self.ref_x, atol=self.tol, rtol=self.tol)
Пример #2
0
    def solve(self, rhs, trans='N'):
        """Solves linear system of equations with one or several right-hand sides.

        Args:
            rhs (cupy.ndarray): Right-hand side(s) of equation with dimension
                ``(M)`` or ``(M, K)``.
            trans (str): 'N', 'T' or 'H'.
                'N': Solves ``A * x = rhs``.
                'T': Solves ``A.T * x = rhs``.
                'H': Solves ``A.conj().T * x = rhs``.

        Returns:
            cupy.ndarray:
                Solution vector(s)
        """
        if not isinstance(rhs, cupy.ndarray):
            raise TypeError('ojb must be cupy.ndarray')
        if rhs.ndim not in (1, 2):
            raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'.
                             format(rhs.ndim))
        if rhs.shape[0] != self.shape[0]:
            raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})'
                             .format(self.shape, rhs.shape))
        if trans not in ('N', 'T', 'H'):
            raise ValueError('trans must be \'N\', \'T\', or \'H\'')
        if not cusparse.check_availability('csrsm2'):
            raise NotImplementedError

        x = rhs.astype(self.L.dtype)
        if trans == 'N':
            if self.perm_r is not None:
                x = x[self._perm_r_rev]
            cusparse.csrsm2(self.L, x, lower=True, transa=trans)
            cusparse.csrsm2(self.U, x, lower=False, transa=trans)
            if self.perm_c is not None:
                x = x[self.perm_c]
        else:
            if self.perm_c is not None:
                x = x[self._perm_c_rev]
            cusparse.csrsm2(self.U, x, lower=False, transa=trans)
            cusparse.csrsm2(self.L, x, lower=True, transa=trans)
            if self.perm_r is not None:
                x = x[self.perm_r]

        if not x._f_contiguous:
            # For compatibility with SciPy
            x = x.copy(order='F')
        return x
Пример #3
0
    def test_csrsm2(self, dtype):
        if not cusparse.check_availability('csrsm2'):
            pytest.skip('csrsm2 is not available')
        if runtime.is_hip:
            if (self.transa == 'H'
                    or (driver.get_build_version() < 400 and
                        ((self.format == 'csc' and self.transa == 'N') or
                         (self.format == 'csr' and self.transa == 'T')))):
                pytest.xfail('may be buggy')

        if (self.format == 'csc' and numpy.dtype(dtype).char in 'FD'
                and self.transa == 'H'):
            pytest.skip('unsupported combination')
        self._setup(dtype)
        x = self.b.copy(order=self.order)
        cusparse.csrsm2(self.a,
                        x,
                        alpha=self.alpha,
                        lower=self.lower,
                        unit_diag=self.unit_diag,
                        transa=self.transa,
                        blocking=self.blocking,
                        level_info=self.level_info)
        testing.assert_allclose(x, self.ref_x, atol=self.tol, rtol=self.tol)
Пример #4
0
def spsolve_triangular(A, b, lower=True, overwrite_A=False, overwrite_b=False,
                       unit_diagonal=False):
    """Solves a sparse triangular system ``A x = b``.

    Args:
        A (cupyx.scipy.sparse.spmatrix):
            Sparse matrix with dimension ``(M, M)``.
        b (cupy.ndarray):
            Dense vector or matrix with dimension ``(M)`` or ``(M, K)``.
        lower (bool):
            Whether ``A`` is a lower or upper trinagular matrix.
            If True, it is lower triangular, otherwise, upper triangular.
        overwrite_A (bool):
            (not supported)
        overwrite_b (bool):
            Allows overwriting data in ``b``.
        unit_diagonal (bool):
            If True, diagonal elements of ``A`` are assumed to be 1 and will
            not be referencec.

    Returns:
        cupy.ndarray:
            Solution to the system ``A x = b``. The shape is the same as ``b``.
    """
    if not cusparse.check_availability('csrsm2'):
        raise NotImplementedError

    if not sparse.isspmatrix(A):
        raise TypeError('A must be cupyx.scipy.sparse.spmatrix')
    if not isinstance(b, cupy.ndarray):
        raise TypeError('b must be cupy.ndarray')
    if A.shape[0] != A.shape[1]:
        raise ValueError('A must be a square matrix (A.shape: {})'.
                         format(A.shape))
    if b.ndim not in [1, 2]:
        raise ValueError('b must be 1D or 2D array (b.shape: {})'.
                         format(b.shape))
    if A.shape[0] != b.shape[0]:
        raise ValueError('The size of dimensions of A must be equal to the '
                         'size of the first dimension of b '
                         '(A.shape: {}, b.shape: {})'.format(A.shape, b.shape))
    if A.dtype.char not in 'fdFD':
        raise TypeError('unsupported dtype (actual: {})'.format(A.dtype))

    if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)):
        warnings.warn('CSR or CSC format is required. Converting to CSR '
                      'format.', sparse.SparseEfficiencyWarning)
        A = A.tocsr()
    A.sum_duplicates()

    if (overwrite_b and A.dtype == b.dtype and
            (b._c_contiguous or b._f_contiguous)):
        x = b
    else:
        x = b.astype(A.dtype, copy=True)

    cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal)

    if x.dtype.char in 'fF':
        # Note: This is for compatibility with SciPy.
        dtype = numpy.promote_types(x.dtype, 'float64')
        x = x.astype(dtype)
    return x