Exemple #1
0
def solve(a, b):
    """Solves a linear matrix equation.

    It computes the exact solution of ``x`` in ``ax = b``,
    where ``a`` is a square and full rank matrix.

    Args:
        a (cupy.ndarray): The matrix with dimension ``(..., M, M)``.
        b (cupy.ndarray): The matrix with dimension ``(...,M)`` or
            ``(..., M, K)``.

    Returns:
        cupy.ndarray:
            The matrix with dimension ``(..., M)`` or ``(..., M, K)``.

    .. warning::
        This function calls one or more cuSOLVER routine(s) which may yield
        invalid results if input conditions are not met.
        To detect these invalid results, you can set the `linalg`
        configuration to a value that is not `ignore` in
        :func:`cupyx.errstate` or :func:`cupyx.seterr`.

    .. seealso:: :func:`numpy.linalg.solve`
    """
    if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit():
        # Note: There is a low performance issue in batched_gesv when matrix is
        # large, so it is not used in such cases.
        return batched_gesv(a, b)

    # TODO(kataoka): Move the checks to the beginning
    _util._assert_cupy_array(a, b)
    _util._assert_stacked_2d(a)
    _util._assert_stacked_square(a)

    if not ((a.ndim == b.ndim or a.ndim == b.ndim + 1)
            and a.shape[:-1] == b.shape[:a.ndim - 1]):
        raise ValueError(
            'a must have (..., M, M) shape and b must have (..., M) '
            'or (..., M, K)')

    dtype, out_dtype = _util.linalg_common_type(a, b)
    if a.ndim == 2:
        # prevent 'a' and 'b' to be overwritten
        a = a.astype(dtype, copy=True, order='F')
        b = b.astype(dtype, copy=True, order='F')
        cupyx.lapack.gesv(a, b)
        return b.astype(out_dtype, copy=False)

    # prevent 'a' to be overwritten
    a = a.astype(dtype, copy=True, order='C')
    x = cupy.empty_like(b, dtype=out_dtype)
    shape = a.shape[:-2]
    for i in range(numpy.prod(shape)):
        index = numpy.unravel_index(i, shape)
        # prevent 'b' to be overwritten
        bi = b[index].astype(dtype, copy=True, order='F')
        cupyx.lapack.gesv(a[index], bi)
        x[index] = bi
    return x
Exemple #2
0
def solve(a, b):
    """Solves a linear matrix equation.

    It computes the exact solution of ``x`` in ``ax = b``,
    where ``a`` is a square and full rank matrix.

    Args:
        a (cupy.ndarray): The matrix with dimension ``(..., M, M)``.
        b (cupy.ndarray): The matrix with dimension ``(...,M)`` or
            ``(..., M, K)``.

    Returns:
        cupy.ndarray:
            The matrix with dimension ``(..., M)`` or ``(..., M, K)``.

    .. warning::
        This function calls one or more cuSOLVER routine(s) which may yield
        invalid results if input conditions are not met.
        To detect these invalid results, you can set the `linalg`
        configuration to a value that is not `ignore` in
        :func:`cupyx.errstate` or :func:`cupyx.seterr`.

    .. seealso:: :func:`numpy.linalg.solve`
    """
    if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit():
        # Note: There is a low performance issue in batched_gesv when matrix is
        # large, so it is not used in such cases.
        return batched_gesv(a, b)

    _util._assert_cupy_array(a, b)
    _util._assert_nd_squareness(a)

    if not ((a.ndim == b.ndim or a.ndim == b.ndim + 1)
            and a.shape[:-1] == b.shape[:a.ndim - 1]):
        raise ValueError(
            'a must have (..., M, M) shape and b must have (..., M) '
            'or (..., M, K)')

    # Cast to float32 or float64
    if a.dtype.char == 'f' or a.dtype.char == 'd':
        dtype = a.dtype
    else:
        dtype = numpy.promote_types(a.dtype.char, 'f')

    a = a.astype(dtype)
    b = b.astype(dtype)
    if a.ndim == 2:
        return cupyx.lapack.gesv(a, b)

    x = cupy.empty_like(b)
    shape = a.shape[:-2]
    for i in range(numpy.prod(shape)):
        index = numpy.unravel_index(i, shape)
        x[index] = cupyx.lapack.gesv(a[index], b[index])
    return x
Exemple #3
0
 def setUp(self):
     if self.batched_gesv_limit is not None:
         self.old_limit = get_batched_gesv_limit()
         set_batched_gesv_limit(self.batched_gesv_limit)