Example #1
0
def add_preprocessing(g, preproc_g):
  # type: (gde.Graph, gde.Graph) -> None
  """
  Add preprocessing ops to a graph.

  Replaces one or more input `Placeholders` in the target graph with
  subgraphs that preprocess the input values prior to feeding them into the
  original graph.

  After performing this rewrite, the inputs of the resulting graph may have a
  different shape and dtype than before, but they will have the same names.

  Args:
    g: `gde.Graph` to which preprocessing should be added. *Modified in place.*
    preproc_g: `gde.Graph` containing the preprocessing ops to add.
      For each placeholder in `g` that needs preprocessing, `preproc_g`
      should contain a placeholder with the same name and a second op named
      "<name of placeholder>_preprocessed", where `<name of placeholder>` is
      the name of the Placeholder op.
  """
  placeholders = gde.filter_ops_by_optype(preproc_g, "Placeholder")

  def preproc_name(placeholder_name):
    return placeholder_name + "_preprocessed"

  def orig_name(placeholder_name):
    return "__original__" + placeholder_name

  # Validate before modifying the graph
  for p in placeholders:
    if not g.contains_node(p.name):
      raise ValueError("Preprocessing graph contains a Placeholder called "
                       "'{}', but target graph does not have an input "
                       "Placeholder by that name."
                       "".format(p.name))
    if not preproc_g.contains_node(preproc_name(p.name)):
      raise ValueError("Preprocessing graph contains a Placeholder called "
                       "'{}', but it does not have an output node called '{}' "
                       "to produce the preprocessed version of that input."
                       "".format(p.name, preproc_name(p.name)))

  # Rename all the target placeholders so we can bulk-copy the preprocessing
  # graph.
  for p in placeholders:
    g.rename_node(p.name, orig_name(p.name))

  # Now it should be safe to copy the preprocessing graph into the original
  # graph.
  gde.copy(preproc_g, g)

  for p in placeholders:
    preproc_p = g.get_node_by_name(preproc_name(p.name))
    orig_p = g.get_node_by_name(orig_name(p.name))

    # Reroute all connections from original placeholder to go to the
    # corresponding output of the preprocessing graph.
    gde.reroute_ts(preproc_p.output(0), orig_p.output(0))

    # Get rid of the original placeholder
    g.remove_node_by_name(orig_p.name)
  def test_graph_cond(self):
    tf_g = tf.Graph()
    with tf_g.as_default():
      choice_tensor = tf.placeholder(shape=(), dtype=tf.bool, name="choice")
      _ = tf.identity(
        tf.cond(
          choice_tensor,
          lambda: tf.constant(1),
          lambda: tf.constant(2)
        ),
        name="result"
      )

    g = gde.Graph(tf_g)
    choice = g["choice"].output(0)
    result = g["result"].output(0)

    copied_g = gde.Graph()
    _, copy_info = gde.copy(
        g, dst_graph=copied_g, dst_scope="imported")
    copied_result = copy_info.transformed(result)
    copied_choice = copy_info.transformed(choice)

    tf_copied_graph = tf.Graph()
    with tf_copied_graph.as_default():
      tf.import_graph_def(copied_g.to_graph_def(), name="")
      with tf.Session() as sess:
        res = sess.run(copied_result.name, feed_dict={copied_choice.name: True})
        self.assertEqual(res, 1)
        res = sess.run(copied_result.name,
                       feed_dict={copied_choice.name: False})
        self.assertEqual(res, 2)
  def test_graph_while_loop(self):
    tf_graph = tf.Graph()
    with tf_graph.as_default():
      max_index = tf.placeholder(dtype=tf.int32, shape=tuple())
      index_start = tf.constant(1)
      sum_start = tf.constant(0)
      _, result = tf.while_loop(
          cond=lambda i, unused_s: i <= max_index,
          body=lambda i, s: (i + 1, s + i),
          loop_vars=[index_start, sum_start])
    g = gde.Graph(tf_graph)
    result_tensor = g[result.op.name].output(0)
    max_index_tensor = g[max_index.op.name].output(0)

    g.frozen = True
    copied_graph = gde.Graph()
    _, copy_info = gde.copy(
        g, dst_graph=copied_graph, dst_scope="imported")
    copied_result_tensor = copy_info.transformed(result_tensor)
    copied_max_index_tensor = copy_info.transformed(max_index_tensor)

    tf_copied_graph = tf.Graph()
    with tf_copied_graph.as_default():
      tf.import_graph_def(copied_graph.to_graph_def(), name="")
      with tf.Session() as sess:
        n = 10
        sum_val = sess.run(copied_result_tensor.name,
                           feed_dict={copied_max_index_tensor.name: n})
        self.assertEqual(sum_val, 55)
