def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_extra_ops_omni(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])