コード例 #1
0
    def testFullSumRewriter(self):
        def fun(x):
            return np.sum(x)

        x = npr.randn(4, 3)

        expr = tracers.make_expr(fun, x)
        self.assertEqual(expr.expr_node.fun, np.sum)

        expr = self._rewriter_test_helper(fun, rewrites.replace_sum, x)
        self.assertEqual(expr.expr_node.fun, np.einsum)

        def fun(x):
            return np.sum(x, None)

        expr = tracers.make_expr(fun, x)
        self.assertEqual(expr.expr_node.fun, np.sum)

        expr = self._rewriter_test_helper(fun, rewrites.replace_sum, x)
        self.assertEqual(expr.expr_node.fun, np.einsum)
コード例 #2
0
  def testAllDescendantsOf(self):

    def f(x, y):
      return 2 * x ** 2 + y

    expr = tracers.make_expr(f, 1, 2)
    xnode = expr.free_vars['x']
    ynode = expr.free_vars['y']

    descendants = tracers.all_descendants_of(expr.expr_node, ynode)
    self.assertEqual(descendants, {ynode, expr.expr_node})
コード例 #3
0
    def testEinsumOneArg(self):
        x = npr.randn(10)

        def fun(x):
            return np.einsum('a->a', x)

        expr = tracers.make_expr(fun, x)
        self.assertNotEqual(expr.expr_node.fun, tracers.env_lookup)
        expr = tracers.remake_expr(expr, {np.einsum: rewrites.maybe_einsum})
        self.assertAllClose(tracers.eval_expr(expr, {'x': x}), fun(x))
        self.assertEqual(expr.expr_node.fun, tracers.env_lookup)
コード例 #4
0
    def testCompoundPatternNameConstraints(self):
        def fun(x, y):
            return 3 * x + y**2

        x = np.ones(2)
        y = 2 * np.ones(2)
        end_node = tracers.make_expr(fun, x, y).expr_node

        match = matchers.matcher(
            (Add, (Multiply, 3, Val('x')), (Power, Val('x'), 2)))
        self.assertFalse(match(end_node))

        def fun(x, y):
            return 3 * x + x**2  # note x used twice

        x = np.ones(2)
        y = 2 * np.ones(2)
        end_node = tracers.make_expr(fun, x, y).expr_node

        self.assertEqual(match(end_node), {'x': end_node.args[0].args[1]})
コード例 #5
0
  def testExtractSuperexpr(self):

    def f(x, y):
      return 2 * x ** 2 + y

    expr = tracers.make_expr(f, 1, 2)
    node = expr.expr_node.parents[0].parents[0]  # x ** 2

    new_expr = tracers.extract_superexpr(expr, {'x2': node})

    self.assertEqual(2 * 5 + 6, tracers.eval_expr(new_expr, {'x2': 5, 'y': 6}))
コード例 #6
0
    def _eager_rewriter_test_helper(self, fun, rewriter, *args, **kwargs):
        expr = kwargs.get('expr') or tracers.make_expr(fun, *args)
        self.assertIsInstance(expr, tracers.GraphExpr)

        env = dict(zip(inspect.getargspec(fun).args, args))
        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))

        rewrite_node = kwargs.get('rewrite_node', expr.expr_node)
        expr = tracers.remake_expr(expr, {rewrite_node.fun: rewriter})

        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))
        return tracers.remake_expr(expr)  # constant folding
コード例 #7
0
    def testInlineExpr(self):
        def f(x, y):
            return 2 * x + y

        def g(z):
            return 3 * z + z**2

        expr = tracers.make_expr(f, 1, 2)
        subexpr = tracers.make_expr(g, 3)

        target_node = expr.expr_node.parents[0]
        new_expr = tracers.inline_expr(subexpr, {'z': target_node})
        printed_expr = tracers.print_expr(new_expr)
        expected = ("temp_0 = multiply(2, x)\n"
                    "temp_1 = multiply(3, temp_0)\n"
                    "temp_2 = power(temp_0, 2)\n"
                    "temp_3 = add(temp_1, temp_2)\n")
        self.assertEqual(printed_expr, expected)
        self.assertEqual(3 * (2 * 5) + (2 * 5)**2,
                         tracers.eval_expr(new_expr, {'x': 5}))
        self.assertEqual(f(6, 7), tracers.eval_expr(expr, {'x': 6, 'y': 7}))
コード例 #8
0
    def testPrintExpr(self):
        def fun(x, y):
            return 2 * x**2 + np.tanh(y)

        expr = tracers.make_expr(fun, 4, 5)
        printed_expr = tracers.print_expr(expr)

        expected = ("temp_0 = power(x, 2)\n"
                    "temp_1 = multiply(2, temp_0)\n"
                    "temp_2 = tanh(y)\n"
                    "temp_3 = add(temp_1, temp_2)\n")
        self.assertEqual(printed_expr, expected)
