def _eager_rewriter_test_helper(self, fun, rewriter, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) expr = tracers.remake_expr(expr, {rewrite_node.fun: rewriter}) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def _rewriter_test_helper(self, fun, rewrite_rule, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewriter = rewrites.make_rewriter(rewrite_rule) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) rewriter(rewrite_node) # modifies expr in-place self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def testEinsumAddSubSimplify(self): # TODO(mhoffman): Think about broadcasting. We need to support `x - 2.0`. def test_fun(x): return np.einsum('i->', x + np.full(x.shape, 2.0)) expr = make_expr(test_fun, np.ones(3)) test_x = np.full(3, 0.5) correct_value = eval_expr(expr, {'x': test_x}) expr = canonicalize(expr) self.assertIsInstance(expr, GraphExpr) self.assertEqual(expr.expr_node.fun, add_n) self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum') new_value = eval_expr(expr, {'x': test_x}) self.assertEqual(correct_value, new_value)
def testEvalExpr(self): def fun(x, y): return 2 * x**2 + np.tanh(3 * y) expr = tracers.make_expr(fun, 4, 5) self.assertEqual(fun(9, 10), tracers.eval_expr(expr, {'x': 9, 'y': 10}))
def testCanonicalize(self): def mahalanobis_distance(x, y, matrix): x_minus_y = x - y return np.einsum('i,j,ij->', x_minus_y, x_minus_y, matrix) x = np.array([1.3, 3.6]) y = np.array([2.3, -1.2]) matrix = np.arange(4).reshape([2, 2]) expr = make_expr(mahalanobis_distance, x, y, matrix) self.assertFalse(is_canonical(expr)) correct_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix}) expr = canonicalize(expr) self.assertTrue(is_canonical(expr)) new_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix}) self.assertAlmostEqual(correct_value, new_value)
def testLinearRegression(self): def squared_loss(X, beta, y): predictions = np.einsum('ij,j->i', X, beta) errors = y - predictions return np.einsum('k,k->', errors, errors) n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) y = np.random.randn(n_examples) expr = make_expr(squared_loss, X, beta, y) correct_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y}) self.assertFalse(is_canonical(expr)) expr = canonicalize(expr) self.assertTrue(is_canonical(expr)) new_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y}) self.assertAlmostEqual(correct_value, new_value)
def testSqrtToPow(self): def fun(x): return np.sqrt(x) expr = make_expr(fun, 3.) expr = canonicalize(expr) self.assertIsInstance(expr, GraphExpr) self.assertEqual(expr.expr_node.fun, np.power) self.assertEqual(eval_expr(expr, {'x': 3.}), fun(3.))
def testEinsumCompose(self): def Xbeta_squared(X, beta): Xbeta = np.einsum('ij,j->i', X, beta) Xbeta2 = np.einsum('lm,m->l', X, beta) return np.einsum('k,k->', Xbeta, Xbeta) n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) expr = make_expr(Xbeta_squared, X, beta) correct_value = eval_expr(expr, {'X': X, 'beta': beta}) self.assertFalse(is_canonical(expr)) expr = canonicalize(expr) new_value = eval_expr(expr, {'X': X, 'beta': beta}) self.assertAlmostEqual(correct_value, new_value) self.assertIsInstance(expr, GraphExpr) self.assertEqual(expr.expr_node.fun, np.einsum) self.assertTrue(is_canonical(expr))
def testFindSufficientStatisticNodes(self): def log_joint(x, y, matrix): # Linear in x: y^T x result = np.einsum('i,i->', x, y) # Quadratic form: x^T matrix x result += np.einsum('ij,i,j->', matrix, x, x) # Rank-1 quadratic form: (x**2)^T(y**2) result += np.einsum('i,i,j,j->', x, y, x, y) # Linear in log(x): y^T log(x) result += np.einsum('i,i->', y, np.log(x)) # Linear in reciprocal(x): y^T reciprocal(x) result += np.einsum('i,i->', y, np.reciprocal(x)) # More obscurely linear in log(x): y^T matrix log(x) result += np.einsum('i,ij,j->', y, matrix, np.log(x)) # Linear in x * log(x): y^T (x * log(x)) result += np.einsum('i,i->', y, x * np.log(x)) return result n_dimensions = 5 x = np.exp(np.random.randn(n_dimensions)) y = np.random.randn(n_dimensions) matrix = np.random.randn(n_dimensions, n_dimensions) env = {'x': x, 'y': y, 'matrix': matrix} expr = make_expr(log_joint, x, y, matrix) expr = canonicalize(expr) sufficient_statistic_nodes = find_sufficient_statistic_nodes(expr, 'x') suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env) for node in sufficient_statistic_nodes] correct_suff_stats = [x, x.dot(matrix.dot(x)), np.square(x.dot(y)), np.log(x), np.reciprocal(x), y.dot(x * np.log(x))] self.assertTrue(_perfect_match_values(suff_stats, correct_suff_stats)) expr = make_expr(log_joint, x, y, matrix) expr = canonicalize(expr) sufficient_statistic_nodes = find_sufficient_statistic_nodes( expr, 'x', split_einsums=True) suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env) for node in sufficient_statistic_nodes] correct_suff_stats = [x, np.outer(x, x), x * x, np.log(x), np.reciprocal(x), x * np.log(x)] self.assertTrue(_match_values(suff_stats, correct_suff_stats))
def testExtractSuperexpr(self): def f(x, y): return 2 * x ** 2 + y expr = tracers.make_expr(f, 1, 2) node = expr.expr_node.parents[0].parents[0] # x ** 2 new_expr = tracers.extract_superexpr(expr, {'x2': node}) self.assertEqual(2 * 5 + 6, tracers.eval_expr(new_expr, {'x2': 5, 'y': 6}))
def testEinsumOneArg(self): x = npr.randn(10) def fun(x): return np.einsum('a->a', x) expr = tracers.make_expr(fun, x) self.assertNotEqual(expr.expr_node.fun, tracers.env_lookup) expr = tracers.remake_expr(expr, {np.einsum: rewrites.maybe_einsum}) self.assertAllClose(tracers.eval_expr(expr, {'x': x}), fun(x)) self.assertEqual(expr.expr_node.fun, tracers.env_lookup)
def testInlineExpr(self): def f(x, y): return 2 * x + y def g(z): return 3 * z + z**2 expr = tracers.make_expr(f, 1, 2) subexpr = tracers.make_expr(g, 3) target_node = expr.expr_node.parents[0] new_expr = tracers.inline_expr(subexpr, {'z': target_node}) printed_expr = tracers.print_expr(new_expr) expected = ("temp_0 = multiply(2, x)\n" "temp_1 = multiply(3, temp_0)\n" "temp_2 = power(temp_0, 2)\n" "temp_3 = add(temp_1, temp_2)\n") self.assertEqual(printed_expr, expected) self.assertEqual(3 * (2 * 5) + (2 * 5)**2, tracers.eval_expr(new_expr, {'x': 5})) self.assertEqual(f(6, 7), tracers.eval_expr(expr, {'x': 6, 'y': 7}))
def testReplaceNodeWithExpr(self): def f(x): return 2 * x def g(x): return 3 * x expr = tracers.make_expr(f, 5) new_expr = tracers.make_expr(g, 10) tracers.replace_node_with_expr(expr.expr_node, new_expr) self.assertEqual(3 * 7, tracers.eval_expr(expr, {'x': 7}))
def testExtractSuperexprWithReplaceNode(self): # NOTE(mattjj): this test shows an alternative way to implement, in effect, # tracers.extract_superexpr just using tracers.replace_node_with_expr. The # reason to have both is that one does in-place modification. def f(x, y): return 2 * x ** 2 + y expr = tracers.make_expr(f, 1, 2) node = expr.expr_node.parents[0].parents[0] # x ** 2 lookup_expr = tracers.make_expr(lambda x: x, 3, names=('x2',)) tracers.replace_node_with_expr(node, lookup_expr) # modify expr in-place self.assertEqual(2 * 5 + 6, tracers.eval_expr(expr, {'x2': 5, 'y': 6}))
def testInlineExprAndReplace(self): def f(x, y): return 2 * x ** 2 + y def g(z): return 3 * z ** 3 expr = tracers.make_expr(f, 1, 2) subexpr = tracers.make_expr(g, 3) input_node = expr.expr_node.parents[0].parents[0] # x ** 2 output_node = expr.expr_node.parents[0] # 2 * x ** 2 new_expr = tracers.inline_expr(subexpr, {'z': input_node}) tracers.replace_node_with_expr(output_node, new_expr) # modify expr inplace self.assertEqual(3 * 6 ** 6 + 7, tracers.eval_expr(expr, {'x': 6, 'y': 7}))