Example #4
0
def _graft_pre_and_post_processing_to_main_graph(g):
    # type: (gde.Graph) -> None
    """
  Attach pre- and post-processing subgraphs to the main graph.

  Args:
    g: GDE representation of the core graph. Modified in place.
  """
    # Build the pre- and post-processing subgraphs and import into GDE
    pre_g = gde.Graph(_build_preprocessing_graph_def())
    post_g = gde.Graph(_build_postprocessing_graph_def())

    # Replace the graph's input placeholder with the contents of our
    # pre-processing graph.
    name_of_input_node = _INPUT_NODE_NAMES[0]
    gde.copy(pre_g, g)
    gde.reroute_ts(
        g.get_node_by_name("preprocessed_image").output(0),
        g.get_node_by_name(name_of_input_node).output(0))
    g.remove_node_by_name(name_of_input_node)
    g.rename_node("raw_image", name_of_input_node)

    # Tack on the postprocessing graph at the original output and rename
    # the postprocessed output to the original output's name
    # The original graph produces an output called "detection_classes".
    # The postprocessing graph goes from "detection_classes" to
    # "decoded_detection_classes".
    # The graph after modification produces decoded classes under the original
    # "detection_classes" name. The original output is renamed to
    # "raw_detection_classes".
    g.rename_node("detection_classes", "raw_detection_classes")
    gde.copy(post_g, g)
    gde.reroute_ts(
        g.get_node_by_name("raw_detection_classes").output(0),
        g.get_node_by_name("detection_classes").output(0))
    g.remove_node_by_name("detection_classes")
    g.rename_node("decoded_detection_classes", "detection_classes")
 def test_copy(self):
     graph = gde.Graph()
     _, info = gde.copy(self.graph, graph)
     self.assertEqual(set(op.name for op in self.graph.nodes),
                      set(op.name for op in graph.nodes))
     src_ops = self.graph.nodes
     dst_ops = graph.nodes
     for op in src_ops:
         op_ = info.transformed(op)
         self.assertTrue(op_ in dst_ops)
         self.assertEqual(op.name, op_.name)
         self.assertEqual(info.original(op_), op)
     src_ts = self.graph.tensors
     dst_ts = graph.tensors
     for t in src_ts:
         t_ = info.transformed(t)
         self.assertTrue(t_ in dst_ts)
         self.assertEqual(t.name, t_.name)
         self.assertEqual(info.original(t_), t)
Example #6
0
def add_postprocessing(g, postproc_g):
  # type: (gde.Graph, gde.Graph) -> None
  """
  Add postprocessing ops to a graph.

  The postprocessing ops can replace one or more output operations of the
  original graph with a series of operations that apply additional
  transformations to the output and return the result of the transformations.

  After performing this rewrite, the outputs of the resulting graph may have a
  different shape and dtype than before, but they will have the same names.

  Args:
    g: `gde.Graph` to which postprocessing should be added. *Modified in place.*
    postproc_g: `gde.Graph` containing the postprocessing ops to add.
      For each op in `g` that needs postprocessing, `postproc_g`
      should contain a placeholder with the same name and a second op named
      "<name of output>_postprocessed", where `<name of output>` is
      the name of the original op.
  """
  placeholders = gde.filter_ops_by_optype(postproc_g, "Placeholder")

  def postproc_name(placeholder_name):
    return placeholder_name + "_postprocessed"

  def orig_name(placeholder_name):
    return "__original__" + placeholder_name

  # Validate before modifying the graph
  for p in placeholders:
    if not g.contains_node(p.name):
      raise ValueError("Postprocessing graph contains a Placeholder called "
                       "'{}', but target graph does not have an op by that "
                       "name".format(p.name))
    if 1 != len(g.get_node_by_name(p.name).outputs):
      raise ValueError("Output node '{}' of target graph has {} output "
                       "tensors. Only one output is supported."
                       "".format(p.name,
                                 len(g.get_node_by_name(p.name).outputs)))
    if not postproc_g.contains_node(postproc_name(p.name)):
      raise ValueError("Postprocessing graph contains a Placeholder called "
                       "'{}', but it does not have a node called '{}' "
                       "to produce the postprocessed version of that output."
                       "".format(p.name, postproc_name(p.name)))

  # Rename all the original output ops so we can bulk-copy the preprocessing
  # graph.
  for p in placeholders:
    g.rename_node(p.name, orig_name(p.name))

  # Now it should be safe to copy the preprocessing graph into the original
  # graph.
  gde.copy(postproc_g, g)

  for p in placeholders:
    postproc_input_p = g.get_node_by_name(p.name)
    orig_output_node = g.get_node_by_name(orig_name(p.name))

    # Reroute all connections from original placeholder to go to the
    # corresponding output of the original graph.
    gde.reroute_ts(orig_output_node.output(0), postproc_input_p.output(0))

    # Get rid of the placeholder
    g.remove_node_by_name(postproc_input_p.name)

    # Rename the postprocessed output to the name of the original output
    g.rename_node(postproc_name(p.name), p.name)