コード例 #9
0
    def testInlineExprAndReplace(self):
        def f(x, y):
            return 2 * x**2 + y

        def g(z):
            return 3 * z**3

        expr = tracers.make_expr(f, 1, 2)
        subexpr = tracers.make_expr(g, 3)

        input_node = expr.expr_node.parents[0].parents[0]  # x ** 2
        output_node = expr.expr_node.parents[0]  # 2 * x ** 2

        new_expr = tracers.inline_expr(subexpr, {'z': input_node})
        tracers.replace_node_with_expr(output_node,
                                       new_expr)  # modify expr inplace

        self.assertEqual(3 * 6**6 + 7,
                         tracers.eval_expr(expr, {
                             'x': 6,
                             'y': 7
                         }))
コード例 #10
0
    def _rewriter_test_helper(self, fun, rewrite_rule, *args, **kwargs):
        expr = kwargs.get('expr') or tracers.make_expr(fun, *args)
        self.assertIsInstance(expr, tracers.GraphExpr)

        env = dict(zip(inspect.getargspec(fun).args, args))
        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))

        rewriter = rewrites.make_rewriter(rewrite_rule)
        rewrite_node = kwargs.get('rewrite_node', expr.expr_node)
        rewriter(rewrite_node)  # modifies expr in-place

        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))
        return tracers.remake_expr(expr)  # constant folding
コード例 #11
0
    def testLiterals(self):

        match = matchers.matcher(3)
        self.assertTrue(match(3))

        def fun(x):
            return 2 + x

        x = np.ones(2)
        end_node = tracers.make_expr(fun, x).expr_node

        match = matchers.matcher((Add, 2, Val))
        self.assertTrue(match(end_node))
コード例 #12
0
  def testDescendantOf(self):

    def f(x, y):
      return 2 * x ** 2 + y

    expr = tracers.make_expr(f, 1, 2)
    xnode = expr.free_vars['x']
    ynode = expr.free_vars['y']

    self.assertTrue(tracers.is_descendant_of(expr.expr_node, xnode))
    self.assertTrue(tracers.is_descendant_of(xnode, xnode))
    self.assertTrue(tracers.is_descendant_of(expr.expr_node, expr.expr_node))
    self.assertFalse(tracers.is_descendant_of(xnode, ynode))
コード例 #13
0
    def testDotRewriter(self):
        def fun(x, y):
            return np.dot(x, y)

        x = npr.randn(4, 3)
        y = npr.randn(3)

        expr = tracers.make_expr(fun, x, y)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'dot')

        expr = self._eager_rewriter_test_helper(fun, rewrites.dot_as_einsum, x,
                                                y)
        self.assertEqual(expr.expr_node.fun.__name__, 'einsum')
コード例 #14
0
    def testEinsumTransposeRewriter(self):
        def fun(x, y):
            return np.einsum('ij,j->i', x.T, y)

        x = npr.randn(4, 3)
        y = npr.randn(4)

        expr = tracers.make_expr(fun, x, y)
        self.assertEqual(expr.expr_node.fun.__name__, 'einsum')

        expr = self._rewriter_test_helper(fun,
                                          rewrites.transpose_inside_einsum, x,
                                          y)
        self.assertFalse('transpose' in tracers.print_expr(expr))
コード例 #15
0
    def testEinsumDistributeRewriter(self):
        def fun(x, y, z):
            return np.einsum('ij,j->i', x, tracers.add_n(y, z))

        x = npr.randn(4, 3)
        y = npr.randn(3)
        z = npr.randn(3)

        expr = tracers.make_expr(fun, x, y, z)
        self.assertEqual(expr.expr_node.fun.__name__, 'einsum')

        expr = self._rewriter_test_helper(fun, rewrites.distribute_einsum, x,
                                          y, z)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')
コード例 #16
0
    def testCompoundPatternNameBindings(self):
        def fun(x, y):
            return 3 * x + y**2

        x = np.ones(2)
        y = 2 * np.ones(2)
        end_node = tracers.make_expr(fun, x, y).expr_node

        match = matchers.matcher(
            (Add, (Multiply, 3, Val('x')), (Power, Val('y'), 2)))
        self.assertEqual(match(end_node), {
            'x': end_node.args[0].args[1],
            'y': end_node.args[1].args[0]
        })
コード例 #17
0
    def testSegmentsEmpty(self):
        def fun(x, y, z):
            return np.einsum('i,j,ij->', x - y, x, z)

        x = np.ones(3)
        y = 2 * np.ones(3)
        z = 3 * np.ones((3, 3))
        end_node = tracers.make_expr(fun, x, y, z).expr_node

        pat = (Einsum, Str('formula'), Segment('args1'),
               (Choice(Subtract('op'),
                       Add('op')), Val('x'), Val('y')), Segment('args2'))
        match = matchers.matcher(pat)
        self.assertTrue(match(end_node))
