def test_repeatOp(self): for ndim in [1, 3]: x = TensorType(config.floatX, [False] * ndim)() a = np.random.random((10, ) * ndim).astype(config.floatX) for axis in self._possible_axis(ndim): for dtype in integer_dtypes: r_var = scalar(dtype=dtype) r = np.asarray(3, dtype=dtype) if dtype == "uint64" or (dtype in self.numpy_unsupported_dtypes and r_var.ndim == 1): with pytest.raises(TypeError): repeat(x, r_var, axis=axis) else: f = aesara.function([x, r_var], repeat(x, r_var, axis=axis)) assert np.allclose(np.repeat(a, r, axis=axis), f(a, r)) r_var = vector(dtype=dtype) if axis is None: r = np.random.randint(1, 6, size=a.size).astype(dtype) else: r = np.random.randint(1, 6, size=(10, )).astype(dtype) if dtype in self.numpy_unsupported_dtypes and r_var.ndim == 1: with pytest.raises(TypeError): repeat(x, r_var, axis=axis) else: f = aesara.function([x, r_var], repeat(x, r_var, axis=axis)) assert np.allclose(np.repeat(a, r, axis=axis), f(a, r)) # check when r is a list of single integer, e.g. [3]. r = np.random.randint(1, 11, size=()).astype(dtype) + 2 f = aesara.function([x], repeat(x, [r], axis=axis)) assert np.allclose(np.repeat(a, r, axis=axis), f(a)) assert not np.any([ isinstance(n.op, RepeatOp) for n in f.maker.fgraph.toposort() ]) # check when r is aesara tensortype that broadcastable is (True,) r_var = TensorType(broadcastable=(True, ), dtype=dtype)() r = np.random.randint(1, 6, size=(1, )).astype(dtype) f = aesara.function([x, r_var], repeat(x, r_var, axis=axis)) assert np.allclose(np.repeat(a, r[0], axis=axis), f(a, r)) assert not np.any([ isinstance(n.op, RepeatOp) for n in f.maker.fgraph.toposort() ])
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_infer_shape(self): for ndim in [1, 3]: x = TensorType(config.floatX, [False] * ndim)() shp = (np.arange(ndim) + 1) * 3 a = np.random.random(shp).astype(config.floatX) for axis in self._possible_axis(ndim): for dtype in ["int8", "uint8", "uint64"]: r_var = scalar(dtype=dtype) r = np.asarray(3, dtype=dtype) if dtype in self.numpy_unsupported_dtypes: r_var = vector(dtype=dtype) with pytest.raises(TypeError): repeat(x, r_var) else: self._compile_and_check( [x, r_var], [RepeatOp(axis=axis)(x, r_var)], [a, r], self.op_class, ) r_var = vector(dtype=dtype) if axis is None: r = np.random.randint(1, 6, size=a.size).astype(dtype) elif a.size > 0: r = np.random.randint( 1, 6, size=a.shape[axis]).astype(dtype) else: r = np.random.randint(1, 6, size=(10, )).astype(dtype) self._compile_and_check( [x, r_var], [RepeatOp(axis=axis)(x, r_var)], [a, r], self.op_class, )
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = at_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) c = at.as_tensor(5) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = at_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False )