def test_dot(self): x = at.lmatrix("x") y = at.lmatrix("y") x = sparse.csr_from_dense(x) y = sparse.csr_from_dense(y) z = x.__dot__(y) assert isinstance(z.type, SparseTensorType) f = aesara.function([x, y], z) exp_res = f( [[1, 0, 2], [-1, 0, 0]], [[-1], [2], [1]], ) assert isinstance(exp_res, csr_matrix)
def test_unary(self, method, exp_type, cm): x = at.dmatrix("x") x = sparse.csr_from_dense(x) method_to_call = getattr(x, method) if cm is None: cm = pytest.warns(UserWarning, match=".*converted to dense.*") if exp_type == SparseTensorType: exp_res_type = csr_matrix else: exp_res_type = np.ndarray with cm: z = method_to_call() if not isinstance(z, tuple): z_outs = (z, ) else: z_outs = z assert all(isinstance(out.type, exp_type) for out in z_outs) f = aesara.function([x], z, on_unused_input="ignore") res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) if not isinstance(res, list): res_outs = [res] else: res_outs = res assert all(isinstance(out, exp_res_type) for out in res_outs)
def test_getitem(self): x = at.dmatrix("x") x = sparse.csr_from_dense(x) z = x[:, :2] assert isinstance(z.type, SparseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) assert isinstance(exp_res, csr_matrix)
def test_repeat(self): x = at.dmatrix("x") x = sparse.csr_from_dense(x) with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.repeat(2, axis=1) assert isinstance(z.type, DenseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) assert isinstance(exp_res, np.ndarray)
def test_binary(self, method, exp_type): x = at.lmatrix("x") y = at.lmatrix("y") x = sparse.csr_from_dense(x) y = sparse.csr_from_dense(y) method_to_call = getattr(x, method) if exp_type == SparseTensorType: exp_res_type = csr_matrix cm = ExitStack() else: exp_res_type = np.ndarray cm = pytest.warns(UserWarning, match=".*converted to dense.*") with cm: z = method_to_call(y) if not isinstance(z, tuple): z_outs = (z, ) else: z_outs = z assert all(isinstance(out.type, exp_type) for out in z_outs) f = aesara.function([x, y], z) res = f( [[1, 0, 2], [-1, 0, 0]], [[1, 1, 2], [1, 4, 1]], ) if not isinstance(res, list): res_outs = [res] else: res_outs = res assert all(isinstance(out, exp_res_type) for out in res_outs)