def _gesvdj_batched(a, full_matrices, compute_uv, overwrite_a): if a.dtype == 'f': helper = cusolver.sgesvdjBatched_bufferSize solver = cusolver.sgesvdjBatched s_dtype = 'f' elif a.dtype == 'd': helper = cusolver.dgesvdjBatched_bufferSize solver = cusolver.dgesvdjBatched s_dtype = 'd' elif a.dtype == 'F': helper = cusolver.cgesvdjBatched_bufferSize solver = cusolver.cgesvdjBatched s_dtype = 'f' elif a.dtype == 'D': helper = cusolver.zgesvdjBatched_bufferSize solver = cusolver.zgesvdjBatched s_dtype = 'd' else: raise TypeError handle = device.get_cusolver_handle() batch_size, m, n = a.shape a = cupy.array(a.swapaxes(-2, -1), order='C', copy=not overwrite_a) lda = m mn = min(m, n) s = cupy.empty((batch_size, mn), dtype=s_dtype) ldu = m ldv = n if compute_uv: jobz = cusolver.CUSOLVER_EIG_MODE_VECTOR else: jobz = cusolver.CUSOLVER_EIG_MODE_NOVECTOR # if not batched, `full_matrices = False` could speedup. u = cupy.empty((batch_size, m, ldu), dtype=a.dtype).swapaxes(-2, -1) v = cupy.empty((batch_size, n, ldv), dtype=a.dtype).swapaxes(-2, -1) params = cusolver.createGesvdjInfo() lwork = helper(handle, jobz, m, n, a.data.ptr, lda, s.data.ptr, u.data.ptr, ldu, v.data.ptr, ldv, params, batch_size) work = cupy.empty(lwork, dtype=a.dtype) info = cupy.empty(1, dtype=numpy.int32) solver(handle, jobz, m, n, a.data.ptr, lda, s.data.ptr, u.data.ptr, ldu, v.data.ptr, ldv, work.data.ptr, lwork, info.data.ptr, params, batch_size) cupy.linalg.util._check_cusolver_dev_info_if_synchronization_allowed( gesvdj, info) cusolver.destroyGesvdjInfo(params) if not full_matrices: u = u[..., :mn] v = v[..., :mn] if compute_uv: return u, s, v else: return s
def gesvdj(a, full_matrices=True, compute_uv=True, overwrite_a=False): """Singular value decomposition using cusolverDn<t>gesvdj(). Factorizes the matrix ``a`` into two unitary matrices ``u`` and ``v`` and a singular values vector ``s`` such that ``a == u @ diag(s) @ v*``. Args: a (cupy.ndarray): The input matrix with dimension ``(M, N)``. full_matrices (bool): If True, it returns u and v with dimensions ``(M, M)`` and ``(N, N)``. Otherwise, the dimensions of u and v are respectively ``(M, K)`` and ``(K, N)``, where ``K = min(M, N)``. compute_uv (bool): If ``False``, it only returns singular values. overwrite_a (bool): If ``True``, matrix ``a`` might be overwritten. Returns: tuple of :class:`cupy.ndarray`: A tuple of ``(u, s, v)``. """ if not check_availability('gesvdj'): raise RuntimeError('gesvdj is not available.') if a.ndim == 3: return _gesvdj_batched(a, full_matrices, compute_uv, overwrite_a) assert a.ndim == 2 if a.dtype == 'f': helper = cusolver.sgesvdj_bufferSize solver = cusolver.sgesvdj s_dtype = 'f' elif a.dtype == 'd': helper = cusolver.dgesvdj_bufferSize solver = cusolver.dgesvdj s_dtype = 'd' elif a.dtype == 'F': helper = cusolver.cgesvdj_bufferSize solver = cusolver.cgesvdj s_dtype = 'f' elif a.dtype == 'D': helper = cusolver.zgesvdj_bufferSize solver = cusolver.zgesvdj s_dtype = 'd' else: raise TypeError handle = device.get_cusolver_handle() m, n = a.shape a = cupy.array(a, order='F', copy=not overwrite_a) lda = m mn = min(m, n) s = cupy.empty(mn, dtype=s_dtype) ldu = m ldv = n if compute_uv: jobz = cusolver.CUSOLVER_EIG_MODE_VECTOR else: jobz = cusolver.CUSOLVER_EIG_MODE_NOVECTOR full_matrices = False if full_matrices: econ = 0 u = cupy.empty((ldu, m), dtype=a.dtype, order='F') v = cupy.empty((ldv, n), dtype=a.dtype, order='F') else: econ = 1 u = cupy.empty((ldu, mn), dtype=a.dtype, order='F') v = cupy.empty((ldv, mn), dtype=a.dtype, order='F') params = cusolver.createGesvdjInfo() lwork = helper(handle, jobz, econ, m, n, a.data.ptr, lda, s.data.ptr, u.data.ptr, ldu, v.data.ptr, ldv, params) work = cupy.empty(lwork, dtype=a.dtype) info = cupy.empty(1, dtype=numpy.int32) solver(handle, jobz, econ, m, n, a.data.ptr, lda, s.data.ptr, u.data.ptr, ldu, v.data.ptr, ldv, work.data.ptr, lwork, info.data.ptr, params) cupy.linalg.util._check_cusolver_dev_info_if_synchronization_allowed( gesvdj, info) cusolver.destroyGesvdjInfo(params) if compute_uv: return u, s, v else: return s