def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def check(shape, index_ndim, mode, order): multi_index = np.unravel_index(np.arange(np.product(shape)), shape, order=order) # create some invalid indices to test the mode if mode in ("wrap", "clip"): multi_index = (multi_index[0] - 1, ) + multi_index[1:] # test with scalars and higher-dimensional indices if index_ndim == 0: multi_index = tuple(i[-1] for i in multi_index) elif index_ndim == 2: multi_index = tuple(i[:, np.newaxis] for i in multi_index) multi_index_symb = [aesara.shared(i) for i in multi_index] # reference result ref = np.ravel_multi_index(multi_index, shape, mode, order) def fn(mi, s): return function([], ravel_multi_index(mi, s, mode, order)) # shape given as a tuple f_array_tuple = fn(multi_index, shape) f_symb_tuple = fn(multi_index_symb, shape) np.testing.assert_equal(ref, f_array_tuple()) np.testing.assert_equal(ref, f_symb_tuple()) # shape given as an array shape_array = np.array(shape) f_array_array = fn(multi_index, shape_array) np.testing.assert_equal(ref, f_array_array()) # shape given as an Aesara variable shape_symb = aesara.shared(shape_array) f_array_symb = fn(multi_index, shape_symb) np.testing.assert_equal(ref, f_array_symb()) # shape testing self._compile_and_check( [], [ravel_multi_index(multi_index, shape_symb, mode, order)], [], RavelMultiIndex, )
def test_extra_ops_omni(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def fn(mi, s): return function([], ravel_multi_index(mi, s, mode, order))
def test_ravel_multi_index(self): def check(shape, index_ndim, mode, order): multi_index = np.unravel_index(np.arange(np.product(shape)), shape, order=order) # create some invalid indices to test the mode if mode in ("wrap", "clip"): multi_index = (multi_index[0] - 1, ) + multi_index[1:] # test with scalars and higher-dimensional indices if index_ndim == 0: multi_index = tuple(i[-1] for i in multi_index) elif index_ndim == 2: multi_index = tuple(i[:, np.newaxis] for i in multi_index) multi_index_symb = [aesara.shared(i) for i in multi_index] # reference result ref = np.ravel_multi_index(multi_index, shape, mode, order) def fn(mi, s): return function([], ravel_multi_index(mi, s, mode, order)) # shape given as a tuple f_array_tuple = fn(multi_index, shape) f_symb_tuple = fn(multi_index_symb, shape) np.testing.assert_equal(ref, f_array_tuple()) np.testing.assert_equal(ref, f_symb_tuple()) # shape given as an array shape_array = np.array(shape) f_array_array = fn(multi_index, shape_array) np.testing.assert_equal(ref, f_array_array()) # shape given as an Aesara variable shape_symb = aesara.shared(shape_array) f_array_symb = fn(multi_index, shape_symb) np.testing.assert_equal(ref, f_array_symb()) # shape testing self._compile_and_check( [], [ravel_multi_index(multi_index, shape_symb, mode, order)], [], RavelMultiIndex, ) for mode in ("raise", "wrap", "clip"): for order in ("C", "F"): for index_ndim in (0, 1, 2): check((3, ), index_ndim, mode, order) check((3, 4), index_ndim, mode, order) check((3, 4, 5), index_ndim, mode, order) # must provide integers with pytest.raises(TypeError): ravel_multi_index((fvector(), ivector()), (3, 4)) with pytest.raises(TypeError): ravel_multi_index(((3, 4), ivector()), (3.4, 3.2)) # dims must be a 1D sequence with pytest.raises(TypeError): ravel_multi_index(((3, 4), ), ((3, 4), ))