Exemple #1
0
    def test_unique_graph(self):
        """Test for pge.util.check_graphs and pge.util.get_unique_graph."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            tf.constant(1, name="a")
            tf.constant(2, name="b")
        g1_graph = tf.Graph()
        with g1_graph.as_default():
            tf.constant(1, name="a")
            tf.constant(2, name="b")

        g0 = pge.Graph(g0_graph.as_graph_def())
        g1 = pge.Graph(g1_graph.as_graph_def())
        a0, b0, a1, b1 = (g0["a"], g0["b"], g1["a"], g1["b"])

        print("g0['a'] returns {} (type {})".format(g0['a'], type(g0['a'])))

        # Same graph, should be fine.
        self.assertIsNone(pge.util.check_graphs(a0, b0))
        # Two different graphs, should assert.
        with self.assertRaises(ValueError):
            pge.util.check_graphs(a0, b0, a1, b1)
        # a0 and b0 belongs to the same graph, should be fine.
        self.assertEqual(pge.util.get_unique_graph([a0, b0]), g0)
        # Different graph, should raise an error.
        with self.assertRaises(ValueError):
            pge.util.get_unique_graph([a0, b0, a1, b1])
Exemple #2
0
    def test_placeholder(self):
        """Test placeholder functionalities."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            tf.constant(1, name="foo")

        g0 = pge.Graph(g0_graph)
        a0 = g0["foo"].output(0)

        # Test placeholder name.
        self.assertEqual(pge.util.placeholder_name(a0), "geph__foo_0")
        self.assertEqual(pge.util.placeholder_name(None), "geph")
        self.assertEqual(pge.util.placeholder_name(a0, scope="foo/"),
                         "foo/geph__foo_0")
        self.assertEqual(pge.util.placeholder_name(a0, scope="foo"),
                         "foo/geph__foo_0")
        self.assertEqual(pge.util.placeholder_name(None, scope="foo/"),
                         "foo/geph")
        self.assertEqual(pge.util.placeholder_name(None, scope="foo"),
                         "foo/geph")

        # Test placeholder creation.
        g1_graph = tf.Graph()
        with g1_graph.as_default():
            tf.constant(1, dtype=tf.float32, name="a1")

        g1 = pge.Graph(g1_graph)
        a1_tensor = g1["a1"].output(0)
        print("Type of a1_tensor is {}".format(type(a1_tensor)))

        ph1 = pge.util.make_placeholder_from_tensor(g1, a1_tensor)
        ph2 = pge.util.make_placeholder_from_dtype_and_shape(g1,
                                                             dtype=tf.float32)
        self.assertEqual(ph1.name, "geph__a1_0")
        self.assertEqual(ph2.name, "geph")
Exemple #3
0
    def test_transform_nodedef_fn(self):
        transformer = pge.Transformer()

        def nodedef_fn(node_def):
            if "_foo" in node_def.attr:
                del node_def.attr["_foo"]
            node_def.attr["_bar"].s = b"bar"
            return node_def

        my_copy_op_handler = functools.partial(pge.transform.copy_op_handler,
                                               nodedef_fn=nodedef_fn)
        transformer.transform_op_handler = my_copy_op_handler

        graph = pge.Graph()
        transformer(self.graph, graph, "", "")

        c0_before = self.graph["Const"]
        c0_after = graph["Const"]
        self.assertEqual(c0_before.get_attr("_foo"), "foo")
        with self.assertRaises(ValueError):
            c0_after.get_attr("_foo")

        all_ops = graph.nodes
        for op in all_ops:
            self.assertEqual(op.get_attr("_bar"), "bar")
Exemple #4
0
 def setUp(self):
     tf_graph = tf.Graph()
     with tf_graph.as_default():
         c0 = tf.constant(1.0, shape=[10], name="Const")
         c0.op._set_attr("_foo", tf.AttrValue(s=b"foo"))
         c1 = tf.constant(1.0, shape=[10], name="Const")
         c2 = tf.constant(1.0, shape=[10], name="Const")
         i = tf.constant(1.0, shape=[10], name="Input")
         tf.add(c2, tf.add(c1, tf.add(c0, i)), name="o")
     self.graph = pge.Graph(tf_graph)
     self.o = self.graph["o"]
Exemple #5
0
    def test_make_list_of_node(self):
        """Test for pge.util.make_list_of_op."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            tf.constant(1, name="a0")
            tf.constant(2, name="b0")
        g0 = pge.Graph(g0_graph)

        # Should extract the ops from the graph.
        self.assertEqual(len(pge.util.make_list_of_op(g0)), 2)
        # Should extract the ops from the tuple.
        self.assertEqual(len(pge.util.make_list_of_op((g0["a0"], g0["b0"]))),
                         2)
Exemple #6
0
    def test_get_generating_consuming(self):
        """Test for pge.util.get_generating_ops and pge.util.get_generating_ops."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            a0_tensor = tf.constant(1, name="a0")
            b0_tensor = tf.constant(2, name="b0")
            tf.add(a0_tensor, b0_tensor, name="c0")
        g0 = pge.Graph(g0_graph)
        a0 = g0["a0"].output(0)
        b0 = g0["b0"].output(0)
        c0 = g0["c0"].output(0)

        self.assertEqual(len(pge.util.get_generating_ops([a0, b0])), 2)
        self.assertEqual(len(pge.util.get_consuming_ops([a0, b0])), 1)
        self.assertEqual(len(pge.util.get_generating_ops([c0])), 1)
        self.assertEqual(pge.util.get_consuming_ops([c0]), [])
Exemple #7
0
    def test_control_outputs(self):
        """Test for the pge.util.ControlOutputs class."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            a0_tensor = tf.constant(1, name="a0")
            b0_tensor = tf.constant(2, name="b0")
            x0_tensor = tf.constant(3, name="x0")
            with tf.control_dependencies([x0_tensor.op]):
                tf.add(a0_tensor, b0_tensor, name="c0")

        g0 = pge.Graph(g0_graph)
        x0_node = g0["x0"]
        c0_node = g0["c0"]
        control_outputs = pge.util.ControlOutputs(g0).get_all()
        self.assertEqual(len(control_outputs), 1)
        self.assertEqual(len(control_outputs[x0_node]), 1)
        self.assertIs(list(control_outputs[x0_node])[0], c0_node)
Exemple #8
0
    def test_make_list_of_t(self):
        """Test for pge.util.make_list_of_t."""
        g0_graph = tf.Graph()
        with g0_graph.as_default():
            a0_op = tf.constant(1, name="a0")
            b0_op = tf.constant(2, name="b0")
            tf.add(a0_op, b0_op)
        g0 = pge.Graph(g0_graph)
        a0 = g0["a0"].output(0)
        b0 = g0["b0"].output(0)

        # Should extract the tensors from the graph.
        self.assertEqual(len(pge.util.make_list_of_t(g0)), 3)
        # Should extract the tensors from the tuple
        self.assertEqual(len(pge.util.make_list_of_t((a0, b0))), 2)
        # Should extract the tensors and ignore the ops.
        self.assertEqual(
            len(pge.util.make_list_of_t((a0, a0.operator, b0),
                                        ignore_ops=True)), 2)