def check_dot(self, dtype): blas = Blas() x = np.random.random(10).astype(dtype=dtype) y = np.random.random(10).astype(dtype=dtype) expect = np.dot(x, y) got = blas.dot(x=x, y=y) np.testing.assert_almost_equal(expect, got, decimal=6)
def check_scal(self, dtype): blas = Blas() arr = np.random.random(10).astype(dtype=dtype) alpha = 0.5 expect = arr * alpha blas.scal(alpha=alpha, x=arr) np.testing.assert_almost_equal(expect, arr)
def check_copy(self, dtype): blas = Blas() arr = np.random.random(10).astype(dtype=dtype) expect = arr.copy() got = np.zeros_like(arr) blas.copy(x=arr, y=got) np.testing.assert_almost_equal(expect, got)
def check_swap(self, dtype): blas = Blas() x = np.random.random(10).astype(dtype=dtype) y = np.random.random(10).astype(dtype=dtype) expect_y, expect_x = x.copy(), y.copy() blas.swap(x=x, y=y) np.testing.assert_almost_equal(x, expect_x) np.testing.assert_almost_equal(y, expect_y)
def check_trtri(self, dtype): blas = Blas() n = 4 A = np.asfortranarray(np.triu(np.random.random((n, n)).astype(dtype=dtype))) invA = np.zeros_like(A) expect = np.linalg.inv(A) blas.trtri(A=A, invA=invA, uplo='upper') np.testing.assert_almost_equal(expect, invA, decimal=4)
def check_axpy(self, dtype): blas = Blas() alpha = 3.23456789 x = np.random.random(10).astype(dtype=dtype) y = np.random.random(10).astype(dtype=dtype) expect = alpha * x + y blas.axpy(alpha=alpha, x=x, y=y) np.testing.assert_almost_equal(expect, y)
def check_trtri_batched(self, dtype): blas = Blas() n = 4 batch_count = 3 As = [np.asfortranarray(np.triu(np.random.random((n, n)).astype(dtype=dtype))) for _ in range(batch_count)] A = np.stack(As) invA = np.zeros_like(A) expect = np.stack([np.linalg.inv(a) for a in A]) blas.trtri_batched(A=A, invA=invA, uplo='upper') np.testing.assert_almost_equal(expect, invA, decimal=4)
def check_gemv(self, dtype): blas = Blas() m = 3 n = 4 x = np.random.random(n).astype(dtype=dtype) y = np.random.random(m).astype(dtype=dtype) A = np.random.random((m, n)).astype(dtype=dtype, order='F') alpha = 0.7 beta = 0.8 expect = np.dot(alpha * x, A.T) + beta * y blas.gemv(alpha=alpha, beta=beta, x=x, y=y, A=A) np.testing.assert_almost_equal(expect, y, decimal=6)
def check_gemm(self, dtype, order='F'): blas = Blas() m = 4 n = 5 k = 6 A = np.random.random((n, k)).astype(dtype=dtype, order=order) B = np.random.random((k, m)).astype(dtype=dtype, order=order) C = np.random.random((n, m)).astype(dtype=dtype, order=order) alpha = 0.7 beta = 0.8 expect = alpha * np.dot(A, B) + beta * C blas.gemm(alpha=alpha, beta=beta, A=A, B=B, C=C) np.testing.assert_almost_equal(expect, C, decimal=6)
def check_trsm(self, dtype): import scipy.linalg blas = Blas() m = 4 n = 4 A = np.asfortranarray(np.triu(np.random.random((n, n)).astype(dtype=dtype))) B = np.random.random((n, m)).astype(dtype=dtype, order='F') alpha = 0.7 expect = scipy.linalg.solve_triangular(A, alpha * B) raise NotImplementedError('segfault!?!') #0 __GI___pthread_mutex_lock (mutex=0x0) at ../nptl/pthread_mutex_lock.c:66 #1 0x00007ffff7b9f9bc in std::__1::mutex::lock() () from /usr/lib/x86_64-linux-gnu/libc++.so #2 0x00007fffe8120aeb in LockedAccessor<ihipCtxCriticalBase_t<std::__1::mutex> >::LockedAccessor(ihipCtxCriticalBase_t<std::__1::mutex>&, bool) () # from /home/amd_user/rocBLAS/build/library-build/src/librocblas-hcc.so #3 0x00007fffe811cc98 in ihipStream_t::canSeePeerMemory(ihipCtx_t const*, ihipCtx_t*, ihipCtx_t*) () # from /home/amd_user/rocBLAS/build/library-build/src/librocblas-hcc.so #4 0x00007fffe811d97f in ihipStream_t::locked_copySync(void*, void const*, unsigned long, unsigned int, bool) () from /home/amd_user/rocBLAS/build/library-build/src/librocblas-hcc.so #5 0x00007fffe8128ea6 in hipMemcpy () blas.trsm(alpha=alpha, A=A, B=B, uplo='upper', side='left') np.testing.assert_almost_equal(expect, B, decimal=6)
def check_amin(self, dtype): blas = Blas() x = np.random.random(10).astype(dtype=dtype) expect = np.argmax(x) got = blas.amin(x=x) np.testing.assert_almost_equal(expect, got)
def check_nrm2(self, dtype): blas = Blas() x = np.random.random(10).astype(dtype=dtype) expect = np.linalg.norm(x) got = blas.nrm2(x=x) np.testing.assert_almost_equal(expect, got, decimal=6)
def check_asum(self, dtype): blas = Blas() x = np.random.random(10).astype(dtype=dtype) expect = np.sum(x) got = blas.asum(x=x) np.testing.assert_almost_equal(expect, got, decimal=5)