コード例 #18
0
  def testEinsumAddSubSimplify(self):
    # TODO(mhoffman): Think about broadcasting. We need to support `x - 2.0`.

    def test_fun(x):
      return np.einsum('i->', x + np.full(x.shape, 2.0))

    expr = make_expr(test_fun, np.ones(3))
    test_x = np.full(3, 0.5)
    correct_value = eval_expr(expr, {'x': test_x})
    expr = canonicalize(expr)
    self.assertIsInstance(expr, GraphExpr)
    self.assertEqual(expr.expr_node.fun, add_n)
    self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum')
    new_value = eval_expr(expr, {'x': test_x})
    self.assertEqual(correct_value, new_value)
コード例 #19
0
    def testEinsumCompositionRewriter(self):
        def fun(x, y, z):
            return np.einsum('ij,jk->i', x, np.einsum('ija,ijk->ka', y, z))

        x = npr.randn(4, 3)
        y = npr.randn(5, 4, 2)
        z = npr.randn(5, 4, 3)

        expr = tracers.make_expr(fun, x, y, z)
        self.assertEqual(expr.expr_node.fun.__name__, 'einsum')
        self.assertEqual(expr.expr_node.parents[1].fun.__name__, 'einsum')

        expr = self._rewriter_test_helper(fun,
                                          rewrites.combine_einsum_compositions,
                                          x, y, z)
        self.assertNotEqual(expr.expr_node.parents[1].fun.__name__, 'einsum')
コード例 #20
0
    def testOneElementPatternNameBinding(self):
        def fun(x, y):
            return 3 * x + y**2

        x = np.ones(2)
        y = 2 * np.ones(2)
        end_node = tracers.make_expr(fun, x, y).expr_node

        match = matchers.matcher(Val('z'))
        self.assertEqual(match(end_node), {'z': end_node})

        match = matchers.matcher(Add('z'))
        self.assertEqual(match(end_node), {'z': end_node.fun})

        match = matchers.matcher(Multiply('z'))
        self.assertFalse(match(end_node))
コード例 #21
0
    def testOneElementPattern(self):
        def fun(x, y):
            return 3 * x + y**2

        x = np.ones(2)
        y = 2 * np.ones(2)
        end_node = tracers.make_expr(fun, x, y).expr_node

        match = matchers.matcher(Val)
        self.assertTrue(match(end_node))

        match = matchers.matcher(Add)
        self.assertTrue(match(end_node))

        match = matchers.matcher(Multiply)
        self.assertFalse(match(end_node))
コード例 #22
0
  def testCanonicalize(self):

    def mahalanobis_distance(x, y, matrix):
      x_minus_y = x - y
      return np.einsum('i,j,ij->', x_minus_y, x_minus_y, matrix)

    x = np.array([1.3, 3.6])
    y = np.array([2.3, -1.2])
    matrix = np.arange(4).reshape([2, 2])
    expr = make_expr(mahalanobis_distance, x, y, matrix)
    self.assertFalse(is_canonical(expr))
    correct_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    expr = canonicalize(expr)
    self.assertTrue(is_canonical(expr))
    new_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    self.assertAlmostEqual(correct_value, new_value)
コード例 #23
0
  def _test_condition_and_marginalize_diagonal_zero_mean_normal(self,
                                                                log_joint):
    n_dimensions = 5
    x = np.random.randn(n_dimensions)
    tau = np.random.randn(n_dimensions) ** 2

    end_node = make_expr(log_joint, x, tau)
    end_node = canonicalize(end_node)

    conditional, marginalized_value = _condition_and_marginalize(
        log_joint, 0, SupportTypes.REAL, x, tau)
    correct_marginalized_value = (-0.5 * np.log(tau).sum()
                                  + 0.5 * n_dimensions * np.log(2. * np.pi))
    self.assertAlmostEqual(correct_marginalized_value, marginalized_value)

    self.assertTrue(np.allclose(np.zeros(n_dimensions), conditional.args[0]))
    self.assertTrue(np.allclose(1. / np.sqrt(tau), conditional.args[1]))
コード例 #24
0
    def testLogEinsumRewriter(self):
        def fun(x, y):
            return np.log(np.einsum('ij,ij->ij', x, y))

        x = np.exp(npr.randn(4, 3))
        y = np.exp(npr.randn(4, 3))
        z = np.exp(npr.randn(4))

        expr = tracers.make_expr(fun, x, y)
        self.assertEqual(expr.expr_node.fun.__name__, 'log')
        self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum')

        expr = self._rewriter_test_helper(fun, rewrites.replace_log_einsum, x,
                                          y)
        self.assertEqual(expr.expr_node.fun.__name__, 'add')
        self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'log')
        self.assertEqual(expr.expr_node.parents[1].fun.__name__, 'log')
