def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_perform(self, shp): rng = np.random.default_rng(43) x = matrix() y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) val = np.cast[config.floatX](rng.random()) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) assert (out == val).sum() == min(a.shape)
def test_perform_3d(self): rng = np.random.default_rng(43) a = rng.random((3, 3, 3)).astype(config.floatX) x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) val = np.cast[config.floatX](rng.random() + 10) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val assert out[1, 1, 1] == val assert out[2, 2, 2] == val assert (out == val).sum() == min(a.shape)
def test_perform(self): x = matrix() y = scalar() f = function([x, y], fill_diagonal(x, y)) for shp in [(8, 8), (5, 8), (8, 5)]: a = np.random.rand(*shp).astype(config.floatX) val = np.cast[config.floatX](np.random.rand()) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) assert (out == val).sum() == min(a.shape) # test for 3dtt a = np.random.rand(3, 3, 3).astype(config.floatX) x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) val = np.cast[config.floatX](np.random.rand() + 10) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val assert out[1, 1, 1] == val assert out[2, 2, 2] == val assert (out == val).sum() == min(a.shape)
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = at_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) c = at.as_tensor(5) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = at_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False )