def test_mkl_mabt(): for i in range(50): A = lktu.rand_csr(20, 10, nnz=50) B = lktu.rand_csr(5, 10, nnz=20) As = mkl_ops.SparseM.from_csr(A) Bs = mkl_ops.SparseM.from_csr(B) Ch = mkl_ops._lk_mkl_spmabt(As.ptr, Bs.ptr) C = mkl_ops._to_csr(Ch) C = lm.CSR(N=C) assert C.nrows == 20 assert C.ncols == 5 Csp = A.to_scipy() @ B.to_scipy().T Cspa = Csp.toarray() Ca = C.to_scipy().toarray() assert Ca == approx(Cspa)
def test_unit_norm(): for n in range(50): csr = rand_csr() spm = csr.to_scipy().copy() m2 = csr.normalize_rows('unit') assert len(m2) == 100 for i in range(csr.nrows): vs = csr.row_vs(i) if len(vs) > 0: assert np.linalg.norm(vs) == approx(1.0) assert vs * m2[i] == approx( spm.getrow(i).toarray()[0, csr.row_cs(i)])
def test_mean_center(): for n in range(50): csr = rand_csr() spm = csr.to_scipy().copy() m2 = csr.normalize_rows('center') assert len(m2) == 100 for i in range(csr.nrows): vs = csr.row_vs(i) if len(vs) > 0: assert np.mean(vs) == approx(0.0) assert vs + m2[i] == approx( spm.getrow(i).toarray()[0, csr.row_cs(i)])
def test_filter(): csr = rand_csr() csrf = csr.filter_nnzs(csr.values > 0) assert all(csrf.values > 0) assert csrf.nnz <= csr.nnz for i in range(csr.nrows): spo, epo = csr.row_extent(i) spf, epf = csrf.row_extent(i) assert epf - spf <= epo - spo d1 = csr.to_scipy().toarray() df = csrf.to_scipy().toarray() d1[d1 < 0] = 0 assert df == approx(d1)
def test_csr_pickle(values): csr = rand_csr(100, 50, 1000, values=values) assert csr.nrows == 100 assert csr.ncols == 50 assert csr.nnz == 1000 data = pickle.dumps(csr) csr2 = pickle.loads(data) assert csr2.nrows == csr.nrows assert csr2.ncols == csr.ncols assert csr2.nnz == csr.nnz assert all(csr2.rowptrs == csr.rowptrs) assert all(csr2.colinds == csr.colinds) if values: assert all(csr2.values == csr.values) else: assert csr2.values is None
def test_csr64_pickle(values): csr = rand_csr(100, 50, 1000, values=values) csr = lm.CSR(csr.nrows, csr.ncols, csr.nnz, csr.rowptrs.astype(np.int64), csr.colinds, csr.values) assert csr.nrows == 100 assert csr.ncols == 50 assert csr.nnz == 1000 data = pickle.dumps(csr) csr2 = pickle.loads(data) assert csr2.nrows == csr.nrows assert csr2.ncols == csr.ncols assert csr2.nnz == csr.nnz assert all(csr2.rowptrs == csr.rowptrs) assert csr2.rowptrs.dtype == np.int64 assert all(csr2.colinds == csr.colinds) if values: assert all(csr2.values == csr.values) else: assert csr2.values is None