def inv_as_solve(fgraph, node): if not imported_scipy: return False if isinstance(node.op, (Dot, Dot22)): l, r = node.inputs if l.owner and l.owner.op == matrix_inverse: return [solve(l.owner.inputs[0], r)] if r.owner and r.owner.op == matrix_inverse: if is_symmetric(r.owner.inputs[0]): return [solve(r.owner.inputs[0], l.T).T] else: return [solve(r.owner.inputs[0].T, l.T).T]
def inv_as_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. """ if isinstance(node.op, (Dot, Dot22)): l, r = node.inputs if l.owner and isinstance(l.owner.op, MatrixInverse): return [solve(l.owner.inputs[0], r)] if r.owner and isinstance(r.owner.op, MatrixInverse): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: return [solve(x, l.T).T] else: return [solve(x.T, l.T).T]
def test_solve_dtype(self): pytest.importorskip("scipy") dtypes = [ "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float16", "float32", "float64", ] A_val = np.eye(2) b_val = np.ones((2, 1)) # try all dtype combinations for A_dtype, b_dtype in itertools.product(dtypes, dtypes): A = matrix(dtype=A_dtype) b = matrix(dtype=b_dtype) x = solve(A, b) fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) assert x.dtype == x_result.dtype
def test_correctness(self): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() y = solve(A, b) gen_solve_func = aesara.function([A, b], y) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) assert np.allclose(scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val)) A_undef = np.array( [ [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0], ], dtype=config.floatX, ) assert np.allclose(scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val))
def test_tag_solve_triangular(): cholesky_lower = Cholesky(lower=True) cholesky_upper = Cholesky(lower=False) A = matrix("A") x = vector("x") L = cholesky_lower(A) U = cholesky_upper(A) b1 = solve(L, x) b2 = solve(U, x) f = aesara.function([A, x], b1) if config.mode != "FAST_COMPILE": for node in f.maker.fgraph.toposort(): if isinstance(node.op, Solve): assert node.op.assume_a != "gen" and node.op.lower f = aesara.function([A, x], b2) if config.mode != "FAST_COMPILE": for node in f.maker.fgraph.toposort(): if isinstance(node.op, Solve): assert node.op.assume_a != "gen" and not node.op.lower
def test_infer_shape(self, b_shape): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b_val = np.asarray(rng.random(b_shape), dtype=config.floatX) b = aet.as_tensor_variable(b_val).type() self._compile_and_check( [A, b], [solve(A, b)], [ np.asarray(rng.random((5, 5)), dtype=config.floatX), b_val, ], Solve, warn=False, )
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )