def test_cublasCgemmBatched(self): l, m, k, n = 11, 7, 5, 3 A = (np.random.rand(l, m, k)+1j*np.random.rand(l, m, k)).astype(np.complex64) B = (np.random.rand(l, k, n)+1j*np.random.rand(l, k, n)).astype(np.complex64) C_res = np.einsum('nij,njk->nik', A, B) a_gpu = gpuarray.to_gpu(A) b_gpu = gpuarray.to_gpu(B) c_gpu = gpuarray.empty((l, m, n), np.complex64) alpha = np.complex64(1.0) beta = np.complex64(0.0) a_arr = bptrs(a_gpu) b_arr = bptrs(b_gpu) c_arr = bptrs(c_gpu) cublas.cublasCgemmBatched(self.cublas_handle, 'n','n', n, m, k, alpha, b_arr.gpudata, n, a_arr.gpudata, k, beta, c_arr.gpudata, n, l) assert np.allclose(C_res, c_gpu.get())
def sc_complex_dot_batched(bx_gpu, by_gpu, bc_gpu, transa='N', transb='N', handle=None): """ uses cublasCgemmBatched to compute a bunch of complex dot products in parallel """ if handle is None: handle = scikits.cuda.misc._global_cublas_handle assert len(bx_gpu.shape) == 3 assert len(by_gpu.shape) == 3 assert len(bc_gpu.shape) == 3 assert bx_gpu.dtype == np.complex64 assert by_gpu.dtype == np.complex64 assert bc_gpu.dtype == np.complex64 # Get the shapes of the arguments bx_shape = bx_gpu.shape by_shape = by_gpu.shape # Perform matrix multiplication for 2D arrays: alpha = np.complex64(1.0) beta = np.complex64(0.0) transa = string.lower(transa) transb = string.lower(transb) if transb in ['t', 'c']: N, m, k = by_shape elif transb in ['n']: N, k, m = by_shape else: raise ValueError('invalid value for transb') if transa in ['t', 'c']: N2, l, n = bx_shape elif transa in ['n']: N2, n, l = bx_shape else: raise ValueError('invalid value for transa') if l != k: raise ValueError('objects are not aligned') if N != N2: raise ValueError('batch sizes are not the same') if transb == 'n': lda = max(1, m) else: lda = max(1, k) if transa == 'n': ldb = max(1, k) else: ldb = max(1, n) ldc = max(1, m) # construct pointer arrays needed for cublasCgemmBatched bx_arr = bptrs(bx_gpu) by_arr = bptrs(by_gpu) bc_arr = bptrs(bc_gpu) cublas.cublasCgemmBatched(handle, transb, transa, m, n, k, alpha, by_arr.gpudata, lda, bx_arr.gpudata, ldb, beta, bc_arr.gpudata, ldc, N)