def _tensordot_core(a, b, out, n, m, k, ret_shape): ret_dtype = a.dtype.char if ret_dtype != b.dtype.char: ret_dtype = numpy.find_common_type((ret_dtype, b.dtype), ()).char # Cast to float32 or float64 if ret_dtype == 'f' or ret_dtype == 'd': dtype = ret_dtype else: dtype = numpy.find_common_type((ret_dtype, 'f'), ()).char a = a.astype(dtype, copy=False) b = b.astype(dtype, copy=False) if not a.size or not b.size: if a.size or b.size: raise ValueError('cannot dot zero-sized and non-zero-sized arrays') if out is None: return cupy.zeros(ret_shape, dtype=ret_dtype) else: out.fill(0) return out if out is None: out = cupy.empty(ret_shape, dtype) if dtype == ret_dtype: ret = out else: ret = cupy.empty(ret_shape, ret_dtype) else: ret = out if out.dtype != dtype: out = cupy.empty(ret_shape, dtype) # It copies the operands if needed if a.shape != (k, n): a = cupy.reshape(a, (k, n)) if b.shape != (k, m): b = cupy.reshape(b, (k, m)) c = out if c.shape != (n, m): c = c.view() c.shape = (n, m) # Be careful that cuBLAS uses the FORTRAN-order matrix representation. if k == 1: if n == 1: # Scalar-vector product cupy.multiply(a, b, c) elif m == 1: # Scalar-vector product cupy.multiply(a.T, b, c) else: # Outer product A^T * B # c is C-contiguous while cuBLAS requires F-contiguous arrays, so # we compute C^T = B^T * A here. handle = cuda.Device().cublas_handle c.fill(0) a, inca = _to_cublas_vector(a, 1) b, incb = _to_cublas_vector(b, 1) if dtype == 'f': ger = cublas.sger elif dtype == 'd': ger = cublas.dger ger(handle, m, n, 1, b.data.ptr, incb, a.data.ptr, inca, c.data.ptr, m) if dtype != ret_dtype: elementwise.copy(out, ret) return ret handle = cuda.Device().cublas_handle if n == 1: if m == 1: # Inner product a, inca = _to_cublas_vector(a, 0) b, incb = _to_cublas_vector(b, 0) mode = cublas.getPointerMode(handle) cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) if dtype == 'f': dot = cublas.sdot elif dtype == 'd': dot = cublas.ddot try: dot(handle, k, a.data.ptr, inca, b.data.ptr, incb, c.data.ptr) finally: cublas.setPointerMode(handle, mode) else: # Matrix-vector product B^T * A a, inca = _to_cublas_vector(a, 0) b, transb, ldb = _mat_to_cublas_contiguous(b, 1) if transb: # gemv requires (m, k) as the original matrix dimensions # rather than the transposed dimensions. m, k = k, m if dtype == 'f': gemv = cublas.sgemv elif dtype == 'd': gemv = cublas.dgemv gemv(handle, transb, m, k, 1, b.data.ptr, ldb, a.data.ptr, inca, 0, c.data.ptr, 1) elif m == 1: # Matrix-vector product A^T * B a, transa, lda = _mat_to_cublas_contiguous(a, 1) b, incb = _to_cublas_vector(b, 0) if transa: # gemv requires (n, k) as the original matrix dimensions rather # than the transposed dimensions. n, k = k, n if dtype == 'f': gemv = cublas.sgemv elif dtype == 'd': gemv = cublas.dgemv gemv(handle, transa, n, k, 1, a.data.ptr, lda, b.data.ptr, incb, 0, c.data.ptr, 1) else: # Matrix-Matrix product A^T * B # c is C-contiguous while cuBLAS assumes F-contiguous inputs, so we # compute C^T = B^T * A here. a, transa, lda = _mat_to_cublas_contiguous(a, 0) b, transb, ldb = _mat_to_cublas_contiguous(b, 1) if dtype == 'f': gemm = cublas.sgemm elif dtype == 'd': gemm = cublas.dgemm gemm(handle, transb, transa, m, n, k, 1, b.data.ptr, ldb, a.data.ptr, lda, 0, c.data.ptr, m) if dtype != ret_dtype: elementwise.copy(out, ret) return ret
def tensordot(a, b, axes=2, out=None): """Returns the tensor dot product of two arrays along specified axes. This is equivalent to compute dot product along the specified axes which are treated as one axis by reshaping. Args: a (cupy.ndarray): The first argument. b (cupy.ndarray): The second argument. axes: - If it is an integer, then ``axes`` axes at the last of ``a`` and the first of ``b`` are used. - If it is a pair of sequences of integers, then these two sequences specify the list of axes for ``a`` and ``b``. The corresponding axes are paired for sum-product. out (cupy.ndarray): Output array. Returns: cupy.ndarray: The tensor dot product of ``a`` and ``b`` along the axes specified by ``axes``. .. seealso:: :func:`numpy.tensordot` """ if a.ndim == 0 or b.ndim == 0: if axes != 0 and axes != ((), ()): raise ValueError('An input is zero-dim while axes has dimensions') return cupy.multiply(a, b, out=out) ret_dtype = numpy.find_common_type([a.dtype, b.dtype], []) # Cast to float32 or float64 dtype = numpy.find_common_type([a.dtype, b.dtype, 'f'], []) a = a.astype(dtype, copy=False) b = b.astype(dtype, copy=False) if a.dtype.type == numpy.float32: dot = cublas.sdot gemv = cublas.sgemv ger = cublas.sger gemm = cublas.sgemm elif a.dtype.type == numpy.float64: dot = cublas.ddot gemv = cublas.dgemv ger = cublas.dger gemm = cublas.dgemm if numpy.isscalar(axes): axes = [list(six.moves.range(a.ndim - axes, a.ndim)), list(six.moves.range(axes))] else: axes = list(axes) if numpy.isscalar(axes[0]): axes[0] = (axes[0],) if numpy.isscalar(axes[1]): axes[1] = (axes[1],) if len(axes) != 2: raise ValueError('Axes must consist of two arrays.') if len(axes[0]) != len(axes[1]): raise ValueError('Axes length mismatch') for a_axis, b_axis in zip(*axes): if not (-a.ndim <= a_axis < a.ndim and -b.ndim <= b_axis < b.ndim): raise IndexError('Axis overrun') if a.shape[a_axis] != b.shape[b_axis]: raise ValueError('Axis dimension mismatch') # Make the axes non-negative axes = (tuple(axis % a.ndim for axis in axes[0]), tuple(axis % b.ndim for axis in axes[1])) sum_ndim = len(axes[0]) a = _move_axes_to_head(a, axes[0]) b = _move_axes_to_head(b, axes[1]) m = internal.prod(b.shape[sum_ndim:]) n = internal.prod(a.shape[sum_ndim:]) ret_shape = a.shape[sum_ndim:] + b.shape[sum_ndim:] if out is not None: if out.size != internal.prod(ret_shape): raise ValueError('Output array has an invalid size') if not out.flags.c_contiguous: raise ValueError('Output array must be C-contiguous') if 0 in a.shape or 0 in b.shape: if 0 not in a.shape or 0 not in b.shape: raise ValueError('cannot dot zero-sized and non-zero-sized arrays') if out is None: return cupy.zeros(ret_shape, dtype=ret_dtype) else: out.fill(0) return out if out is None: out = cupy.empty(ret_shape, dtype=dtype) if dtype == ret_dtype: ret = out else: ret = cupy.empty(ret_shape, dtype=ret_dtype) else: ret = out if out.dtype != dtype: out = cupy.empty(ret_shape, dtype=dtype) k = a.size // n # It copies the operands if needed a = a.reshape(k, n) b = b.reshape(k, m) c = out.view() c.shape = (n, m) # Be careful that cuBLAS uses the FORTRAN-order matrix representation. handle = cuda.Device().cublas_handle if k == 1: if n == 1 or m == 1: # Scalar-vector product cupy.multiply(a.T, b, c) else: # Outer product A^T * B # c is C-contiguous while cuBLAS requires F-contiguous arrays, so # we compute C^T = B^T * A here. c.fill(0) a, inca = _to_cublas_vector(a, 1) b, incb = _to_cublas_vector(b, 1) ger(handle, m, n, 1, b._fptr, incb, a._fptr, inca, c._fptr, m) elif n == 1: if m == 1: # Inner product a, inca = _to_cublas_vector(a, 0) b, incb = _to_cublas_vector(b, 0) mode = cublas.getPointerMode(handle) cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) try: dot(handle, k, a._fptr, inca, b._fptr, incb, c._fptr) finally: cublas.setPointerMode(handle, mode) else: # Matrix-vector product B^T * A a, inca = _to_cublas_vector(a, 1) b, transb, ldb = _mat_to_cublas_contiguous(b.T) if transb: # gemv requires (m, k) as the original matrix dimensions # rather than the transposed dimensions. m, k = k, m gemv(handle, transb, m, k, 1, b._fptr, ldb, a._fptr, inca, 0, c._fptr, 1) elif m == 1: # Matrix-vector product A^T * B a, transa, lda = _mat_to_cublas_contiguous(a.T) b, incb = _to_cublas_vector(b, 1) if not transa: # gemv requires (n, k) as the original matrix dimensions rather # than the transposed dimensions. n, k = k, n gemv(handle, transa, n, k, 1, a._fptr, lda, b._fptr, incb, 0, c._fptr, 1) else: # Matrix-Matrix product A^T * B # c is C-contiguous while cuBLAS assumes F-contiguous inputs, so we # compute C^T = B^T * A here. a, transa, lda = _mat_to_cublas_contiguous(a) b, transb, ldb = _mat_to_cublas_contiguous(b.T) gemm(handle, transb, transa, m, n, k, 1, b._fptr, ldb, a._fptr, lda, 0, c._fptr, m) if dtype != ret_dtype: elementwise.copy(out, ret) return ret