Exemplo n.º 1
0
def test_jax_multioutput():
    x = vector("x")
    x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
    y = vector("y")
    y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)

    w = cosh(x**2 + y / 3.0)
    v = cosh(x / 3.0 + y**2)

    fgraph = FunctionGraph([x, y], [w, v])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Exemplo n.º 2
0
def test_reallocation():
    x = scalar("x")
    y = scalar("y")
    z = tanh(3 * x + y) + cosh(x + 5 * y)
    # The functinality is currently implement for non lazy and non c VM only.
    for linker in [
            VMLinker(allow_gc=False, lazy=False, use_cloop=False),
            VMLinker(allow_gc=True, lazy=False, use_cloop=False),
    ]:
        m = get_mode(Mode(linker=linker))
        m = m.excluding("fusion", "inplace")

        f = function([x, y], z, name="test_reduce_memory", mode=m)
        output = f(1, 2)
        assert output
        storage_map = f.fn.storage_map

        def check_storage(storage_map):
            for i in storage_map:
                if not isinstance(i, TensorConstant):
                    keys_copy = list(storage_map.keys())[:]
                    keys_copy.remove(i)
                    for o in keys_copy:
                        if storage_map[i][
                                0] and storage_map[i][0] is storage_map[o][0]:
                            return [True, storage_map[o][0]]
            return [False, None]

        assert check_storage(storage_map)[0]
        assert len({id(v) for v in storage_map.values()}) < len(storage_map)
Exemplo n.º 3
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )