示例#1
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = aet_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = aet_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
示例#2
0
    def test_perform(self, shp):
        rng = np.random.default_rng(43)

        x = matrix()
        y = scalar()
        f = function([x, y], fill_diagonal(x, y))
        a = rng.random(shp).astype(config.floatX)
        val = np.cast[config.floatX](rng.random())
        out = f(a, val)
        # We can't use np.fill_diagonal as it is bugged.
        assert np.allclose(np.diag(out), val)
        assert (out == val).sum() == min(a.shape)
示例#3
0
 def test_perform_3d(self):
     rng = np.random.default_rng(43)
     a = rng.random((3, 3, 3)).astype(config.floatX)
     x = tensor3()
     y = scalar()
     f = function([x, y], fill_diagonal(x, y))
     val = np.cast[config.floatX](rng.random() + 10)
     out = f(a, val)
     # We can't use np.fill_diagonal as it is bugged.
     assert out[0, 0, 0] == val
     assert out[1, 1, 1] == val
     assert out[2, 2, 2] == val
     assert (out == val).sum() == min(a.shape)
示例#4
0
    def test_perform(self):
        x = matrix()
        y = scalar()
        f = function([x, y], fill_diagonal(x, y))
        for shp in [(8, 8), (5, 8), (8, 5)]:
            a = np.random.rand(*shp).astype(config.floatX)
            val = np.cast[config.floatX](np.random.rand())
            out = f(a, val)
            # We can't use np.fill_diagonal as it is bugged.
            assert np.allclose(np.diag(out), val)
            assert (out == val).sum() == min(a.shape)

        # test for 3dtt
        a = np.random.rand(3, 3, 3).astype(config.floatX)
        x = tensor3()
        y = scalar()
        f = function([x, y], fill_diagonal(x, y))
        val = np.cast[config.floatX](np.random.rand() + 10)
        out = f(a, val)
        # We can't use np.fill_diagonal as it is bugged.
        assert out[0, 0, 0] == val
        assert out[1, 1, 1] == val
        assert out[2, 2, 2] == val
        assert (out == val).sum() == min(a.shape)
示例#5
0
文件: test_jax.py 项目: mgorny/aesara
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = at_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    c = at.as_tensor(5)

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )