def test_unique_graph(self): """Test for pge.util.check_graphs and pge.util.get_unique_graph.""" g0_graph = tf.Graph() with g0_graph.as_default(): tf.constant(1, name="a") tf.constant(2, name="b") g1_graph = tf.Graph() with g1_graph.as_default(): tf.constant(1, name="a") tf.constant(2, name="b") g0 = pge.Graph(g0_graph.as_graph_def()) g1 = pge.Graph(g1_graph.as_graph_def()) a0, b0, a1, b1 = (g0["a"], g0["b"], g1["a"], g1["b"]) print("g0['a'] returns {} (type {})".format(g0['a'], type(g0['a']))) # Same graph, should be fine. self.assertIsNone(pge.util.check_graphs(a0, b0)) # Two different graphs, should assert. with self.assertRaises(ValueError): pge.util.check_graphs(a0, b0, a1, b1) # a0 and b0 belongs to the same graph, should be fine. self.assertEqual(pge.util.get_unique_graph([a0, b0]), g0) # Different graph, should raise an error. with self.assertRaises(ValueError): pge.util.get_unique_graph([a0, b0, a1, b1])
def test_placeholder(self): """Test placeholder functionalities.""" g0_graph = tf.Graph() with g0_graph.as_default(): tf.constant(1, name="foo") g0 = pge.Graph(g0_graph) a0 = g0["foo"].output(0) # Test placeholder name. self.assertEqual(pge.util.placeholder_name(a0), "geph__foo_0") self.assertEqual(pge.util.placeholder_name(None), "geph") self.assertEqual(pge.util.placeholder_name(a0, scope="foo/"), "foo/geph__foo_0") self.assertEqual(pge.util.placeholder_name(a0, scope="foo"), "foo/geph__foo_0") self.assertEqual(pge.util.placeholder_name(None, scope="foo/"), "foo/geph") self.assertEqual(pge.util.placeholder_name(None, scope="foo"), "foo/geph") # Test placeholder creation. g1_graph = tf.Graph() with g1_graph.as_default(): tf.constant(1, dtype=tf.float32, name="a1") g1 = pge.Graph(g1_graph) a1_tensor = g1["a1"].output(0) print("Type of a1_tensor is {}".format(type(a1_tensor))) ph1 = pge.util.make_placeholder_from_tensor(g1, a1_tensor) ph2 = pge.util.make_placeholder_from_dtype_and_shape(g1, dtype=tf.float32) self.assertEqual(ph1.name, "geph__a1_0") self.assertEqual(ph2.name, "geph")
def test_transform_nodedef_fn(self): transformer = pge.Transformer() def nodedef_fn(node_def): if "_foo" in node_def.attr: del node_def.attr["_foo"] node_def.attr["_bar"].s = b"bar" return node_def my_copy_op_handler = functools.partial(pge.transform.copy_op_handler, nodedef_fn=nodedef_fn) transformer.transform_op_handler = my_copy_op_handler graph = pge.Graph() transformer(self.graph, graph, "", "") c0_before = self.graph["Const"] c0_after = graph["Const"] self.assertEqual(c0_before.get_attr("_foo"), "foo") with self.assertRaises(ValueError): c0_after.get_attr("_foo") all_ops = graph.nodes for op in all_ops: self.assertEqual(op.get_attr("_bar"), "bar")
def setUp(self): tf_graph = tf.Graph() with tf_graph.as_default(): c0 = tf.constant(1.0, shape=[10], name="Const") c0.op._set_attr("_foo", tf.AttrValue(s=b"foo")) c1 = tf.constant(1.0, shape=[10], name="Const") c2 = tf.constant(1.0, shape=[10], name="Const") i = tf.constant(1.0, shape=[10], name="Input") tf.add(c2, tf.add(c1, tf.add(c0, i)), name="o") self.graph = pge.Graph(tf_graph) self.o = self.graph["o"]
def test_make_list_of_node(self): """Test for pge.util.make_list_of_op.""" g0_graph = tf.Graph() with g0_graph.as_default(): tf.constant(1, name="a0") tf.constant(2, name="b0") g0 = pge.Graph(g0_graph) # Should extract the ops from the graph. self.assertEqual(len(pge.util.make_list_of_op(g0)), 2) # Should extract the ops from the tuple. self.assertEqual(len(pge.util.make_list_of_op((g0["a0"], g0["b0"]))), 2)
def test_get_generating_consuming(self): """Test for pge.util.get_generating_ops and pge.util.get_generating_ops.""" g0_graph = tf.Graph() with g0_graph.as_default(): a0_tensor = tf.constant(1, name="a0") b0_tensor = tf.constant(2, name="b0") tf.add(a0_tensor, b0_tensor, name="c0") g0 = pge.Graph(g0_graph) a0 = g0["a0"].output(0) b0 = g0["b0"].output(0) c0 = g0["c0"].output(0) self.assertEqual(len(pge.util.get_generating_ops([a0, b0])), 2) self.assertEqual(len(pge.util.get_consuming_ops([a0, b0])), 1) self.assertEqual(len(pge.util.get_generating_ops([c0])), 1) self.assertEqual(pge.util.get_consuming_ops([c0]), [])
def test_control_outputs(self): """Test for the pge.util.ControlOutputs class.""" g0_graph = tf.Graph() with g0_graph.as_default(): a0_tensor = tf.constant(1, name="a0") b0_tensor = tf.constant(2, name="b0") x0_tensor = tf.constant(3, name="x0") with tf.control_dependencies([x0_tensor.op]): tf.add(a0_tensor, b0_tensor, name="c0") g0 = pge.Graph(g0_graph) x0_node = g0["x0"] c0_node = g0["c0"] control_outputs = pge.util.ControlOutputs(g0).get_all() self.assertEqual(len(control_outputs), 1) self.assertEqual(len(control_outputs[x0_node]), 1) self.assertIs(list(control_outputs[x0_node])[0], c0_node)
def test_make_list_of_t(self): """Test for pge.util.make_list_of_t.""" g0_graph = tf.Graph() with g0_graph.as_default(): a0_op = tf.constant(1, name="a0") b0_op = tf.constant(2, name="b0") tf.add(a0_op, b0_op) g0 = pge.Graph(g0_graph) a0 = g0["a0"].output(0) b0 = g0["b0"].output(0) # Should extract the tensors from the graph. self.assertEqual(len(pge.util.make_list_of_t(g0)), 3) # Should extract the tensors from the tuple self.assertEqual(len(pge.util.make_list_of_t((a0, b0))), 2) # Should extract the tensors and ignore the ops. self.assertEqual( len(pge.util.make_list_of_t((a0, a0.operator, b0), ignore_ops=True)), 2)