def test_cublasCgemmBatched(self): l, m, k, n = 11, 7, 5, 3 A = (np.random.rand(l, m, k)+1j*np.random.rand(l, m, k)).astype(np.complex64) B = (np.random.rand(l, k, n)+1j*np.random.rand(l, k, n)).astype(np.complex64) C_res = np.einsum('nij,njk->nik', A, B) a_gpu = gpuarray.to_gpu(A) b_gpu = gpuarray.to_gpu(B) c_gpu = gpuarray.empty((l, m, n), np.complex64) alpha = np.complex64(1.0) beta = np.complex64(0.0) a_arr = bptrs(a_gpu) b_arr = bptrs(b_gpu) c_arr = bptrs(c_gpu) cublas.cublasCgemmBatched(self.cublas_handle, 'n','n', n, m, k, alpha, b_arr.gpudata, n, a_arr.gpudata, k, beta, c_arr.gpudata, n, l) assert np.allclose(C_res, c_gpu.get())
def cgemm(A, B, transa=False, transb=False, alpha=1,beta=1): """This function uses the C-wrapper to use cuBLAS. """ CUBLAS_OP_N = cublas._CUBLAS_OP['n'] CUBLAS_OP_T = cublas._CUBLAS_OP['t'] m, n, k = A.size(1),B.size(2),A.size(2) batchCount = A.size(0) C = A.new(batchCount,m,n) lda = m ldb = k ldc = m trans_a_ptr = CUBLAS_OP_N if not transa else CUBLAS_OP_T trans_b_ptr = CUBLAS_OP_N if not transb else CUBLAS_OP_T alpha_tensor = torch.cuda.tensor([1]).fill_(alpha) beta_tensor = torch.cuda.tensor([1]).fill_(beta) handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCgemmBatched(handle, trans_a_ptr, trans_b_ptr, m, n, k, alpha_tensor.data_ptr(), A.data_ptr(), lda, B.data_ptr(),ldb, beta_tensor.data_ptr(), C.data_ptr(), ldc, batchCount) return C