def add_preprocessing(g, preproc_g): # type: (gde.Graph, gde.Graph) -> None """ Add preprocessing ops to a graph. Replaces one or more input `Placeholders` in the target graph with subgraphs that preprocess the input values prior to feeding them into the original graph. After performing this rewrite, the inputs of the resulting graph may have a different shape and dtype than before, but they will have the same names. Args: g: `gde.Graph` to which preprocessing should be added. *Modified in place.* preproc_g: `gde.Graph` containing the preprocessing ops to add. For each placeholder in `g` that needs preprocessing, `preproc_g` should contain a placeholder with the same name and a second op named "<name of placeholder>_preprocessed", where `<name of placeholder>` is the name of the Placeholder op. """ placeholders = gde.filter_ops_by_optype(preproc_g, "Placeholder") def preproc_name(placeholder_name): return placeholder_name + "_preprocessed" def orig_name(placeholder_name): return "__original__" + placeholder_name # Validate before modifying the graph for p in placeholders: if not g.contains_node(p.name): raise ValueError("Preprocessing graph contains a Placeholder called " "'{}', but target graph does not have an input " "Placeholder by that name." "".format(p.name)) if not preproc_g.contains_node(preproc_name(p.name)): raise ValueError("Preprocessing graph contains a Placeholder called " "'{}', but it does not have an output node called '{}' " "to produce the preprocessed version of that input." "".format(p.name, preproc_name(p.name))) # Rename all the target placeholders so we can bulk-copy the preprocessing # graph. for p in placeholders: g.rename_node(p.name, orig_name(p.name)) # Now it should be safe to copy the preprocessing graph into the original # graph. gde.copy(preproc_g, g) for p in placeholders: preproc_p = g.get_node_by_name(preproc_name(p.name)) orig_p = g.get_node_by_name(orig_name(p.name)) # Reroute all connections from original placeholder to go to the # corresponding output of the preprocessing graph. gde.reroute_ts(preproc_p.output(0), orig_p.output(0)) # Get rid of the original placeholder g.remove_node_by_name(orig_p.name)
def test_graph_cond(self): tf_g = tf.Graph() with tf_g.as_default(): choice_tensor = tf.placeholder(shape=(), dtype=tf.bool, name="choice") _ = tf.identity( tf.cond( choice_tensor, lambda: tf.constant(1), lambda: tf.constant(2) ), name="result" ) g = gde.Graph(tf_g) choice = g["choice"].output(0) result = g["result"].output(0) copied_g = gde.Graph() _, copy_info = gde.copy( g, dst_graph=copied_g, dst_scope="imported") copied_result = copy_info.transformed(result) copied_choice = copy_info.transformed(choice) tf_copied_graph = tf.Graph() with tf_copied_graph.as_default(): tf.import_graph_def(copied_g.to_graph_def(), name="") with tf.Session() as sess: res = sess.run(copied_result.name, feed_dict={copied_choice.name: True}) self.assertEqual(res, 1) res = sess.run(copied_result.name, feed_dict={copied_choice.name: False}) self.assertEqual(res, 2)
def test_graph_while_loop(self): tf_graph = tf.Graph() with tf_graph.as_default(): max_index = tf.placeholder(dtype=tf.int32, shape=tuple()) index_start = tf.constant(1) sum_start = tf.constant(0) _, result = tf.while_loop( cond=lambda i, unused_s: i <= max_index, body=lambda i, s: (i + 1, s + i), loop_vars=[index_start, sum_start]) g = gde.Graph(tf_graph) result_tensor = g[result.op.name].output(0) max_index_tensor = g[max_index.op.name].output(0) g.frozen = True copied_graph = gde.Graph() _, copy_info = gde.copy( g, dst_graph=copied_graph, dst_scope="imported") copied_result_tensor = copy_info.transformed(result_tensor) copied_max_index_tensor = copy_info.transformed(max_index_tensor) tf_copied_graph = tf.Graph() with tf_copied_graph.as_default(): tf.import_graph_def(copied_graph.to_graph_def(), name="") with tf.Session() as sess: n = 10 sum_val = sess.run(copied_result_tensor.name, feed_dict={copied_max_index_tensor.name: n}) self.assertEqual(sum_val, 55)
def _graft_pre_and_post_processing_to_main_graph(g): # type: (gde.Graph) -> None """ Attach pre- and post-processing subgraphs to the main graph. Args: g: GDE representation of the core graph. Modified in place. """ # Build the pre- and post-processing subgraphs and import into GDE pre_g = gde.Graph(_build_preprocessing_graph_def()) post_g = gde.Graph(_build_postprocessing_graph_def()) # Replace the graph's input placeholder with the contents of our # pre-processing graph. name_of_input_node = _INPUT_NODE_NAMES[0] gde.copy(pre_g, g) gde.reroute_ts( g.get_node_by_name("preprocessed_image").output(0), g.get_node_by_name(name_of_input_node).output(0)) g.remove_node_by_name(name_of_input_node) g.rename_node("raw_image", name_of_input_node) # Tack on the postprocessing graph at the original output and rename # the postprocessed output to the original output's name # The original graph produces an output called "detection_classes". # The postprocessing graph goes from "detection_classes" to # "decoded_detection_classes". # The graph after modification produces decoded classes under the original # "detection_classes" name. The original output is renamed to # "raw_detection_classes". g.rename_node("detection_classes", "raw_detection_classes") gde.copy(post_g, g) gde.reroute_ts( g.get_node_by_name("raw_detection_classes").output(0), g.get_node_by_name("detection_classes").output(0)) g.remove_node_by_name("detection_classes") g.rename_node("decoded_detection_classes", "detection_classes")
def test_copy(self): graph = gde.Graph() _, info = gde.copy(self.graph, graph) self.assertEqual(set(op.name for op in self.graph.nodes), set(op.name for op in graph.nodes)) src_ops = self.graph.nodes dst_ops = graph.nodes for op in src_ops: op_ = info.transformed(op) self.assertTrue(op_ in dst_ops) self.assertEqual(op.name, op_.name) self.assertEqual(info.original(op_), op) src_ts = self.graph.tensors dst_ts = graph.tensors for t in src_ts: t_ = info.transformed(t) self.assertTrue(t_ in dst_ts) self.assertEqual(t.name, t_.name) self.assertEqual(info.original(t_), t)
def add_postprocessing(g, postproc_g): # type: (gde.Graph, gde.Graph) -> None """ Add postprocessing ops to a graph. The postprocessing ops can replace one or more output operations of the original graph with a series of operations that apply additional transformations to the output and return the result of the transformations. After performing this rewrite, the outputs of the resulting graph may have a different shape and dtype than before, but they will have the same names. Args: g: `gde.Graph` to which postprocessing should be added. *Modified in place.* postproc_g: `gde.Graph` containing the postprocessing ops to add. For each op in `g` that needs postprocessing, `postproc_g` should contain a placeholder with the same name and a second op named "<name of output>_postprocessed", where `<name of output>` is the name of the original op. """ placeholders = gde.filter_ops_by_optype(postproc_g, "Placeholder") def postproc_name(placeholder_name): return placeholder_name + "_postprocessed" def orig_name(placeholder_name): return "__original__" + placeholder_name # Validate before modifying the graph for p in placeholders: if not g.contains_node(p.name): raise ValueError("Postprocessing graph contains a Placeholder called " "'{}', but target graph does not have an op by that " "name".format(p.name)) if 1 != len(g.get_node_by_name(p.name).outputs): raise ValueError("Output node '{}' of target graph has {} output " "tensors. Only one output is supported." "".format(p.name, len(g.get_node_by_name(p.name).outputs))) if not postproc_g.contains_node(postproc_name(p.name)): raise ValueError("Postprocessing graph contains a Placeholder called " "'{}', but it does not have a node called '{}' " "to produce the postprocessed version of that output." "".format(p.name, postproc_name(p.name))) # Rename all the original output ops so we can bulk-copy the preprocessing # graph. for p in placeholders: g.rename_node(p.name, orig_name(p.name)) # Now it should be safe to copy the preprocessing graph into the original # graph. gde.copy(postproc_g, g) for p in placeholders: postproc_input_p = g.get_node_by_name(p.name) orig_output_node = g.get_node_by_name(orig_name(p.name)) # Reroute all connections from original placeholder to go to the # corresponding output of the original graph. gde.reroute_ts(orig_output_node.output(0), postproc_input_p.output(0)) # Get rid of the placeholder g.remove_node_by_name(postproc_input_p.name) # Rename the postprocessed output to the name of the original output g.rename_node(postproc_name(p.name), p.name)