def test_copy(self): graph = tf.Graph() _, info = ge.copy(self.graph, graph) self.assertEqual(set(op.name for op in self.graph.get_operations()), set(op.name for op in graph.get_operations())) src_ops = self.graph.get_operations() dst_ops = graph.get_operations() for op in src_ops: op_ = info.transformed(op) self.assertTrue(op_ in dst_ops) self.assertEqual(op.name, op_.name) self.assertEqual(info.original(op_), op) src_ts = ge.util.get_tensors(self.graph) dst_ts = ge.util.get_tensors(graph) for t in src_ts: t_ = info.transformed(t) self.assertTrue(t_ in dst_ts) self.assertEqual(t.name, t_.name) self.assertEqual(info.original(t_), t)
def test_graph_cond(self): graph = ops.Graph() with graph.as_default(): choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) result = control_flow_ops.cond( choice, lambda: constant_op.constant(1), lambda: constant_op.constant(2)) copied_graph = ops.Graph() _, copy_info = ge.copy( graph, dst_graph=copied_graph, dst_scope="imported") copied_result = copy_info.transformed(result) copied_choice = copy_info.transformed(choice) with copied_graph.as_default(): with session.Session() as sess: res = sess.run(copied_result, feed_dict={copied_choice: True}) self.assertEqual(res, 1) res = sess.run(copied_result, feed_dict={copied_choice: False}) self.assertEqual(res, 2)
def test_graph_cond(self): graph = ops.Graph() with graph.as_default(): choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) result = control_flow_ops.cond(choice, lambda: constant_op.constant(1), lambda: constant_op.constant(2)) copied_graph = ops.Graph() _, copy_info = ge.copy(graph, dst_graph=copied_graph, dst_scope="imported") copied_result = copy_info.transformed(result) copied_choice = copy_info.transformed(choice) with copied_graph.as_default(): with session.Session() as sess: res = sess.run(copied_result, feed_dict={copied_choice: True}) self.assertEqual(res, 1) res = sess.run(copied_result, feed_dict={copied_choice: False}) self.assertEqual(res, 2)
def test_graph_while_loop(self): graph = ops.Graph() with graph.as_default(): max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) index_start = constant_op.constant(1) sum_start = constant_op.constant(0) _, result = control_flow_ops.while_loop( cond=lambda i, unused_s: i <= max_index, body=lambda i, s: (i + 1, s + i), loop_vars=[index_start, sum_start]) copied_graph = ops.Graph() _, copy_info = ge.copy( graph, dst_graph=copied_graph, dst_scope="imported") copied_result = copy_info.transformed(result) copied_max_index = copy_info.transformed(max_index) with copied_graph.as_default(): with session.Session() as sess: n = 10 sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) self.assertEqual(sum_val, 55)
def test_graph_while_loop(self): graph = ops.Graph() with graph.as_default(): max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) index_start = constant_op.constant(1) sum_start = constant_op.constant(0) _, result = control_flow_ops.while_loop( cond=lambda i, unused_s: i <= max_index, body=lambda i, s: (i + 1, s + i), loop_vars=[index_start, sum_start]) copied_graph = ops.Graph() _, copy_info = ge.copy(graph, dst_graph=copied_graph, dst_scope="imported") copied_result = copy_info.transformed(result) copied_max_index = copy_info.transformed(max_index) with copied_graph.as_default(): with session.Session() as sess: n = 10 sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) self.assertEqual(sum_val, 55)
def test_copy(self): graph = tf.Graph() ge.copy(self.graph, graph) self.assertEqual(set(op.name for op in self.graph.get_operations()), set(op.name for op in graph.get_operations()))