def testDuplicatedAddNRewriter(self): x = npr.randn(4, 3) y = npr.randn(3) z = npr.randn(1, 3) expr = self._rewriter_test_helper( lambda x, y, z: tracers.add_n(x, y, z, y), rewrites.replace_duplicated_addn, x, y, z) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'add_n') self.assertEqual(len(expr.expr_node.parents), 3) self.assertTrue( all([ parent.fun.__name__ in ('multiply', 'env_lookup') for parent in expr.expr_node.parents ]))
def testAddNRewriter(self): x = npr.randn(4, 3) y = npr.randn(3) z = npr.randn(1, 3) expr = self._rewriter_test_helper( lambda x, y, z: x + tracers.add_n(y, z), rewrites.replace_add_addn, x, y, z) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'add_n') expr = self._rewriter_test_helper( lambda x, y, z: tracers.add_n(x, y) + z, rewrites.replace_add_addn, x, y, z) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'add_n') expr = self._rewriter_test_helper( lambda x, y, z: tracers.add_n(tracers.add_n(x, y), z), rewrites.replace_addn_addn, x, y, z) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'add_n') self.assertTrue( all([ parent.fun.__name__ == 'env_lookup' for parent in expr.expr_node.parents ])) expr = self._rewriter_test_helper( lambda x, y, z: tracers.add_n(x, tracers.add_n(y, z), z), rewrites.replace_addn_addn, x, y, z) self.assertIsInstance(expr, tracers.GraphExpr) self.assertEqual(expr.expr_node.fun.__name__, 'add_n') self.assertTrue( all([ parent.fun.__name__ == 'env_lookup' for parent in expr.expr_node.parents ]))
def fun(x, a, b): return tracers.logdet( tracers.add_n(np.einsum(',ab->ab', x, a), np.einsum(',ab->ab', x, b)))
def fun(x, a, b): return np.linalg.inv( tracers.add_n(np.einsum(',ab->ab', x, a), np.einsum(',ab->ab', x, b)))
def fun(x, a, b): return tracers.add_n(np.einsum(',a->a', x, a), np.einsum(',a->a', x, b))**3
def fun(a, x, y, z): return np.log( tracers.add_n(np.einsum(',a->', a, x), np.einsum(',a->', a, y), np.einsum(',a->', a, z)))
def fun(x, y, z): return x * tracers.add_n(y, z)
def fun(x, y, z): return np.einsum('ij,j->i', x, tracers.add_n(y, z))
def f(x): return tracers.add_n(np.einsum(',i->i', x, np.ones(3)), np.einsum(',j->j', x, 2. * np.ones(3)))