コード例 #1
0
    def _eager_rewriter_test_helper(self, fun, rewriter, *args, **kwargs):
        expr = kwargs.get('expr') or tracers.make_expr(fun, *args)
        self.assertIsInstance(expr, tracers.GraphExpr)

        env = dict(zip(inspect.getargspec(fun).args, args))
        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))

        rewrite_node = kwargs.get('rewrite_node', expr.expr_node)
        expr = tracers.remake_expr(expr, {rewrite_node.fun: rewriter})

        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))
        return tracers.remake_expr(expr)  # constant folding
コード例 #2
0
    def _rewriter_test_helper(self, fun, rewrite_rule, *args, **kwargs):
        expr = kwargs.get('expr') or tracers.make_expr(fun, *args)
        self.assertIsInstance(expr, tracers.GraphExpr)

        env = dict(zip(inspect.getargspec(fun).args, args))
        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))

        rewriter = rewrites.make_rewriter(rewrite_rule)
        rewrite_node = kwargs.get('rewrite_node', expr.expr_node)
        rewriter(rewrite_node)  # modifies expr in-place

        self.assertAllClose(fun(*args), tracers.eval_expr(expr, env))
        return tracers.remake_expr(expr)  # constant folding
コード例 #3
0
  def testEinsumAddSubSimplify(self):
    # TODO(mhoffman): Think about broadcasting. We need to support `x - 2.0`.

    def test_fun(x):
      return np.einsum('i->', x + np.full(x.shape, 2.0))

    expr = make_expr(test_fun, np.ones(3))
    test_x = np.full(3, 0.5)
    correct_value = eval_expr(expr, {'x': test_x})
    expr = canonicalize(expr)
    self.assertIsInstance(expr, GraphExpr)
    self.assertEqual(expr.expr_node.fun, add_n)
    self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum')
    new_value = eval_expr(expr, {'x': test_x})
    self.assertEqual(correct_value, new_value)
コード例 #4
0
  def testEvalExpr(self):

    def fun(x, y):
      return 2 * x**2 + np.tanh(3 * y)

    expr = tracers.make_expr(fun, 4, 5)
    self.assertEqual(fun(9, 10), tracers.eval_expr(expr, {'x': 9, 'y': 10}))
コード例 #5
0
  def testCanonicalize(self):

    def mahalanobis_distance(x, y, matrix):
      x_minus_y = x - y
      return np.einsum('i,j,ij->', x_minus_y, x_minus_y, matrix)

    x = np.array([1.3, 3.6])
    y = np.array([2.3, -1.2])
    matrix = np.arange(4).reshape([2, 2])
    expr = make_expr(mahalanobis_distance, x, y, matrix)
    self.assertFalse(is_canonical(expr))
    correct_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    expr = canonicalize(expr)
    self.assertTrue(is_canonical(expr))
    new_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    self.assertAlmostEqual(correct_value, new_value)
コード例 #6
0
 def testLinearRegression(self):
   def squared_loss(X, beta, y):
     predictions = np.einsum('ij,j->i', X, beta)
     errors = y - predictions
     return np.einsum('k,k->', errors, errors)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   y = np.random.randn(n_examples)
   expr = make_expr(squared_loss, X, beta, y)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   self.assertTrue(is_canonical(expr))
   new_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertAlmostEqual(correct_value, new_value)
コード例 #7
0
 def testSqrtToPow(self):
   def fun(x):
     return np.sqrt(x)
   expr = make_expr(fun, 3.)
   expr = canonicalize(expr)
   self.assertIsInstance(expr, GraphExpr)
   self.assertEqual(expr.expr_node.fun, np.power)
   self.assertEqual(eval_expr(expr, {'x': 3.}), fun(3.))
コード例 #8
0
 def testEinsumCompose(self):
   def Xbeta_squared(X, beta):
     Xbeta = np.einsum('ij,j->i', X, beta)
     Xbeta2 = np.einsum('lm,m->l', X, beta)
     return np.einsum('k,k->', Xbeta, Xbeta)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   expr = make_expr(Xbeta_squared, X, beta)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   new_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertAlmostEqual(correct_value, new_value)
   self.assertIsInstance(expr, GraphExpr)
   self.assertEqual(expr.expr_node.fun, np.einsum)
   self.assertTrue(is_canonical(expr))
