def test_csrsm2(self, dtype): if not cusparse.check_availability('csrsm2'): raise unittest.SkipTest('csrsm2 is not available') if (self.format == 'csc' and numpy.dtype(dtype).char in 'FD' and self.transa == 'H'): raise unittest.SkipTest('unsupported combination') self._setup(dtype) x = self.b.copy(order=self.order) cusparse.csrsm2(self.a, x, alpha=self.alpha, lower=self.lower, unit_diag=self.unit_diag, transa=self.transa, blocking=self.blocking, level_info=self.level_info) testing.assert_allclose(x, self.ref_x, atol=self.tol, rtol=self.tol)
def solve(self, rhs, trans='N'): """Solves linear system of equations with one or several right-hand sides. Args: rhs (cupy.ndarray): Right-hand side(s) of equation with dimension ``(M)`` or ``(M, K)``. trans (str): 'N', 'T' or 'H'. 'N': Solves ``A * x = rhs``. 'T': Solves ``A.T * x = rhs``. 'H': Solves ``A.conj().T * x = rhs``. Returns: cupy.ndarray: Solution vector(s) """ if not isinstance(rhs, cupy.ndarray): raise TypeError('ojb must be cupy.ndarray') if rhs.ndim not in (1, 2): raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. format(rhs.ndim)) if rhs.shape[0] != self.shape[0]: raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' .format(self.shape, rhs.shape)) if trans not in ('N', 'T', 'H'): raise ValueError('trans must be \'N\', \'T\', or \'H\'') if not cusparse.check_availability('csrsm2'): raise NotImplementedError x = rhs.astype(self.L.dtype) if trans == 'N': if self.perm_r is not None: x = x[self._perm_r_rev] cusparse.csrsm2(self.L, x, lower=True, transa=trans) cusparse.csrsm2(self.U, x, lower=False, transa=trans) if self.perm_c is not None: x = x[self.perm_c] else: if self.perm_c is not None: x = x[self._perm_c_rev] cusparse.csrsm2(self.U, x, lower=False, transa=trans) cusparse.csrsm2(self.L, x, lower=True, transa=trans) if self.perm_r is not None: x = x[self.perm_r] if not x._f_contiguous: # For compatibility with SciPy x = x.copy(order='F') return x
def test_csrsm2(self, dtype): if not cusparse.check_availability('csrsm2'): pytest.skip('csrsm2 is not available') if runtime.is_hip: if (self.transa == 'H' or (driver.get_build_version() < 400 and ((self.format == 'csc' and self.transa == 'N') or (self.format == 'csr' and self.transa == 'T')))): pytest.xfail('may be buggy') if (self.format == 'csc' and numpy.dtype(dtype).char in 'FD' and self.transa == 'H'): pytest.skip('unsupported combination') self._setup(dtype) x = self.b.copy(order=self.order) cusparse.csrsm2(self.a, x, alpha=self.alpha, lower=self.lower, unit_diag=self.unit_diag, transa=self.transa, blocking=self.blocking, level_info=self.level_info) testing.assert_allclose(x, self.ref_x, atol=self.tol, rtol=self.tol)
def spsolve_triangular(A, b, lower=True, overwrite_A=False, overwrite_b=False, unit_diagonal=False): """Solves a sparse triangular system ``A x = b``. Args: A (cupyx.scipy.sparse.spmatrix): Sparse matrix with dimension ``(M, M)``. b (cupy.ndarray): Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. lower (bool): Whether ``A`` is a lower or upper trinagular matrix. If True, it is lower triangular, otherwise, upper triangular. overwrite_A (bool): (not supported) overwrite_b (bool): Allows overwriting data in ``b``. unit_diagonal (bool): If True, diagonal elements of ``A`` are assumed to be 1 and will not be referencec. Returns: cupy.ndarray: Solution to the system ``A x = b``. The shape is the same as ``b``. """ if not cusparse.check_availability('csrsm2'): raise NotImplementedError if not sparse.isspmatrix(A): raise TypeError('A must be cupyx.scipy.sparse.spmatrix') if not isinstance(b, cupy.ndarray): raise TypeError('b must be cupy.ndarray') if A.shape[0] != A.shape[1]: raise ValueError('A must be a square matrix (A.shape: {})'. format(A.shape)) if b.ndim not in [1, 2]: raise ValueError('b must be 1D or 2D array (b.shape: {})'. format(b.shape)) if A.shape[0] != b.shape[0]: raise ValueError('The size of dimensions of A must be equal to the ' 'size of the first dimension of b ' '(A.shape: {}, b.shape: {})'.format(A.shape, b.shape)) if A.dtype.char not in 'fdFD': raise TypeError('unsupported dtype (actual: {})'.format(A.dtype)) if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): warnings.warn('CSR or CSC format is required. Converting to CSR ' 'format.', sparse.SparseEfficiencyWarning) A = A.tocsr() A.sum_duplicates() if (overwrite_b and A.dtype == b.dtype and (b._c_contiguous or b._f_contiguous)): x = b else: x = b.astype(A.dtype, copy=True) cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) if x.dtype.char in 'fF': # Note: This is for compatibility with SciPy. dtype = numpy.promote_types(x.dtype, 'float64') x = x.astype(dtype) return x