def testFullSumRewriter(self): def fun(x): return np.sum(x) x = npr.randn(4, 3) expr = tracers.make_expr(fun, x) self.assertEqual(expr.expr_node.fun, np.sum) expr = self._rewriter_test_helper(fun, rewrites.replace_sum, x) self.assertEqual(expr.expr_node.fun, np.einsum) def fun(x): return np.sum(x, None) expr = tracers.make_expr(fun, x) self.assertEqual(expr.expr_node.fun, np.sum) expr = self._rewriter_test_helper(fun, rewrites.replace_sum, x) self.assertEqual(expr.expr_node.fun, np.einsum)
def testAllDescendantsOf(self): def f(x, y): return 2 * x ** 2 + y expr = tracers.make_expr(f, 1, 2) xnode = expr.free_vars['x'] ynode = expr.free_vars['y'] descendants = tracers.all_descendants_of(expr.expr_node, ynode) self.assertEqual(descendants, {ynode, expr.expr_node})
def testEinsumOneArg(self): x = npr.randn(10) def fun(x): return np.einsum('a->a', x) expr = tracers.make_expr(fun, x) self.assertNotEqual(expr.expr_node.fun, tracers.env_lookup) expr = tracers.remake_expr(expr, {np.einsum: rewrites.maybe_einsum}) self.assertAllClose(tracers.eval_expr(expr, {'x': x}), fun(x)) self.assertEqual(expr.expr_node.fun, tracers.env_lookup)
def testCompoundPatternNameConstraints(self): def fun(x, y): return 3 * x + y**2 x = np.ones(2) y = 2 * np.ones(2) end_node = tracers.make_expr(fun, x, y).expr_node match = matchers.matcher( (Add, (Multiply, 3, Val('x')), (Power, Val('x'), 2))) self.assertFalse(match(end_node)) def fun(x, y): return 3 * x + x**2 # note x used twice x = np.ones(2) y = 2 * np.ones(2) end_node = tracers.make_expr(fun, x, y).expr_node self.assertEqual(match(end_node), {'x': end_node.args[0].args[1]})
def testExtractSuperexpr(self): def f(x, y): return 2 * x ** 2 + y expr = tracers.make_expr(f, 1, 2) node = expr.expr_node.parents[0].parents[0] # x ** 2 new_expr = tracers.extract_superexpr(expr, {'x2': node}) self.assertEqual(2 * 5 + 6, tracers.eval_expr(new_expr, {'x2': 5, 'y': 6}))
def _eager_rewriter_test_helper(self, fun, rewriter, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) expr = tracers.remake_expr(expr, {rewrite_node.fun: rewriter}) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def testInlineExpr(self): def f(x, y): return 2 * x + y def g(z): return 3 * z + z**2 expr = tracers.make_expr(f, 1, 2) subexpr = tracers.make_expr(g, 3) target_node = expr.expr_node.parents[0] new_expr = tracers.inline_expr(subexpr, {'z': target_node}) printed_expr = tracers.print_expr(new_expr) expected = ("temp_0 = multiply(2, x)\n" "temp_1 = multiply(3, temp_0)\n" "temp_2 = power(temp_0, 2)\n" "temp_3 = add(temp_1, temp_2)\n") self.assertEqual(printed_expr, expected) self.assertEqual(3 * (2 * 5) + (2 * 5)**2, tracers.eval_expr(new_expr, {'x': 5})) self.assertEqual(f(6, 7), tracers.eval_expr(expr, {'x': 6, 'y': 7}))
def testPrintExpr(self): def fun(x, y): return 2 * x**2 + np.tanh(y) expr = tracers.make_expr(fun, 4, 5) printed_expr = tracers.print_expr(expr) expected = ("temp_0 = power(x, 2)\n" "temp_1 = multiply(2, temp_0)\n" "temp_2 = tanh(y)\n" "temp_3 = add(temp_1, temp_2)\n") self.assertEqual(printed_expr, expected)
def testInlineExprAndReplace(self): def f(x, y): return 2 * x**2 + y def g(z): return 3 * z**3 expr = tracers.make_expr(f, 1, 2) subexpr = tracers.make_expr(g, 3) input_node = expr.expr_node.parents[0].parents[0] # x ** 2 output_node = expr.expr_node.parents[0] # 2 * x ** 2 new_expr = tracers.inline_expr(subexpr, {'z': input_node}) tracers.replace_node_with_expr(output_node, new_expr) # modify expr inplace self.assertEqual(3 * 6**6 + 7, tracers.eval_expr(expr, { 'x': 6, 'y': 7 }))
def _rewriter_test_helper(self, fun, rewrite_rule, *args, **kwargs): expr = kwargs.get('expr') or tracers.make_expr(fun, *args) self.assertIsInstance(expr, tracers.GraphExpr) env = dict(zip(inspect.getargspec(fun).args, args)) self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) rewriter = rewrites.make_rewriter(rewrite_rule) rewrite_node = kwargs.get('rewrite_node', expr.expr_node) rewriter(rewrite_node) # modifies expr in-place self.assertAllClose(fun(*args), tracers.eval_expr(expr, env)) return tracers.remake_expr(expr) # constant folding
def testLiterals(self): match = matchers.matcher(3) self.assertTrue(match(3)) def fun(x): return 2 + x x = np.ones(2) end_node = tracers.make_expr(fun, x).expr_node match = matchers.matcher((Add, 2, Val)) self.assertTrue(match(end_node))
def testDescendantOf(self): def f(x, y): return 2 * x ** 2 + y expr = tracers.make_expr(f, 1, 2) xnode = expr.free_vars['x'] ynode = expr.free_vars['y'] self.assertTrue(tracers.is_descendant_of(expr.expr_node, xnode)) self.assertTrue(tracers.is_descendant_of(xnode, xnode)) self.assertTrue(tracers.is_descendant_of(expr.expr_node, expr.expr_node)) self.assertFalse(tracers.is_descendant_of(xnode, ynode))
def testDotRewriter(self): def fun(x, y): return np.dot(x, y) x = npr.randn(4, 3) y = npr.randn(3) expr = tracers.make_expr(fun, x, y) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'dot') expr = self._eager_rewriter_test_helper(fun, rewrites.dot_as_einsum, x, y) self.assertEqual(expr.expr_node.fun.__name__, 'einsum')
def testEinsumTransposeRewriter(self): def fun(x, y): return np.einsum('ij,j->i', x.T, y) x = npr.randn(4, 3) y = npr.randn(4) expr = tracers.make_expr(fun, x, y) self.assertEqual(expr.expr_node.fun.__name__, 'einsum') expr = self._rewriter_test_helper(fun, rewrites.transpose_inside_einsum, x, y) self.assertFalse('transpose' in tracers.print_expr(expr))
def testEinsumDistributeRewriter(self): def fun(x, y, z): return np.einsum('ij,j->i', x, tracers.add_n(y, z)) x = npr.randn(4, 3) y = npr.randn(3) z = npr.randn(3) expr = tracers.make_expr(fun, x, y, z) self.assertEqual(expr.expr_node.fun.__name__, 'einsum') expr = self._rewriter_test_helper(fun, rewrites.distribute_einsum, x, y, z) self.assertEqual(expr.expr_node.fun.__name__, 'add_n')
def testCompoundPatternNameBindings(self): def fun(x, y): return 3 * x + y**2 x = np.ones(2) y = 2 * np.ones(2) end_node = tracers.make_expr(fun, x, y).expr_node match = matchers.matcher( (Add, (Multiply, 3, Val('x')), (Power, Val('y'), 2))) self.assertEqual(match(end_node), { 'x': end_node.args[0].args[1], 'y': end_node.args[1].args[0] })
def testSegmentsEmpty(self): def fun(x, y, z): return np.einsum('i,j,ij->', x - y, x, z) x = np.ones(3) y = 2 * np.ones(3) z = 3 * np.ones((3, 3)) end_node = tracers.make_expr(fun, x, y, z).expr_node pat = (Einsum, Str('formula'), Segment('args1'), (Choice(Subtract('op'), Add('op')), Val('x'), Val('y')), Segment('args2')) match = matchers.matcher(pat) self.assertTrue(match(end_node))
def testEinsumAddSubSimplify(self): # TODO(mhoffman): Think about broadcasting. We need to support `x - 2.0`. def test_fun(x): return np.einsum('i->', x + np.full(x.shape, 2.0)) expr = make_expr(test_fun, np.ones(3)) test_x = np.full(3, 0.5) correct_value = eval_expr(expr, {'x': test_x}) expr = canonicalize(expr) self.assertIsInstance(expr, GraphExpr) self.assertEqual(expr.expr_node.fun, add_n) self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum') new_value = eval_expr(expr, {'x': test_x}) self.assertEqual(correct_value, new_value)
def testEinsumCompositionRewriter(self): def fun(x, y, z): return np.einsum('ij,jk->i', x, np.einsum('ija,ijk->ka', y, z)) x = npr.randn(4, 3) y = npr.randn(5, 4, 2) z = npr.randn(5, 4, 3) expr = tracers.make_expr(fun, x, y, z) self.assertEqual(expr.expr_node.fun.__name__, 'einsum') self.assertEqual(expr.expr_node.parents[1].fun.__name__, 'einsum') expr = self._rewriter_test_helper(fun, rewrites.combine_einsum_compositions, x, y, z) self.assertNotEqual(expr.expr_node.parents[1].fun.__name__, 'einsum')
def testOneElementPatternNameBinding(self): def fun(x, y): return 3 * x + y**2 x = np.ones(2) y = 2 * np.ones(2) end_node = tracers.make_expr(fun, x, y).expr_node match = matchers.matcher(Val('z')) self.assertEqual(match(end_node), {'z': end_node}) match = matchers.matcher(Add('z')) self.assertEqual(match(end_node), {'z': end_node.fun}) match = matchers.matcher(Multiply('z')) self.assertFalse(match(end_node))
def testOneElementPattern(self): def fun(x, y): return 3 * x + y**2 x = np.ones(2) y = 2 * np.ones(2) end_node = tracers.make_expr(fun, x, y).expr_node match = matchers.matcher(Val) self.assertTrue(match(end_node)) match = matchers.matcher(Add) self.assertTrue(match(end_node)) match = matchers.matcher(Multiply) self.assertFalse(match(end_node))
def testCanonicalize(self): def mahalanobis_distance(x, y, matrix): x_minus_y = x - y return np.einsum('i,j,ij->', x_minus_y, x_minus_y, matrix) x = np.array([1.3, 3.6]) y = np.array([2.3, -1.2]) matrix = np.arange(4).reshape([2, 2]) expr = make_expr(mahalanobis_distance, x, y, matrix) self.assertFalse(is_canonical(expr)) correct_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix}) expr = canonicalize(expr) self.assertTrue(is_canonical(expr)) new_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix}) self.assertAlmostEqual(correct_value, new_value)
def _test_condition_and_marginalize_diagonal_zero_mean_normal(self, log_joint): n_dimensions = 5 x = np.random.randn(n_dimensions) tau = np.random.randn(n_dimensions) ** 2 end_node = make_expr(log_joint, x, tau) end_node = canonicalize(end_node) conditional, marginalized_value = _condition_and_marginalize( log_joint, 0, SupportTypes.REAL, x, tau) correct_marginalized_value = (-0.5 * np.log(tau).sum() + 0.5 * n_dimensions * np.log(2. * np.pi)) self.assertAlmostEqual(correct_marginalized_value, marginalized_value) self.assertTrue(np.allclose(np.zeros(n_dimensions), conditional.args[0])) self.assertTrue(np.allclose(1. / np.sqrt(tau), conditional.args[1]))
def testLogEinsumRewriter(self): def fun(x, y): return np.log(np.einsum('ij,ij->ij', x, y)) x = np.exp(npr.randn(4, 3)) y = np.exp(npr.randn(4, 3)) z = np.exp(npr.randn(4)) expr = tracers.make_expr(fun, x, y) self.assertEqual(expr.expr_node.fun.__name__, 'log') self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum') expr = self._rewriter_test_helper(fun, rewrites.replace_log_einsum, x, y) self.assertEqual(expr.expr_node.fun.__name__, 'add') self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'log') self.assertEqual(expr.expr_node.parents[1].fun.__name__, 'log')
def testLinearRegression(self): def log_joint(X, beta, y): predictions = np.einsum('ij,j->i', X, beta) errors = y - predictions log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta) log_likelihood = np.einsum(',k,k->', -0.5, errors, errors) return log_prior + log_likelihood n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) y = np.random.randn(n_examples) graph = make_expr(log_joint, X, beta, y) graph = canonicalize(graph) args = graph.free_vars.keys() sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1]) sufficient_statistics = [eval_node(node, graph.free_vars, {'X': X, 'beta': beta, 'y': y}) for node in sufficient_statistic_nodes] correct_sufficient_statistics = [ -0.5 * beta.dot(beta), beta, -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta) ] self.assertTrue(_match_values(sufficient_statistics, correct_sufficient_statistics)) _, natural_parameter_funs = _extract_conditional_factors(graph, 'beta') self.assertTrue(_match_values(natural_parameter_funs.keys(), ['x', 'einsum(...a,...b->...ab, x, x)', 'einsum(...,...->..., x, x)'], lambda x, y: x == y)) natural_parameter_vals = [f(X, beta, y) for f in natural_parameter_funs.values()] correct_parameter_vals = [-0.5 * np.ones(n_predictors), -0.5 * X.T.dot(X), y.dot(X)] self.assertTrue(_match_values(natural_parameter_vals, correct_parameter_vals)) conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL, X, beta, y) conditional = conditional_factory(X, y) true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors)) true_mean = true_cov.dot(y.dot(X)) self.assertTrue(np.allclose(true_cov, conditional.cov)) self.assertTrue(np.allclose(true_mean, conditional.mean))
def testLinearRegression(self): def squared_loss(X, beta, y): predictions = np.einsum('ij,j->i', X, beta) errors = y - predictions return np.einsum('k,k->', errors, errors) n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) y = np.random.randn(n_examples) expr = make_expr(squared_loss, X, beta, y) correct_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y}) self.assertFalse(is_canonical(expr)) expr = canonicalize(expr) self.assertTrue(is_canonical(expr)) new_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y}) self.assertAlmostEqual(correct_value, new_value)
def testEinsumCompose(self): def Xbeta_squared(X, beta): Xbeta = np.einsum('ij,j->i', X, beta) Xbeta2 = np.einsum('lm,m->l', X, beta) return np.einsum('k,k->', Xbeta, Xbeta) n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) expr = make_expr(Xbeta_squared, X, beta) correct_value = eval_expr(expr, {'X': X, 'beta': beta}) self.assertFalse(is_canonical(expr)) expr = canonicalize(expr) new_value = eval_expr(expr, {'X': X, 'beta': beta}) self.assertAlmostEqual(correct_value, new_value) self.assertIsInstance(expr, GraphExpr) self.assertEqual(expr.expr_node.fun, np.einsum) self.assertTrue(is_canonical(expr))
def testLinearRegression(self): def log_joint(X, beta, y): predictions = np.einsum('ij,j->i', X, beta) errors = y - predictions log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta) log_likelihood = np.einsum(',k,k->', -0.5, errors, errors) return log_prior + log_likelihood n_examples = 10 n_predictors = 2 X = np.random.randn(n_examples, n_predictors) beta = np.random.randn(n_predictors) y = np.random.randn(n_examples) graph = make_expr(log_joint, X, beta, y) graph = canonicalize(graph) args = graph.free_vars.keys() sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1]) sufficient_statistics = [eval_node(node, graph.free_vars, {'X': X, 'beta': beta, 'y': y}) for node in sufficient_statistic_nodes] correct_sufficient_statistics = [ -0.5 * beta.dot(beta), beta, -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta) ] self.assertTrue(_match_values(sufficient_statistics, correct_sufficient_statistics)) new_log_joint, _, stats_funs, _ = ( statistic_representation(log_joint, (X, beta, y), (SupportTypes.REAL,), (1,))) beta_stat_fun = stats_funs[0] beta_natparam = grad_namedtuple(new_log_joint, 1)(X, beta_stat_fun(beta), y) correct_beta_natparam = (-0.5 * X.T.dot(X), y.dot(X), -0.5 * np.ones(n_predictors)) self.assertTrue(_match_values(beta_natparam, correct_beta_natparam)) conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL, X, beta, y) conditional = conditional_factory(X, y) true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors)) true_mean = true_cov.dot(y.dot(X)) self.assertTrue(np.allclose(true_cov, conditional.cov)) self.assertTrue(np.allclose(true_mean, conditional.mean))
def testEvalPerturbed(self): def f(x, y): return 2 * x**2 + y expr = tracers.make_expr(f, 1, 2) node = expr.expr_node.parents[0].parents[0] # x ** 2 self.assertEqual( 2 * (4**2 + 3) + 5, tracers._eval_perturbed(expr, {node: 3}, { 'x': 4, 'y': 5 })) node = expr.free_vars['x'] self.assertEqual( 2 * (4 + 3)**2 + 5, tracers._eval_perturbed(expr, {node: 3}, { 'x': 4, 'y': 5 }))
def testSplitEinsumNode2(self): n_dimensions = 5 x = np.random.randn(n_dimensions) y = np.random.randn(n_dimensions) matrix = np.random.randn(n_dimensions, n_dimensions) env = {'x': x, 'y': y, 'matrix': matrix} args = (x, y) f = lambda x, y: np.einsum('i,i->', x, y) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [0]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), x)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [1]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), y)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), x * y)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) args = (x, y) f = lambda x, y: np.einsum('i,i,i->', x, y, y) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), y * y)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [0]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), x)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) args = (x,) f = lambda x: np.einsum('i,i,i->', np.ones_like(x), x, x) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), x * x)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) args = (matrix, x, y) f = lambda matrix, x, y: np.einsum('ij,i,j->', matrix, x, y) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [1, 2]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), np.outer(x, y))) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [0]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), matrix)) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) args = (matrix, x, y) f = lambda matrix, x, y: np.einsum('i,j,ki,kj->', x, x, matrix, matrix) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [2, 3]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), matrix[:, None, :] * matrix[:, :, None])) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1]) self.assertTrue(np.allclose(eval_node(stat_node, node, env), np.outer(x, x))) self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) args = (matrix, x, y) f = lambda matrix, x, y: np.einsum(',kj,j,ka,a->', -0.5, matrix, x, matrix, y) node = make_expr(f, *args) val = f(*args) potential_node, stat_node = split_einsum_node2(node.expr_node, [2, 4], False) self.assertEqual(stat_node.args[0], 'j,a->ja') self.assertTrue(np.allclose(eval_node(potential_node, node, env), val)) potential_node, stat_node = split_einsum_node2(node.expr_node, [0, 1, 3], False) self.assertEqual(stat_node.args[0], ',kj,ka->kja') self.assertTrue(np.allclose(eval_node(potential_node, node, env), val))