Beispiel #1
0
    def test_repeatOp(self):
        for ndim in [1, 3]:
            x = TensorType(config.floatX, [False] * ndim)()
            a = np.random.random((10, ) * ndim).astype(config.floatX)

            for axis in self._possible_axis(ndim):
                for dtype in integer_dtypes:
                    r_var = scalar(dtype=dtype)
                    r = np.asarray(3, dtype=dtype)
                    if dtype == "uint64" or (dtype
                                             in self.numpy_unsupported_dtypes
                                             and r_var.ndim == 1):
                        with pytest.raises(TypeError):
                            repeat(x, r_var, axis=axis)
                    else:
                        f = aesara.function([x, r_var],
                                            repeat(x, r_var, axis=axis))
                        assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))

                        r_var = vector(dtype=dtype)
                        if axis is None:
                            r = np.random.randint(1, 6,
                                                  size=a.size).astype(dtype)
                        else:
                            r = np.random.randint(1, 6,
                                                  size=(10, )).astype(dtype)

                        if dtype in self.numpy_unsupported_dtypes and r_var.ndim == 1:
                            with pytest.raises(TypeError):
                                repeat(x, r_var, axis=axis)
                        else:
                            f = aesara.function([x, r_var],
                                                repeat(x, r_var, axis=axis))
                            assert np.allclose(np.repeat(a, r, axis=axis),
                                               f(a, r))

                        # check when r is a list of single integer, e.g. [3].
                        r = np.random.randint(1, 11, size=()).astype(dtype) + 2
                        f = aesara.function([x], repeat(x, [r], axis=axis))
                        assert np.allclose(np.repeat(a, r, axis=axis), f(a))
                        assert not np.any([
                            isinstance(n.op, RepeatOp)
                            for n in f.maker.fgraph.toposort()
                        ])

                        # check when r is  aesara tensortype that broadcastable is (True,)
                        r_var = TensorType(broadcastable=(True, ),
                                           dtype=dtype)()
                        r = np.random.randint(1, 6, size=(1, )).astype(dtype)
                        f = aesara.function([x, r_var],
                                            repeat(x, r_var, axis=axis))
                        assert np.allclose(np.repeat(a, r[0], axis=axis),
                                           f(a, r))
                        assert not np.any([
                            isinstance(n.op, RepeatOp)
                            for n in f.maker.fgraph.toposort()
                        ])
Beispiel #2
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = aet_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = aet_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Beispiel #3
0
    def test_infer_shape(self):
        for ndim in [1, 3]:
            x = TensorType(config.floatX, [False] * ndim)()
            shp = (np.arange(ndim) + 1) * 3
            a = np.random.random(shp).astype(config.floatX)

            for axis in self._possible_axis(ndim):
                for dtype in ["int8", "uint8", "uint64"]:
                    r_var = scalar(dtype=dtype)
                    r = np.asarray(3, dtype=dtype)
                    if dtype in self.numpy_unsupported_dtypes:
                        r_var = vector(dtype=dtype)
                        with pytest.raises(TypeError):
                            repeat(x, r_var)
                    else:
                        self._compile_and_check(
                            [x, r_var],
                            [RepeatOp(axis=axis)(x, r_var)],
                            [a, r],
                            self.op_class,
                        )

                        r_var = vector(dtype=dtype)
                        if axis is None:
                            r = np.random.randint(1, 6,
                                                  size=a.size).astype(dtype)
                        elif a.size > 0:
                            r = np.random.randint(
                                1, 6, size=a.shape[axis]).astype(dtype)
                        else:
                            r = np.random.randint(1, 6,
                                                  size=(10, )).astype(dtype)

                        self._compile_and_check(
                            [x, r_var],
                            [RepeatOp(axis=axis)(x, r_var)],
                            [a, r],
                            self.op_class,
                        )
Beispiel #4
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = at_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    c = at.as_tensor(5)

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )