Пример #1
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = aet_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = aet_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Пример #2
0
        def check(shape, index_ndim, mode, order):
            multi_index = np.unravel_index(np.arange(np.product(shape)),
                                           shape,
                                           order=order)
            # create some invalid indices to test the mode
            if mode in ("wrap", "clip"):
                multi_index = (multi_index[0] - 1, ) + multi_index[1:]
            # test with scalars and higher-dimensional indices
            if index_ndim == 0:
                multi_index = tuple(i[-1] for i in multi_index)
            elif index_ndim == 2:
                multi_index = tuple(i[:, np.newaxis] for i in multi_index)
            multi_index_symb = [aesara.shared(i) for i in multi_index]

            # reference result
            ref = np.ravel_multi_index(multi_index, shape, mode, order)

            def fn(mi, s):
                return function([], ravel_multi_index(mi, s, mode, order))

            # shape given as a tuple
            f_array_tuple = fn(multi_index, shape)
            f_symb_tuple = fn(multi_index_symb, shape)
            np.testing.assert_equal(ref, f_array_tuple())
            np.testing.assert_equal(ref, f_symb_tuple())

            # shape given as an array
            shape_array = np.array(shape)
            f_array_array = fn(multi_index, shape_array)
            np.testing.assert_equal(ref, f_array_array())

            # shape given as an Aesara variable
            shape_symb = aesara.shared(shape_array)
            f_array_symb = fn(multi_index, shape_symb)
            np.testing.assert_equal(ref, f_array_symb())

            # shape testing
            self._compile_and_check(
                [],
                [ravel_multi_index(multi_index, shape_symb, mode, order)],
                [],
                RavelMultiIndex,
            )
Пример #3
0
def test_extra_ops_omni():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Пример #4
0
 def fn(mi, s):
     return function([], ravel_multi_index(mi, s, mode, order))
Пример #5
0
    def test_ravel_multi_index(self):
        def check(shape, index_ndim, mode, order):
            multi_index = np.unravel_index(np.arange(np.product(shape)),
                                           shape,
                                           order=order)
            # create some invalid indices to test the mode
            if mode in ("wrap", "clip"):
                multi_index = (multi_index[0] - 1, ) + multi_index[1:]
            # test with scalars and higher-dimensional indices
            if index_ndim == 0:
                multi_index = tuple(i[-1] for i in multi_index)
            elif index_ndim == 2:
                multi_index = tuple(i[:, np.newaxis] for i in multi_index)
            multi_index_symb = [aesara.shared(i) for i in multi_index]

            # reference result
            ref = np.ravel_multi_index(multi_index, shape, mode, order)

            def fn(mi, s):
                return function([], ravel_multi_index(mi, s, mode, order))

            # shape given as a tuple
            f_array_tuple = fn(multi_index, shape)
            f_symb_tuple = fn(multi_index_symb, shape)
            np.testing.assert_equal(ref, f_array_tuple())
            np.testing.assert_equal(ref, f_symb_tuple())

            # shape given as an array
            shape_array = np.array(shape)
            f_array_array = fn(multi_index, shape_array)
            np.testing.assert_equal(ref, f_array_array())

            # shape given as an Aesara variable
            shape_symb = aesara.shared(shape_array)
            f_array_symb = fn(multi_index, shape_symb)
            np.testing.assert_equal(ref, f_array_symb())

            # shape testing
            self._compile_and_check(
                [],
                [ravel_multi_index(multi_index, shape_symb, mode, order)],
                [],
                RavelMultiIndex,
            )

        for mode in ("raise", "wrap", "clip"):
            for order in ("C", "F"):
                for index_ndim in (0, 1, 2):
                    check((3, ), index_ndim, mode, order)
                    check((3, 4), index_ndim, mode, order)
                    check((3, 4, 5), index_ndim, mode, order)

        # must provide integers
        with pytest.raises(TypeError):
            ravel_multi_index((fvector(), ivector()), (3, 4))
        with pytest.raises(TypeError):
            ravel_multi_index(((3, 4), ivector()), (3.4, 3.2))

        # dims must be a 1D sequence
        with pytest.raises(TypeError):
            ravel_multi_index(((3, 4), ), ((3, 4), ))