示例#1
0
    def check(self, expected):
        """
        Core function to perform comparison.

        :param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
        with:
            - `v1` and `v2` two Variables (the graphs to be compared)
            - `gj` a `givens` dictionary to give as input to `is_same_graph`
            - `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`

        This function also tries to call `is_same_graph` by inverting `v1` and
        `v2`, and ensures the output remains the same.
        """
        for v1, v2, go in expected:
            for gj, oj in go:
                r1 = is_same_graph(v1, v2, givens=gj)
                assert r1 == oj
                r2 = is_same_graph(v2, v1, givens=gj)
                assert r2 == oj
示例#2
0
 def ok(expr1, expr2):
     trees = [parse_mul_tree(e) for e in (expr1, expr2)]
     perform_sigm_times_exp(trees[0])
     trees[0] = simplify_mul(trees[0])
     good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1]))
     if not good:
         print(trees[0])
         print(trees[1])
         print("***")
         theano.printing.debugprint(compute_mul(trees[0]))
         print("***")
         theano.printing.debugprint(compute_mul(trees[1]))
     assert good
示例#3
0
 def test_is_1pexp(self):
     backup = config.warn__identify_1pexp_bug
     config.warn__identify_1pexp_bug = False
     try:
         x = tt.vector("x")
         exp = tt.exp
         assert is_1pexp(1 + exp(x), False) == (False, x)
         assert is_1pexp(exp(x) + 1, False) == (False, x)
         for neg, exp_arg in map(
             lambda x: is_1pexp(x, only_process_constants=False),
             [(1 + exp(-x)), (exp(-x) + 1)],
         ):
             assert not neg and is_same_graph(exp_arg, -x)
         assert is_1pexp(1 - exp(x), False) is None
         assert is_1pexp(2 + exp(x), False) is None
         assert is_1pexp(exp(x) + 2, False) is None
         assert is_1pexp(exp(x) - 1, False) is None
         assert is_1pexp(-1 + exp(x), False) is None
         assert is_1pexp(1 + 2 * exp(x), False) is None
     finally:
         config.warn__identify_1pexp_bug = backup
示例#4
0
 def test_compute_mul(self):
     x, y, z = tt.vectors("x", "y", "z")
     tree = (x * y) * -z
     mul_tree = parse_mul_tree(tree)
     assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
     assert is_same_graph(compute_mul(parse_mul_tree(tree)), tree)