コード例 #9
0
  def testFindSufficientStatisticNodes(self):

    def log_joint(x, y, matrix):
      # Linear in x: y^T x
      result = np.einsum('i,i->', x, y)
      # Quadratic form: x^T matrix x
      result += np.einsum('ij,i,j->', matrix, x, x)
      # Rank-1 quadratic form: (x**2)^T(y**2)
      result += np.einsum('i,i,j,j->', x, y, x, y)
      # Linear in log(x): y^T log(x)
      result += np.einsum('i,i->', y, np.log(x))
      # Linear in reciprocal(x): y^T reciprocal(x)
      result += np.einsum('i,i->', y, np.reciprocal(x))
      # More obscurely linear in log(x): y^T matrix log(x)
      result += np.einsum('i,ij,j->', y, matrix, np.log(x))
      # Linear in x * log(x): y^T (x * log(x))
      result += np.einsum('i,i->', y, x * np.log(x))
      return result

    n_dimensions = 5
    x = np.exp(np.random.randn(n_dimensions))
    y = np.random.randn(n_dimensions)
    matrix = np.random.randn(n_dimensions, n_dimensions)
    env = {'x': x, 'y': y, 'matrix': matrix}

    expr = make_expr(log_joint, x, y, matrix)
    expr = canonicalize(expr)
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(expr, 'x')
    suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env)
                  for node in sufficient_statistic_nodes]
    correct_suff_stats = [x, x.dot(matrix.dot(x)), np.square(x.dot(y)),
                          np.log(x), np.reciprocal(x), y.dot(x * np.log(x))]
    self.assertTrue(_perfect_match_values(suff_stats, correct_suff_stats))

    expr = make_expr(log_joint, x, y, matrix)
    expr = canonicalize(expr)
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(
        expr, 'x', split_einsums=True)
    suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env)
                  for node in sufficient_statistic_nodes]
    correct_suff_stats = [x, np.outer(x, x), x * x,
                          np.log(x), np.reciprocal(x), x * np.log(x)]
    self.assertTrue(_match_values(suff_stats, correct_suff_stats))
コード例 #10
0
  def testExtractSuperexpr(self):

    def f(x, y):
      return 2 * x ** 2 + y

    expr = tracers.make_expr(f, 1, 2)
    node = expr.expr_node.parents[0].parents[0]  # x ** 2

    new_expr = tracers.extract_superexpr(expr, {'x2': node})

    self.assertEqual(2 * 5 + 6, tracers.eval_expr(new_expr, {'x2': 5, 'y': 6}))
コード例 #11
0
    def testEinsumOneArg(self):
        x = npr.randn(10)

        def fun(x):
            return np.einsum('a->a', x)

        expr = tracers.make_expr(fun, x)
        self.assertNotEqual(expr.expr_node.fun, tracers.env_lookup)
        expr = tracers.remake_expr(expr, {np.einsum: rewrites.maybe_einsum})
        self.assertAllClose(tracers.eval_expr(expr, {'x': x}), fun(x))
        self.assertEqual(expr.expr_node.fun, tracers.env_lookup)
コード例 #12
0
    def testInlineExpr(self):
        def f(x, y):
            return 2 * x + y

        def g(z):
            return 3 * z + z**2

        expr = tracers.make_expr(f, 1, 2)
        subexpr = tracers.make_expr(g, 3)

        target_node = expr.expr_node.parents[0]
        new_expr = tracers.inline_expr(subexpr, {'z': target_node})
        printed_expr = tracers.print_expr(new_expr)
        expected = ("temp_0 = multiply(2, x)\n"
                    "temp_1 = multiply(3, temp_0)\n"
                    "temp_2 = power(temp_0, 2)\n"
                    "temp_3 = add(temp_1, temp_2)\n")
        self.assertEqual(printed_expr, expected)
        self.assertEqual(3 * (2 * 5) + (2 * 5)**2,
                         tracers.eval_expr(new_expr, {'x': 5}))
        self.assertEqual(f(6, 7), tracers.eval_expr(expr, {'x': 6, 'y': 7}))
コード例 #13
0
    def testReplaceNodeWithExpr(self):
        def f(x):
            return 2 * x

        def g(x):
            return 3 * x

        expr = tracers.make_expr(f, 5)
        new_expr = tracers.make_expr(g, 10)
        tracers.replace_node_with_expr(expr.expr_node, new_expr)

        self.assertEqual(3 * 7, tracers.eval_expr(expr, {'x': 7}))
コード例 #14
0
  def testExtractSuperexprWithReplaceNode(self):
    # NOTE(mattjj): this test shows an alternative way to implement, in effect,
    # tracers.extract_superexpr just using tracers.replace_node_with_expr. The
    # reason to have both is that one does in-place modification.

    def f(x, y):
      return 2 * x ** 2 + y

    expr = tracers.make_expr(f, 1, 2)
    node = expr.expr_node.parents[0].parents[0]  # x ** 2

    lookup_expr = tracers.make_expr(lambda x: x, 3, names=('x2',))
    tracers.replace_node_with_expr(node, lookup_expr)  # modify expr in-place

    self.assertEqual(2 * 5 + 6, tracers.eval_expr(expr, {'x2': 5, 'y': 6}))
コード例 #15
0
  def testInlineExprAndReplace(self):

    def f(x, y):
      return 2 * x ** 2 + y

    def g(z):
      return 3 * z ** 3

    expr = tracers.make_expr(f, 1, 2)
    subexpr = tracers.make_expr(g, 3)

    input_node = expr.expr_node.parents[0].parents[0]  # x ** 2
    output_node = expr.expr_node.parents[0]  # 2 * x ** 2

    new_expr = tracers.inline_expr(subexpr, {'z': input_node})
    tracers.replace_node_with_expr(output_node, new_expr)  # modify expr inplace

    self.assertEqual(3 * 6 ** 6 + 7,
                     tracers.eval_expr(expr, {'x': 6, 'y': 7}))