Esempio n. 1
0
def freeze_graph(sess, input_tensors, output_tensors):
  """Returns a frozen GraphDef.

  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
  existing GraphDef is returned. The Grappler pass is only run on models that
  are frozen in order to inline the functions in the graph.
  If OpHints is present, it will try to convert the OpHint graph.

  Args:
    sess: TensorFlow Session.
    input_tensors: List of input tensors.
    output_tensors: List of output tensors (only .name is used from this).

  Returns:
    Frozen GraphDef.
  """
  # Grappler inline function optimization will break OpHints graph
  # transformation, so if OpHints are present, just convert it.
  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
  if len(hinted_outputs_nodes) > 0:  #  pylint: disable=g-explicit-length-test
    return _convert_op_hints_if_present(sess, output_tensors)

  # Runs a Grappler pass in order to inline any functions in the graph.
  config = get_grappler_config(function_only=True)
  graph_def = run_graph_optimizations(
      sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)

  if not is_frozen_graph(sess):
    output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
    return tf_graph_util.convert_variables_to_constants(sess, graph_def,
                                                        output_arrays)
  else:
    return sess.graph_def
Esempio n. 2
0
    def testFindHintedOutputNodes(self):
        """Test if all hinted output nodes are correctly found."""
        with ops.Graph().as_default():

            def _build_ophinted_op(name, input1, input2):
                custom_op = op_hint.OpHint(name)
                input1 = custom_op.add_input(input1)
                input2 = custom_op.add_input(input2)
                output = math_ops.mul(input1, input2)
                return custom_op.add_output(output)

            output_1 = _build_ophinted_op("custom_op_1",
                                          array_ops.constant([1.]),
                                          array_ops.constant([2.]))
            output_2 = _build_ophinted_op("custom_op_2",
                                          array_ops.constant([3.]),
                                          array_ops.constant([4.]))
            with self.cached_session() as sess:
                hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(
                    sess)
                expected_hinted_output_nodes = [
                    _node_name(output_1.name),
                    _node_name(output_2.name)
                ]
                self.assertEqual(len(hinted_outputs_nodes),
                                 len(expected_hinted_output_nodes))
Esempio n. 3
0
def _convert_op_hints_if_present(sess, output_tensors):
    if is_frozen_graph(sess):
        raise ValueError("Try to convert op hints, needs unfrozen graph.")
    hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
    output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
    graph_def = tf_graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
    graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
    graph_def = tf_graph_util.remove_training_nodes(graph_def)
    return graph_def
Esempio n. 4
0
def _convert_op_hints_if_present(sess, output_tensors):
  if is_frozen_graph(sess):
    raise ValueError("Try to convert op hints, needs unfrozen graph.")
  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
  graph_def = tf_graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
  graph_def = tf_graph_util.remove_training_nodes(graph_def)
  return graph_def
Esempio n. 5
0
    def getInferenceResult(self, x, output_class, sess):
        b1, _ = self.mnist.train.next_batch(batch_size=1)
        sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))

        expected_output = sess.run(output_class, feed_dict={x: sample_input})
        # It is important to keep all the ophint output nodes.
        hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
        hinted_outputs_nodes.append(output_class.op.name)
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, hinted_outputs_nodes)
        return sample_input, expected_output, frozen_graph
  def getInferenceResult(self, x, output_class, sess):
    b1, _ = self.mnist.train.next_batch(batch_size=1)
    sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))

    expected_output = sess.run(output_class, feed_dict={x: sample_input})
    # It is important to keep all the ophint output nodes.
    hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
    hinted_outputs_nodes.append(output_class.op.name)
    frozen_graph = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def, hinted_outputs_nodes)
    return sample_input, expected_output, frozen_graph
Esempio n. 7
0
def freeze_graph(sess, input_tensors, output_tensors):
    """Returns a frozen GraphDef.

  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
  existing GraphDef is returned. The Grappler pass is only run on models that
  are frozen in order to inline the functions in the graph.
  If OpHints is present, it will try to convert the OpHint graph.

  Args:
    sess: TensorFlow Session.
    input_tensors: List of input tensors.
    output_tensors: List of output tensors (only .name is used from this).

  Returns:
    Frozen GraphDef.
  """
    # Runs a Grappler pass in order to inline any functions in the graph.
    # Asides from inlining any simple function, Grappler will also try to lower
    # while loop into switch merge representation which is undesired for Ophints,
    # so we simply remove those attributes to prevent Grappler from doing so.
    graph_def = _convert_to_constants.disable_lower_using_switch_merge(
        sess.graph_def)
    config = get_grappler_config(["function"])
    graph_def = run_graph_optimizations(graph_def,
                                        input_tensors,
                                        output_tensors,
                                        config,
                                        graph=sess.graph)

    # If ophints are present, just convert them.
    hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
    if hinted_outputs_nodes:
        return _convert_op_hints_if_present(sess, graph_def, output_tensors,
                                            hinted_outputs_nodes)

    if not is_frozen_graph(sess):
        output_node_names = [
            tensor.name.split(":")[0] for tensor in output_tensors
        ]
        return tf_graph_util.convert_variables_to_constants(
            sess, graph_def, output_node_names)
    else:
        return sess.graph_def
Esempio n. 8
0
  def testFindHintedOutputNodes(self):
    """Test if all hinted output nodes are correctly found."""

    def _build_ophinted_op(name, input1, input2):
      custom_op = op_hint.OpHint(name)
      input1 = custom_op.add_input(input1)
      input2 = custom_op.add_input(input2)
      output = math_ops.mul(input1, input2)
      return custom_op.add_output(output)

    output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
                                  array_ops.constant([2.]))
    output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
                                  array_ops.constant([4.]))
    with self.cached_session() as sess:
      hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
      expected_hinted_output_nodes = [
          _node_name(output_1.name),
          _node_name(output_2.name)
      ]
      self.assertEqual(
          len(hinted_outputs_nodes), len(expected_hinted_output_nodes))