Beispiel #1
0
def test_ordering():
    x = var("x")
    y = var("y")
    o = ordering([(1, ), (x, ), (2, ), (y, ), (x, x), (1, x), (x, 1), (1, 2)])

    for a, b in zip(o, o[1:]):
        assert supercedes(a, b) or not supercedes(b, a)
Beispiel #2
0
def test_supercedes():
    x, y, z = var("x"), var("y"), var("z")
    assert not supercedes(1, 2)
    assert supercedes(1, x)
    assert not supercedes(x, 1)
    assert supercedes((1, 2), (1, x))
    assert not supercedes((1, x), (1, 2))
    assert supercedes((1, x), (y, z))
    assert supercedes(x, y)
    assert supercedes((1, (x, 3)), (1, y))
    assert not supercedes((1, y), (1, (x, 3)))
Beispiel #3
0
def test_VarDispatcher():
    d = VarDispatcher("d")
    x, y, z = var("x"), var("y"), var("z")

    @d.register(x, y)
    def swap(y, x):
        return y, x

    assert d(1, 2) == (2, 1)

    @d.register((1, z), 2)
    def foo(z):
        return z

    assert d((1, 3), 2) == 3
Beispiel #4
0
def gen_long_chain(last_elem=None, N=None, use_lvars=False):
    """Generate a nested list of length `N` with the last element set to `last_elm`.

    Parameters
    ----------
    last_elem: object
        The element to be placed in the inner-most nested list.
    N: int
        The number of nested lists.
    use_lvars: bool
        Whether or not to add `var`s to the first elements of each nested list
        or simply integers.  If ``True``, each `var` is passed the nesting
        level integer (i.e. ``var(i)``).

    Returns
    -------
    list, dict
        The generated nested list and a ``dict`` containing the generated
        `var`s and their nesting level integers, if any.

    """
    b_struct = None
    if N is None:
        N = sys.getrecursionlimit()
    lvars = {}
    for i in range(N - 1, 0, -1):
        i_el = var(i) if use_lvars else i
        if use_lvars:
            lvars[i_el] = i
        b_struct = [i_el, last_elem if i == N - 1 else b_struct]
    return b_struct, lvars
Beispiel #5
0
def test_dict():
    d = Dispatcher("d")
    x = var("x")

    d.add(({"x": x, "key": 1}, ), identity)

    d({"x": 1, "key": 1}) == {"x": 1, "key": 1}
Beispiel #6
0
def test_isvar():
    assert not isvar(3)
    assert isvar(var(3))

    class CustomVar(Var):
        pass

    assert isvar(CustomVar())
Beispiel #7
0
def test_complex():
    d = Dispatcher("d")
    x = var("x")
    y = var("y")

    d.add((1, ), inc)
    d.add((x, ), inc)
    d.add((x, 1), add)
    d.add((y, y), mul)
    d.add((x, (x, x)), foo)

    assert d(1) == 2
    assert d(2) == 3
    assert d(2, 1) == 3
    assert d(10, 10) == 100
    assert d(10, (10, 10)) == (10, (10, 10))
    with raises(NotImplementedError):
        d(1, 2)
Beispiel #8
0
def test_dispatcher():
    x = var("x")

    @match(1)
    def fib(x):
        return 1

    @match(0)
    def fib(x):
        return 0

    @match(x)
    def fib(n):
        return fib(n - 1) + fib(n - 2)

    assert [fib(i) for i in range(10)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
Beispiel #9
0
    def _convert(y):
        if isinstance(y, str):
            v = var_map.get(y, var(y))
            var_map[y] = v
            return v
        elif isinstance(y, dict):
            pattern = y["pattern"]
            if not isinstance(pattern, str):
                raise TypeError(
                    "Constraints can only be assigned to logic variables (i.e. strings)"
                )
            constraint = y["constraint"]
            v = var_map.get(pattern, ConstrainedVar(constraint, pattern))
            var_map[pattern] = v
            return v
        elif isinstance(y, tuple):
            return etuple(*tuple(_convert(e) for e in y))
        elif isinstance(y, (Number, np.ndarray)):
            from aesara.tensor import as_tensor_variable

            return as_tensor_variable(y)
        return y
Beispiel #10
0
def test_supercedes_more():
    x, y = var("x"), var("y")
    assert supercedes((1, x), (y, y))
    assert supercedes((1, x), (x, x))
Beispiel #11
0
def test_var():
    assert var(1) == var(1)
    assert var() != var()
Beispiel #12
0
def test_var_inputs():
    assert var(1) == var(1)
    assert var() != var()
Beispiel #13
0
def test_var():
    assert var(1) == var(1)
    assert var() != var()
Beispiel #14
0
def test_isvar():
    assert not isvar(3)
    assert isvar(var(3))
Beispiel #15
0
def test_var_inputs():
    assert var(1) == var(1)
    assert var() != var()
Beispiel #16
0
def test_transitive_get():
    x, y = var(), var()
    assert transitive_get(x, {x: y, y: 1}) == 1
    assert transitive_get({1: 2}, {x: y, y: 1}) == {1: 2}
Beispiel #17
0
def test_isvar():
    assert not isvar(3)
    assert isvar(var(3))
Beispiel #18
0
def test_var():
    assert var(1) == var(1)
    one_lv = var(1)
    assert var(1) is one_lv
    assert var() != var()
    assert var(prefix="a") != var(prefix="a")