Ejemplo n.º 1
0
    def testDuplicatedAddNRewriter(self):
        x = npr.randn(4, 3)
        y = npr.randn(3)
        z = npr.randn(1, 3)

        expr = self._rewriter_test_helper(
            lambda x, y, z: tracers.add_n(x, y, z, y),
            rewrites.replace_duplicated_addn, x, y, z)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')
        self.assertEqual(len(expr.expr_node.parents), 3)
        self.assertTrue(
            all([
                parent.fun.__name__ in ('multiply', 'env_lookup')
                for parent in expr.expr_node.parents
            ]))
Ejemplo n.º 2
0
    def testAddNRewriter(self):
        x = npr.randn(4, 3)
        y = npr.randn(3)
        z = npr.randn(1, 3)

        expr = self._rewriter_test_helper(
            lambda x, y, z: x + tracers.add_n(y, z), rewrites.replace_add_addn,
            x, y, z)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')

        expr = self._rewriter_test_helper(
            lambda x, y, z: tracers.add_n(x, y) + z, rewrites.replace_add_addn,
            x, y, z)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')

        expr = self._rewriter_test_helper(
            lambda x, y, z: tracers.add_n(tracers.add_n(x, y), z),
            rewrites.replace_addn_addn, x, y, z)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')
        self.assertTrue(
            all([
                parent.fun.__name__ == 'env_lookup'
                for parent in expr.expr_node.parents
            ]))

        expr = self._rewriter_test_helper(
            lambda x, y, z: tracers.add_n(x, tracers.add_n(y, z), z),
            rewrites.replace_addn_addn, x, y, z)
        self.assertIsInstance(expr, tracers.GraphExpr)
        self.assertEqual(expr.expr_node.fun.__name__, 'add_n')
        self.assertTrue(
            all([
                parent.fun.__name__ == 'env_lookup'
                for parent in expr.expr_node.parents
            ]))
Ejemplo n.º 3
0
 def fun(x, a, b):
     return tracers.logdet(
         tracers.add_n(np.einsum(',ab->ab', x, a),
                       np.einsum(',ab->ab', x, b)))
Ejemplo n.º 4
0
 def fun(x, a, b):
     return np.linalg.inv(
         tracers.add_n(np.einsum(',ab->ab', x, a),
                       np.einsum(',ab->ab', x, b)))
Ejemplo n.º 5
0
 def fun(x, a, b):
     return tracers.add_n(np.einsum(',a->a', x, a),
                          np.einsum(',a->a', x, b))**3
Ejemplo n.º 6
0
 def fun(a, x, y, z):
     return np.log(
         tracers.add_n(np.einsum(',a->', a, x), np.einsum(',a->', a, y),
                       np.einsum(',a->', a, z)))
Ejemplo n.º 7
0
 def fun(x, y, z):
     return x * tracers.add_n(y, z)
Ejemplo n.º 8
0
 def fun(x, y, z):
     return np.einsum('ij,j->i', x, tracers.add_n(y, z))
Ejemplo n.º 9
0
 def f(x):
     return tracers.add_n(np.einsum(',i->i', x, np.ones(3)),
                          np.einsum(',j->j', x, 2. * np.ones(3)))