def solve(a, b): """Solves a linear matrix equation. It computes the exact solution of ``x`` in ``ax = b``, where ``a`` is a square and full rank matrix. Args: a (cupy.ndarray): The matrix with dimension ``(..., M, M)``. b (cupy.ndarray): The matrix with dimension ``(...,M)`` or ``(..., M, K)``. Returns: cupy.ndarray: The matrix with dimension ``(..., M)`` or ``(..., M, K)``. .. warning:: This function calls one or more cuSOLVER routine(s) which may yield invalid results if input conditions are not met. To detect these invalid results, you can set the `linalg` configuration to a value that is not `ignore` in :func:`cupyx.errstate` or :func:`cupyx.seterr`. .. seealso:: :func:`numpy.linalg.solve` """ if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit(): # Note: There is a low performance issue in batched_gesv when matrix is # large, so it is not used in such cases. return batched_gesv(a, b) # TODO(kataoka): Move the checks to the beginning _util._assert_cupy_array(a, b) _util._assert_stacked_2d(a) _util._assert_stacked_square(a) if not ((a.ndim == b.ndim or a.ndim == b.ndim + 1) and a.shape[:-1] == b.shape[:a.ndim - 1]): raise ValueError( 'a must have (..., M, M) shape and b must have (..., M) ' 'or (..., M, K)') dtype, out_dtype = _util.linalg_common_type(a, b) if a.ndim == 2: # prevent 'a' and 'b' to be overwritten a = a.astype(dtype, copy=True, order='F') b = b.astype(dtype, copy=True, order='F') cupyx.lapack.gesv(a, b) return b.astype(out_dtype, copy=False) # prevent 'a' to be overwritten a = a.astype(dtype, copy=True, order='C') x = cupy.empty_like(b, dtype=out_dtype) shape = a.shape[:-2] for i in range(numpy.prod(shape)): index = numpy.unravel_index(i, shape) # prevent 'b' to be overwritten bi = b[index].astype(dtype, copy=True, order='F') cupyx.lapack.gesv(a[index], bi) x[index] = bi return x
def solve(a, b): """Solves a linear matrix equation. It computes the exact solution of ``x`` in ``ax = b``, where ``a`` is a square and full rank matrix. Args: a (cupy.ndarray): The matrix with dimension ``(..., M, M)``. b (cupy.ndarray): The matrix with dimension ``(...,M)`` or ``(..., M, K)``. Returns: cupy.ndarray: The matrix with dimension ``(..., M)`` or ``(..., M, K)``. .. warning:: This function calls one or more cuSOLVER routine(s) which may yield invalid results if input conditions are not met. To detect these invalid results, you can set the `linalg` configuration to a value that is not `ignore` in :func:`cupyx.errstate` or :func:`cupyx.seterr`. .. seealso:: :func:`numpy.linalg.solve` """ if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit(): # Note: There is a low performance issue in batched_gesv when matrix is # large, so it is not used in such cases. return batched_gesv(a, b) _util._assert_cupy_array(a, b) _util._assert_nd_squareness(a) if not ((a.ndim == b.ndim or a.ndim == b.ndim + 1) and a.shape[:-1] == b.shape[:a.ndim - 1]): raise ValueError( 'a must have (..., M, M) shape and b must have (..., M) ' 'or (..., M, K)') # Cast to float32 or float64 if a.dtype.char == 'f' or a.dtype.char == 'd': dtype = a.dtype else: dtype = numpy.promote_types(a.dtype.char, 'f') a = a.astype(dtype) b = b.astype(dtype) if a.ndim == 2: return cupyx.lapack.gesv(a, b) x = cupy.empty_like(b) shape = a.shape[:-2] for i in range(numpy.prod(shape)): index = numpy.unravel_index(i, shape) x[index] = cupyx.lapack.gesv(a[index], b[index]) return x
def test_batched_gesv(self): x = cublas.batched_gesv(self.a, self.b) cupy.testing.assert_allclose(x, self.x_ref, rtol=self.tol, atol=self.tol)