def test_ordering(): x = var("x") y = var("y") o = ordering([(1, ), (x, ), (2, ), (y, ), (x, x), (1, x), (x, 1), (1, 2)]) for a, b in zip(o, o[1:]): assert supercedes(a, b) or not supercedes(b, a)
def test_supercedes(): x, y, z = var("x"), var("y"), var("z") assert not supercedes(1, 2) assert supercedes(1, x) assert not supercedes(x, 1) assert supercedes((1, 2), (1, x)) assert not supercedes((1, x), (1, 2)) assert supercedes((1, x), (y, z)) assert supercedes(x, y) assert supercedes((1, (x, 3)), (1, y)) assert not supercedes((1, y), (1, (x, 3)))
def test_VarDispatcher(): d = VarDispatcher("d") x, y, z = var("x"), var("y"), var("z") @d.register(x, y) def swap(y, x): return y, x assert d(1, 2) == (2, 1) @d.register((1, z), 2) def foo(z): return z assert d((1, 3), 2) == 3
def gen_long_chain(last_elem=None, N=None, use_lvars=False): """Generate a nested list of length `N` with the last element set to `last_elm`. Parameters ---------- last_elem: object The element to be placed in the inner-most nested list. N: int The number of nested lists. use_lvars: bool Whether or not to add `var`s to the first elements of each nested list or simply integers. If ``True``, each `var` is passed the nesting level integer (i.e. ``var(i)``). Returns ------- list, dict The generated nested list and a ``dict`` containing the generated `var`s and their nesting level integers, if any. """ b_struct = None if N is None: N = sys.getrecursionlimit() lvars = {} for i in range(N - 1, 0, -1): i_el = var(i) if use_lvars else i if use_lvars: lvars[i_el] = i b_struct = [i_el, last_elem if i == N - 1 else b_struct] return b_struct, lvars
def test_dict(): d = Dispatcher("d") x = var("x") d.add(({"x": x, "key": 1}, ), identity) d({"x": 1, "key": 1}) == {"x": 1, "key": 1}
def test_isvar(): assert not isvar(3) assert isvar(var(3)) class CustomVar(Var): pass assert isvar(CustomVar())
def test_complex(): d = Dispatcher("d") x = var("x") y = var("y") d.add((1, ), inc) d.add((x, ), inc) d.add((x, 1), add) d.add((y, y), mul) d.add((x, (x, x)), foo) assert d(1) == 2 assert d(2) == 3 assert d(2, 1) == 3 assert d(10, 10) == 100 assert d(10, (10, 10)) == (10, (10, 10)) with raises(NotImplementedError): d(1, 2)
def test_dispatcher(): x = var("x") @match(1) def fib(x): return 1 @match(0) def fib(x): return 0 @match(x) def fib(n): return fib(n - 1) + fib(n - 2) assert [fib(i) for i in range(10)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
def _convert(y): if isinstance(y, str): v = var_map.get(y, var(y)) var_map[y] = v return v elif isinstance(y, dict): pattern = y["pattern"] if not isinstance(pattern, str): raise TypeError( "Constraints can only be assigned to logic variables (i.e. strings)" ) constraint = y["constraint"] v = var_map.get(pattern, ConstrainedVar(constraint, pattern)) var_map[pattern] = v return v elif isinstance(y, tuple): return etuple(*tuple(_convert(e) for e in y)) elif isinstance(y, (Number, np.ndarray)): from aesara.tensor import as_tensor_variable return as_tensor_variable(y) return y
def test_supercedes_more(): x, y = var("x"), var("y") assert supercedes((1, x), (y, y)) assert supercedes((1, x), (x, x))
def test_var(): assert var(1) == var(1) assert var() != var()
def test_var_inputs(): assert var(1) == var(1) assert var() != var()
def test_isvar(): assert not isvar(3) assert isvar(var(3))
def test_transitive_get(): x, y = var(), var() assert transitive_get(x, {x: y, y: 1}) == 1 assert transitive_get({1: 2}, {x: y, y: 1}) == {1: 2}
def test_var(): assert var(1) == var(1) one_lv = var(1) assert var(1) is one_lv assert var() != var() assert var(prefix="a") != var(prefix="a")