Esempio n. 1
0
  def testCommonSubexpressionElimination(self):

    def f1(x):
      return 3 * x**2 + x**2

    def f2(x):
      y = x**2
      return 3 * y + y

    expr1 = tracers.make_expr(f1, 1)
    expr2 = tracers.make_expr(f2, 1)

    code1 = tracers.print_expr(expr1)
    code2 = tracers.print_expr(expr2)
    self.assertGreater(len(code1), len(code2))

    code1_cse = tracers.print_expr(tracers.remake_expr(expr1))  # applies cse
    self.assertEqual(len(code1_cse), len(code2))
Esempio n. 2
0
    def testPrintExpr(self):
        def fun(x, y):
            return 2 * x**2 + np.tanh(y)

        expr = tracers.make_expr(fun, 4, 5)
        printed_expr = tracers.print_expr(expr)

        expected = ("temp_0 = power(x, 2)\n"
                    "temp_1 = multiply(2, temp_0)\n"
                    "temp_2 = tanh(y)\n"
                    "temp_3 = add(temp_1, temp_2)\n")
        self.assertEqual(printed_expr, expected)
Esempio n. 3
0
    def testEinsumTransposeRewriter(self):
        def fun(x, y):
            return np.einsum('ij,j->i', x.T, y)

        x = npr.randn(4, 3)
        y = npr.randn(4)

        expr = tracers.make_expr(fun, x, y)
        self.assertEqual(expr.expr_node.fun.__name__, 'einsum')

        expr = self._rewriter_test_helper(fun,
                                          rewrites.transpose_inside_einsum, x,
                                          y)
        self.assertFalse('transpose' in tracers.print_expr(expr))
Esempio n. 4
0
    def testInlineExpr(self):
        def f(x, y):
            return 2 * x + y

        def g(z):
            return 3 * z + z**2

        expr = tracers.make_expr(f, 1, 2)
        subexpr = tracers.make_expr(g, 3)

        target_node = expr.expr_node.parents[0]
        new_expr = tracers.inline_expr(subexpr, {'z': target_node})
        printed_expr = tracers.print_expr(new_expr)
        expected = ("temp_0 = multiply(2, x)\n"
                    "temp_1 = multiply(3, temp_0)\n"
                    "temp_2 = power(temp_0, 2)\n"
                    "temp_3 = add(temp_1, temp_2)\n")
        self.assertEqual(printed_expr, expected)
        self.assertEqual(3 * (2 * 5) + (2 * 5)**2,
                         tracers.eval_expr(new_expr, {'x': 5}))
        self.assertEqual(f(6, 7), tracers.eval_expr(expr, {'x': 6, 'y': 7}))