コード例 #25
0
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    _, natural_parameter_funs = _extract_conditional_factors(graph, 'beta')
    self.assertTrue(_match_values(natural_parameter_funs.keys(),
                                  ['x', 'einsum(...a,...b->...ab, x, x)',
                                   'einsum(...,...->..., x, x)'],
                                  lambda x, y: x == y))
    natural_parameter_vals = [f(X, beta, y) for f in
                              natural_parameter_funs.values()]
    correct_parameter_vals = [-0.5 * np.ones(n_predictors), -0.5 * X.T.dot(X),
                              y.dot(X)]
    self.assertTrue(_match_values(natural_parameter_vals,
                                  correct_parameter_vals))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))
コード例 #26
0
 def testLinearRegression(self):
   def squared_loss(X, beta, y):
     predictions = np.einsum('ij,j->i', X, beta)
     errors = y - predictions
     return np.einsum('k,k->', errors, errors)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   y = np.random.randn(n_examples)
   expr = make_expr(squared_loss, X, beta, y)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   self.assertTrue(is_canonical(expr))
   new_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertAlmostEqual(correct_value, new_value)
コード例 #27
0
 def testEinsumCompose(self):
   def Xbeta_squared(X, beta):
     Xbeta = np.einsum('ij,j->i', X, beta)
     Xbeta2 = np.einsum('lm,m->l', X, beta)
     return np.einsum('k,k->', Xbeta, Xbeta)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   expr = make_expr(Xbeta_squared, X, beta)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   new_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertAlmostEqual(correct_value, new_value)
   self.assertIsInstance(expr, GraphExpr)
   self.assertEqual(expr.expr_node.fun, np.einsum)
   self.assertTrue(is_canonical(expr))
コード例 #28
0
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    new_log_joint, _, stats_funs, _ = (
        statistic_representation(log_joint, (X, beta, y),
                               (SupportTypes.REAL,), (1,)))
    beta_stat_fun = stats_funs[0]
    beta_natparam = grad_namedtuple(new_log_joint, 1)(X, beta_stat_fun(beta), y)
    correct_beta_natparam = (-0.5 * X.T.dot(X), y.dot(X),
                             -0.5 * np.ones(n_predictors))
    self.assertTrue(_match_values(beta_natparam, correct_beta_natparam))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))
コード例 #29
0
    def testEvalPerturbed(self):
        def f(x, y):
            return 2 * x**2 + y

        expr = tracers.make_expr(f, 1, 2)
        node = expr.expr_node.parents[0].parents[0]  # x ** 2

        self.assertEqual(
            2 * (4**2 + 3) + 5,
            tracers._eval_perturbed(expr, {node: 3}, {
                'x': 4,
                'y': 5
            }))

        node = expr.free_vars['x']
        self.assertEqual(
            2 * (4 + 3)**2 + 5,
            tracers._eval_perturbed(expr, {node: 3}, {
                'x': 4,
                'y': 5
            }))
コード例 #30
0
  def testSplitEinsumNode2(self):
    n_dimensions = 5
    x = np.random.randn(n_dimensions)
    y = np.random.randn(n_dimensions)
    matrix = np.random.randn(n_dimensions, n_dimensions)

    env = {'x': x, 'y': y, 'matrix': matrix}

    args = (x, y)
    f = lambda x, y: np.einsum('i,i->', x, y)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), x))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [1])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), y))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), x * y))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))

    args = (x, y)
    f = lambda x, y: np.einsum('i,i,i->', x, y, y)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), y * y))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), x))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))

    args = (x,)
    f = lambda x: np.einsum('i,i,i->', np.ones_like(x), x, x)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), x * x))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))

    args = (matrix, x, y)
    f = lambda matrix, x, y: np.einsum('ij,i,j->', matrix, x, y)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env),
                                np.outer(x, y)))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env), matrix))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))

    args = (matrix, x, y)
    f = lambda matrix, x, y: np.einsum('i,j,ki,kj->', x, x, matrix, matrix)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [2, 3])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env),
                                matrix[:, None, :] * matrix[:, :, None]))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1])
    self.assertTrue(np.allclose(eval_node(stat_node, node, env),
                                np.outer(x, x)))
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))

    args = (matrix, x, y)
    f = lambda matrix, x, y: np.einsum(',kj,j,ka,a->', -0.5, matrix, x,
                                       matrix, y)
    node = make_expr(f, *args)
    val = f(*args)
    potential_node, stat_node = split_einsum_node2(node.expr_node, [2, 4], False)
    self.assertEqual(stat_node.args[0], 'j,a->ja')
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))
    potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1, 3], False)
    self.assertEqual(stat_node.args[0], ',kj,ka->kja')
    self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))