コード例 #1
0
ファイル: ops.py プロジェクト: ricardoV94/aesara
def inv_as_solve(fgraph, node):
    if not imported_scipy:
        return False
    if isinstance(node.op, (Dot, Dot22)):
        l, r = node.inputs
        if l.owner and l.owner.op == matrix_inverse:
            return [solve(l.owner.inputs[0], r)]
        if r.owner and r.owner.op == matrix_inverse:
            if is_symmetric(r.owner.inputs[0]):
                return [solve(r.owner.inputs[0], l.T).T]
            else:
                return [solve(r.owner.inputs[0].T, l.T).T]
コード例 #2
0
def inv_as_solve(fgraph, node):
    """
    This utilizes a boolean `symmetric` tag on the matrices.
    """
    if isinstance(node.op, (Dot, Dot22)):
        l, r = node.inputs
        if l.owner and isinstance(l.owner.op, MatrixInverse):
            return [solve(l.owner.inputs[0], r)]
        if r.owner and isinstance(r.owner.op, MatrixInverse):
            x = r.owner.inputs[0]
            if getattr(x.tag, "symmetric", None) is True:
                return [solve(x, l.T).T]
            else:
                return [solve(x.T, l.T).T]
コード例 #3
0
    def test_solve_dtype(self):
        pytest.importorskip("scipy")

        dtypes = [
            "uint8",
            "uint16",
            "uint32",
            "uint64",
            "int8",
            "int16",
            "int32",
            "int64",
            "float16",
            "float32",
            "float64",
        ]

        A_val = np.eye(2)
        b_val = np.ones((2, 1))

        # try all dtype combinations
        for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
            A = matrix(dtype=A_dtype)
            b = matrix(dtype=b_dtype)
            x = solve(A, b)
            fn = function([A, b], x)
            x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))

            assert x.dtype == x_result.dtype
コード例 #4
0
ファイル: test_slinalg.py プロジェクト: lucianopaz/aesara
    def test_correctness(self):
        rng = np.random.default_rng(utt.fetch_seed())
        A = matrix()
        b = matrix()
        y = solve(A, b)
        gen_solve_func = aesara.function([A, b], y)

        b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)

        A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
        A_val = np.dot(A_val.transpose(), A_val)

        assert np.allclose(scipy.linalg.solve(A_val, b_val),
                           gen_solve_func(A_val, b_val))

        A_undef = np.array(
            [
                [1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0],
                [0, 0, 1, 0, 0],
                [0, 0, 0, 1, 1],
                [0, 0, 0, 1, 0],
            ],
            dtype=config.floatX,
        )
        assert np.allclose(scipy.linalg.solve(A_undef, b_val),
                           gen_solve_func(A_undef, b_val))
コード例 #5
0
def test_tag_solve_triangular():
    cholesky_lower = Cholesky(lower=True)
    cholesky_upper = Cholesky(lower=False)
    A = matrix("A")
    x = vector("x")
    L = cholesky_lower(A)
    U = cholesky_upper(A)
    b1 = solve(L, x)
    b2 = solve(U, x)
    f = aesara.function([A, x], b1)
    if config.mode != "FAST_COMPILE":
        for node in f.maker.fgraph.toposort():
            if isinstance(node.op, Solve):
                assert node.op.assume_a != "gen" and node.op.lower
    f = aesara.function([A, x], b2)
    if config.mode != "FAST_COMPILE":
        for node in f.maker.fgraph.toposort():
            if isinstance(node.op, Solve):
                assert node.op.assume_a != "gen" and not node.op.lower
コード例 #6
0
ファイル: test_slinalg.py プロジェクト: lucianopaz/aesara
 def test_infer_shape(self, b_shape):
     rng = np.random.default_rng(utt.fetch_seed())
     A = matrix()
     b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
     b = aet.as_tensor_variable(b_val).type()
     self._compile_and_check(
         [A, b],
         [solve(A, b)],
         [
             np.asarray(rng.random((5, 5)), dtype=config.floatX),
             b_val,
         ],
         Solve,
         warn=False,
     )
コード例 #7
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )