Exemple #1
0
    def testFindHintedOutputNodes(self):
        """Test if all hinted output nodes are correctly found."""
        with ops.Graph().as_default():

            def _build_ophinted_op(name, input1, input2):
                custom_op = op_hint.OpHint(name)
                input1 = custom_op.add_input(input1)
                input2 = custom_op.add_input(input2)
                output = math_ops.mul(input1, input2)
                return custom_op.add_output(output)

            output_1 = _build_ophinted_op("custom_op_1",
                                          array_ops.constant([1.]),
                                          array_ops.constant([2.]))
            output_2 = _build_ophinted_op("custom_op_2",
                                          array_ops.constant([3.]),
                                          array_ops.constant([4.]))
            with self.cached_session() as sess:
                hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(
                    sess)
                expected_hinted_output_nodes = [
                    _node_name(output_1.name),
                    _node_name(output_2.name)
                ]
                self.assertEqual(len(hinted_outputs_nodes),
                                 len(expected_hinted_output_nodes))
  def testFindHintedOutputNodes(self):
    """Test if all hinted output nodes are correctly found."""

    def _build_ophinted_op(name, input1, input2):
      custom_op = op_hint.OpHint(name)
      input1 = custom_op.add_input(input1)
      input2 = custom_op.add_input(input2)
      output = math_ops.mul(input1, input2)
      return custom_op.add_output(output)

    output_1 = _build_ophinted_op("custom_op_1", array_ops.constant([1.]),
                                  array_ops.constant([2.]))
    output_2 = _build_ophinted_op("custom_op_2", array_ops.constant([3.]),
                                  array_ops.constant([4.]))
    with self.cached_session() as sess:
      hinted_outputs_nodes = op_hint.find_all_hinted_output_nodes(sess)
      expected_hinted_output_nodes = [
          _node_name(output_1.name),
          _node_name(output_2.name)
      ]
      self.assertEqual(
          len(hinted_outputs_nodes), len(expected_hinted_output_nodes))
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
    """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

    if isinstance(input_nodes, six.string_types):
        raise TypeError("input_nodes must be a list.")

    if isinstance(output_nodes, six.string_types):
        raise TypeError("output_nodes must be a list.")

    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

    # Nodes upto and including input_nodes
    reachable_by_input = _bfs_for_reachable_nodes(input_nodes,
                                                  name_to_input_name)
    # Nodes upto and including output_nodes
    reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                   name_to_input_name)

    # Set of nodes in the list input_nodes
    input_nodes_set = set(input_nodes)

    # Set of nodes in the list output_nodes
    output_nodes_set = set(output_nodes)

    nodes_post_output = []
    for node in graph_def.node:
        n = _node_name(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is between input and output, i.e., part of the fused op
                next_to_visit = [n]
                while next_to_visit:
                    cur_node = next_to_visit[0]
                    del next_to_visit[0]
                    if cur_node in reachable_by_input and cur_node not in input_nodes_set:
                        raise TypeError(
                            "Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
                    if cur_node not in input_nodes_set:
                        next_to_visit += name_to_input_name[cur_node]
        else:
            nodes_post_output.append(n)

    # Add all nodes upto the input nodes
    out = graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([copy.deepcopy(name_to_node[node])])

    # Add the custom op
    new_node = node_def_pb2.NodeDef()
    for node in input_nodes:
        new_node.input.append(node)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    new_node.attr["_output_quantized"].b = output_quantized
    new_node.op = op_type
    new_node.name = op_name
    out.node.extend([new_node])

    # Add the nodes in the output of the custom op
    for index, n in enumerate(output_nodes):
        assert len(name_to_node[n].input) == 1
        new_node = copy.deepcopy(name_to_node[n])
        del new_node.input[:]
        new_node.input.append(op_name +
                              (":" + str(index) if index != 0 else ""))
        out.node.extend([new_node])

    # Add the nodes post output_nodes
    for n in nodes_post_output:
        out.node.extend([copy.deepcopy(name_to_node[n])])

    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)
    return out
Exemple #4
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
  """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

  if not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

  if isinstance(input_nodes, six.string_types):
    raise TypeError("input_nodes must be a list.")

  if isinstance(output_nodes, six.string_types):
    raise TypeError("output_nodes must be a list.")

  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      graph_def)
  _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

  # Nodes upto and including input_nodes
  reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
  # Nodes upto and including output_nodes
  reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                 name_to_input_name)

  # Set of nodes in the list input_nodes
  input_nodes_set = set(input_nodes)

  # Set of nodes in the list output_nodes
  output_nodes_set = set(output_nodes)

  nodes_post_output = []
  for node in graph_def.node:
    n = _node_name(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # n is between input and output, i.e., part of the fused op
        next_to_visit = [n]
        while next_to_visit:
          cur_node = next_to_visit[0]
          del next_to_visit[0]
          if cur_node in reachable_by_input and cur_node not in input_nodes_set:
            raise TypeError("Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
          if cur_node not in input_nodes_set:
            next_to_visit += name_to_input_name[cur_node]
    else:
      nodes_post_output.append(n)

  # Add all nodes upto the input nodes
  out = graph_pb2.GraphDef()
  reachable_by_input_sorted = sorted(
      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
  for node in reachable_by_input_sorted:
    out.node.extend([copy.deepcopy(name_to_node[node])])

  # Add the custom op
  new_node = node_def_pb2.NodeDef()
  for node in input_nodes:
    new_node.input.append(node)
  new_node.attr["_output_types"].list.type[:] = output_dtypes
  new_node.attr["_output_quantized"].b = output_quantized
  new_node.op = op_type
  new_node.name = op_name
  out.node.extend([new_node])

  # Add the nodes in the output of the custom op
  for index, n in enumerate(output_nodes):
    assert len(name_to_node[n].input) == 1
    new_node = copy.deepcopy(name_to_node[n])
    del new_node.input[:]
    new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
    out.node.extend([new_node])

  # Add the nodes post output_nodes
  for n in nodes_post_output:
    out.node.extend([copy.deepcopy(name_to_node[n])])

  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  return out