def test_grad(self): x = vector("x") a = np.random.random(50).astype(config.floatX) aesara.function([x], grad(aet_sum(diff(x)), x)) utt.verify_grad(self.op, [a]) for k in range(TestDiffOp.nb): aesara.function([x], grad(aet_sum(diff(x, n=k)), x)) utt.verify_grad(DiffOp(n=k), [a], eps=7e-3)
def test_diffOp(self): x = matrix("x") a = np.random.random((30, 50)).astype(config.floatX) f = aesara.function([x], diff(x)) assert np.allclose(np.diff(a), f(a)) for axis in range(len(a.shape)): for k in range(TestDiffOp.nb): g = aesara.function([x], diff(x, n=k, axis=axis)) assert np.allclose(np.diff(a, n=k, axis=axis), g(a))
def test_perform(self, axis, n): rng = np.random.default_rng(4282) x = matrix("x") a = rng.random((30, 50)).astype(config.floatX) f = aesara.function([x], diff(x)) assert np.allclose(np.diff(a), f(a)) g = aesara.function([x], diff(x, n=n, axis=axis)) assert np.allclose(np.diff(a, n=n, axis=axis), g(a))
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_basic(self): rng = np.random.default_rng(4282) x = matrix("x") a = rng.random((30, 50)).astype(config.floatX) f = aesara.function([x], diff(x)) assert np.allclose(np.diff(a), f(a))
def test_infer_shape(self): x = matrix("x") a = np.random.random((30, 50)).astype(config.floatX) self._compile_and_check([x], [self.op(x)], [a], self.op_class) for axis in range(len(a.shape)): for k in range(TestDiffOp.nb): self._compile_and_check([x], [diff(x, n=k, axis=axis)], [a], self.op_class)
def test_output_type(self, x_type, axis, n): x = x_type("x") x_test = np.empty((10, 30)) out = diff(x, n=n, axis=axis) out_test = np.diff(x_test, n=n, axis=axis) for i in range(2): if x.type.shape[i] is None: assert out.type.shape[i] is None else: assert out.type.shape[i] == out_test.shape[i]
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = at_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) c = at.as_tensor(5) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = at_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False )