def _eager_rewriter_test_helper(self, fun, rewriter, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) expr = tracers.remake_expr(expr, {rewrite_node.fun: rewriter}) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def testEinsumOneArg(self): x = npr.randn(10) def fun(x): return np.einsum('a->a', x) expr = tracers.make_expr(fun, x) self.assertNotEqual(expr.expr_node.fun, tracers.env_lookup) expr = tracers.remake_expr(expr, {np.einsum: rewrites.maybe_einsum}) self.assertAllClose(tracers.eval_expr(expr, {'x': x}), fun(x)) self.assertEqual(expr.expr_node.fun, tracers.env_lookup)
def _rewriter_test_helper(self, fun, rewrite_rule, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewriter = rewrites.make_rewriter(rewrite_rule) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) rewriter(rewrite_node) # modifies expr in-place self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def testCommonSubexpressionElimination(self): def f1(x): return 3 * x**2 + x**2 def f2(x): y = x**2 return 3 * y + y expr1 = tracers.make_expr(f1, 1) expr2 = tracers.make_expr(f2, 1) code1 = tracers.print_expr(expr1) code2 = tracers.print_expr(expr2) self.assertGreater(len(code1), len(code2)) code1_cse = tracers.print_expr(tracers.remake_expr(expr1)) # applies cse self.assertEqual(len(code1_cse), len(code2))