Exemplo n.º 1
0
 def test_copy(self):
     graph = tf.Graph()
     _, info = ge.copy(self.graph, graph)
     self.assertEqual(set(op.name for op in self.graph.get_operations()),
                      set(op.name for op in graph.get_operations()))
     src_ops = self.graph.get_operations()
     dst_ops = graph.get_operations()
     for op in src_ops:
         op_ = info.transformed(op)
         self.assertTrue(op_ in dst_ops)
         self.assertEqual(op.name, op_.name)
         self.assertEqual(info.original(op_), op)
     src_ts = ge.util.get_tensors(self.graph)
     dst_ts = ge.util.get_tensors(graph)
     for t in src_ts:
         t_ = info.transformed(t)
         self.assertTrue(t_ in dst_ts)
         self.assertEqual(t.name, t_.name)
         self.assertEqual(info.original(t_), t)
Exemplo n.º 2
0
 def test_copy(self):
   graph = tf.Graph()
   _, info = ge.copy(self.graph, graph)
   self.assertEqual(set(op.name for op in self.graph.get_operations()),
                    set(op.name for op in graph.get_operations()))
   src_ops = self.graph.get_operations()
   dst_ops = graph.get_operations()
   for op in src_ops:
     op_ = info.transformed(op)
     self.assertTrue(op_ in dst_ops)
     self.assertEqual(op.name, op_.name)
     self.assertEqual(info.original(op_), op)
   src_ts = ge.util.get_tensors(self.graph)
   dst_ts = ge.util.get_tensors(graph)
   for t in src_ts:
     t_ = info.transformed(t)
     self.assertTrue(t_ in dst_ts)
     self.assertEqual(t.name, t_.name)
     self.assertEqual(info.original(t_), t)
Exemplo n.º 3
0
 def test_graph_cond(self):
   graph = ops.Graph()
   with graph.as_default():
     choice = array_ops.placeholder(shape=(), dtype=dtypes.bool)
     result = control_flow_ops.cond(
         choice,
         lambda: constant_op.constant(1),
         lambda: constant_op.constant(2))
   copied_graph = ops.Graph()
   _, copy_info = ge.copy(
       graph, dst_graph=copied_graph, dst_scope="imported")
   copied_result = copy_info.transformed(result)
   copied_choice = copy_info.transformed(choice)
   with copied_graph.as_default():
     with session.Session() as sess:
       res = sess.run(copied_result, feed_dict={copied_choice: True})
       self.assertEqual(res, 1)
       res = sess.run(copied_result, feed_dict={copied_choice: False})
       self.assertEqual(res, 2)
Exemplo n.º 4
0
 def test_graph_cond(self):
     graph = ops.Graph()
     with graph.as_default():
         choice = array_ops.placeholder(shape=(), dtype=dtypes.bool)
         result = control_flow_ops.cond(choice,
                                        lambda: constant_op.constant(1),
                                        lambda: constant_op.constant(2))
     copied_graph = ops.Graph()
     _, copy_info = ge.copy(graph,
                            dst_graph=copied_graph,
                            dst_scope="imported")
     copied_result = copy_info.transformed(result)
     copied_choice = copy_info.transformed(choice)
     with copied_graph.as_default():
         with session.Session() as sess:
             res = sess.run(copied_result, feed_dict={copied_choice: True})
             self.assertEqual(res, 1)
             res = sess.run(copied_result, feed_dict={copied_choice: False})
             self.assertEqual(res, 2)
Exemplo n.º 5
0
 def test_graph_while_loop(self):
   graph = ops.Graph()
   with graph.as_default():
     max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple())
     index_start = constant_op.constant(1)
     sum_start = constant_op.constant(0)
     _, result = control_flow_ops.while_loop(
         cond=lambda i, unused_s: i <= max_index,
         body=lambda i, s: (i + 1, s + i),
         loop_vars=[index_start, sum_start])
   copied_graph = ops.Graph()
   _, copy_info = ge.copy(
       graph, dst_graph=copied_graph, dst_scope="imported")
   copied_result = copy_info.transformed(result)
   copied_max_index = copy_info.transformed(max_index)
   with copied_graph.as_default():
     with session.Session() as sess:
       n = 10
       sum_val = sess.run(copied_result, feed_dict={copied_max_index: n})
       self.assertEqual(sum_val, 55)
Exemplo n.º 6
0
 def test_graph_while_loop(self):
     graph = ops.Graph()
     with graph.as_default():
         max_index = array_ops.placeholder(dtype=dtypes.int32,
                                           shape=tuple())
         index_start = constant_op.constant(1)
         sum_start = constant_op.constant(0)
         _, result = control_flow_ops.while_loop(
             cond=lambda i, unused_s: i <= max_index,
             body=lambda i, s: (i + 1, s + i),
             loop_vars=[index_start, sum_start])
     copied_graph = ops.Graph()
     _, copy_info = ge.copy(graph,
                            dst_graph=copied_graph,
                            dst_scope="imported")
     copied_result = copy_info.transformed(result)
     copied_max_index = copy_info.transformed(max_index)
     with copied_graph.as_default():
         with session.Session() as sess:
             n = 10
             sum_val = sess.run(copied_result,
                                feed_dict={copied_max_index: n})
             self.assertEqual(sum_val, 55)
Exemplo n.º 7
0
 def test_copy(self):
   graph = tf.Graph()
   ge.copy(self.graph, graph)
   self.assertEqual(set(op.name for op in self.graph.get_operations()),
                    set(op.name for op in graph.get_operations()))