示例#1
0
def _find_children_hints(call, graph_def):
  """Find all children hints.

  For a given OpHint, we find all children hints inside it, we also copy all the
  nodes inside function defs (if applicable) to the original graph_def, they are
  returned in a list as well.

  Args:
    call: Parent OpHint that contains children ophints.
    graph_def: Original graph def.

  Returns:
    Ordered children hints inside the parent ophint; new graph def that contains
    nodes inside function defs (if applicable); nodes inside function defs.
  """
  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  output_nodes_set = set(output_names)
  children_hints = []
  out = _graph_pb2.GraphDef()
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  function_def_nodes = set()
  for node in graph_def.node:
    out.node.extend([_copy.deepcopy(node)])
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # special handle for while loop function def.
        if node.op == "While":
          body_name = node.attr["body"].func.name
          inputs_outside_loop = node.input
          for function_def in graph_def.library.function:
            if function_def.signature.name == body_name:
              function_inputs = function_def.signature.input_arg
              assert len(inputs_outside_loop) == len(function_inputs)
              nodes_mapping = {}
              for i in range(len(function_inputs)):
                nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
              # TODO(b/123050804): Consider use grappler.
              (children_hints_in_loop,
               new_nodes) = _find_children_hints_in_while_loop(
                   function_def, nodes_mapping)
              function_def_nodes.update([x.name for x in new_nodes])
              children_hints.extend(children_hints_in_loop)
              out.node.extend(new_nodes)

  return children_hints, out, function_def_nodes
示例#2
0
def _find_children_hints(call, graph_def):
  """Find all children hints.

  For a given OpHint, we find all children hints inside it, we also copy all the
  nodes inside function defs (if applicable) to the original graph_def, they are
  returned in a list as well.

  Args:
    call: Parent OpHint that contains children ophints.
    graph_def: Original graph def.

  Returns:
    Ordered children hints inside the parent ophint; new graph def that contains
    nodes inside function defs (if applicable); nodes inside function defs.
  """
  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  output_nodes_set = set(output_names)
  children_hints = []
  out = _graph_pb2.GraphDef()
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  function_def_nodes = set()
  for node in graph_def.node:
    out.node.extend([_copy.deepcopy(node)])
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # special handle for while loop function def.
        if node.op == "While":
          body_name = node.attr["body"].func.name
          inputs_outside_loop = node.input
          for function_def in graph_def.library.function:
            if function_def.signature.name == body_name:
              function_inputs = function_def.signature.input_arg
              assert len(inputs_outside_loop) == len(function_inputs)
              nodes_mapping = {}
              for i in range(len(function_inputs)):
                nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
              # TODO(b/123050804): Consider use grappler.
              (children_hints_in_loop,
               new_nodes) = _find_children_hints_in_while_loop(
                   function_def, nodes_mapping)
              function_def_nodes.update([x.name for x in new_nodes])
              children_hints.extend(children_hints_in_loop)
              out.node.extend(new_nodes)

  return children_hints, out, function_def_nodes
示例#3
0
  def _getGraphOpTypes(self, graphdef, output_nodes):
    """Returns used op types in `graphdef` reachable from `output_nodes`.

    This is used to check that after the stub transformation the expected
    nodes are there.

    NOTE: this is not a exact test that the graph is the correct output, but
      it balances compact expressibility of test with sanity checking.

    Args:
      graphdef: TensorFlow proto graphdef.
      output_nodes: A list of output node names that we need to reach.

    Returns:
      A set of node types reachable from `output_nodes`.
    """
    name_to_input_name, name_to_node, _ = (
        _extract_graph_summary(graphdef))
    # Find all nodes that are needed by the outputs
    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
    return set([name_to_node[node_name].op for node_name in used_node_names])
示例#4
0
  def _getGraphOpTypes(self, graphdef, output_nodes):
    """Returns used op types in `graphdef` reachable from `output_nodes`.

    This is used to check that after the stub transformation the expected
    nodes are there.

    NOTE: this is not a exact test that the graph is the correct output, but
      it balances compact expressibility of test with sanity checking.

    Args:
      graphdef: TensorFlow proto graphdef.
      output_nodes: A list of output node names that we need to reach.

    Returns:
      A set of node types reachable from `output_nodes`.
    """
    name_to_input_name, name_to_node, _ = (
        _extract_graph_summary(graphdef))
    # Find all nodes that are needed by the outputs
    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
    return set([name_to_node[node_name].op for node_name in used_node_names])
示例#5
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
    """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

    if isinstance(input_nodes, six.string_types):
        raise TypeError("input_nodes must be a list.")

    if isinstance(output_nodes, six.string_types):
        raise TypeError("output_nodes must be a list.")

    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

    # Nodes upto and including input_nodes
    reachable_by_input = _bfs_for_reachable_nodes(input_nodes,
                                                  name_to_input_name)
    # Nodes upto and including output_nodes
    reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                   name_to_input_name)

    # Set of nodes in the list input_nodes
    input_nodes_set = set(input_nodes)

    # Set of nodes in the list output_nodes
    output_nodes_set = set(output_nodes)

    nodes_post_output = []
    for node in graph_def.node:
        n = _node_name(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is between input and output, i.e., part of the fused op
                next_to_visit = [n]
                while next_to_visit:
                    cur_node = next_to_visit[0]
                    del next_to_visit[0]
                    if cur_node in reachable_by_input and cur_node not in input_nodes_set:
                        raise TypeError(
                            "Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
                    if cur_node not in input_nodes_set:
                        next_to_visit += name_to_input_name[cur_node]
        else:
            nodes_post_output.append(n)

    # Add all nodes upto the input nodes
    out = graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([copy.deepcopy(name_to_node[node])])

    # Add the custom op
    new_node = node_def_pb2.NodeDef()
    for node in input_nodes:
        new_node.input.append(node)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    new_node.attr["_output_quantized"].b = output_quantized
    new_node.op = op_type
    new_node.name = op_name
    out.node.extend([new_node])

    # Add the nodes in the output of the custom op
    for index, n in enumerate(output_nodes):
        assert len(name_to_node[n].input) == 1
        new_node = copy.deepcopy(name_to_node[n])
        del new_node.input[:]
        new_node.input.append(op_name +
                              (":" + str(index) if index != 0 else ""))
        out.node.extend([new_node])

    # Add the nodes post output_nodes
    for n in nodes_post_output:
        out.node.extend([copy.deepcopy(name_to_node[n])])

    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)
    return out
示例#6
0
def _convert_single_op_hint_to_stub(call, graph_def):
  """Given a graph_def, converts `call` into a stub and returns a new graph_def.

  Args:
    call: A single function call to be converted.
    graph_def: A graph_def to use as input (that hass call obviously).
  Returns:
    A new transformed graph-def that has call as a stub (single op).

  Note: after this process, the graph_def can no longer be loaded into
      the tensorflow runtime, so all future manipulations are done in graph_def
      level.
  """
  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  input_nodes_set = set(input_names)
  output_nodes_set = set(output_names)
  nodes_after_fuse = []
  nodes_deleted_by_fuse = set()
  # Classify each node. We want to keep everything reachable by input, but
  # we don't know if things that are not reachable by output or input (things
  # after fusing).
  for node in graph_def.node:
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # n is an internal node. Check to make sure it is really internal.
        # TODO(aselle): this could be done more efficiently by flooding
        # the graph first.
        _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
                               name_to_input_name)
        nodes_deleted_by_fuse.add(n)
    elif n not in reachable_by_input:
      # n is a node that after all the fusings, so keep it.
      nodes_after_fuse.append(n)
    else:
      # n is a node that is randomly in the graph but not connected to
      # the chain of dependencies.
      pass

  # Make a new graphdef with all the pre-input and input nodes
  out = _graph_pb2.GraphDef()
  reachable_by_input_sorted = sorted(
      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
  for node in reachable_by_input_sorted:
    out.node.extend([_copy.deepcopy(name_to_node[node])])

  # Create any stacks to aggregate arguments into to a single input
  # i.e. for static_rnn's.
  # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
  sorted_input_indices = list(call.inputs.keys())
  sorted_input_indices.sort()
  sorted_output_indices = list(call.outputs.keys())
  sorted_output_indices.sort()
  new_node = _node_def_pb2.NodeDef()
  # Delegate to each operand to produce the proper new input for this stub node.
  # In particular, an aggregate input will now be a Pack of some previously
  # non-fused things.
  for input_index in sorted_input_indices:
    inputs = call.inputs[input_index]
    new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)

  # Ceate the function
  new_node.op = call.function_name
  new_node.name = call.uuid
  out.node.extend([new_node])

  # Now call each output argument to give them a chance to make the proper
  # output type and add it to our new_node.
  output_dtypes = []
  for output_index in sorted_output_indices:
    output = call.outputs[output_index]
    output_dtype = (
        output.aggregate_and_return_name_for_output(new_node.name, output_index,
                                                    out))
    output_dtypes.append(output_dtype)
  new_node.attr["_output_types"].list.type[:] = output_dtypes
  # TODO(aselle): what is right here?
  new_node.attr["_output_quantized"].b = False

  # Add post output nodes that do not depend on the outputs
  for n in nodes_after_fuse:
    should_keep = True
    for input_name in name_to_input_name[n]:
      if input_name in nodes_deleted_by_fuse:
        should_keep = False
    if should_keep:
      out.node.extend([_copy.deepcopy(name_to_node[n])])

  # Misc. graph_def data that needs copying.
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)

  return out
示例#7
0
def _convert_single_op_hint_to_stub(call, graph_def):
    """Given a graph_def, converts `call` into a stub and returns a new graph_def.

  Args:
    call: A single function call to be converted.
    graph_def: A graph_def to use as input (that hass call obviously).
  Returns:
    A new transformed graph-def that has call as a stub (single op).

  Note: after this process, the graph_def can no longer be loaded into
      the tensorflow runtime, so all future manipulations are done in graph_def
      level.
  """
    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    input_names, output_names = call.flattened_inputs_and_outputs()

    reachable_by_input = _bfs_for_reachable_nodes(input_names,
                                                  name_to_input_name)
    reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                   name_to_input_name)
    input_nodes_set = set(input_names)
    output_nodes_set = set(output_names)
    nodes_after_fuse = []
    nodes_deleted_by_fuse = set()
    # Classify each node. We want to keep everything reachable by input, but
    # we don't know if things that are not reachable by output or input (things
    # after fusing).
    for node in graph_def.node:
        n = _tensor_name_base(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is an internal node. Check to make sure it is really internal.
                # TODO(aselle): this could be done more efficiently by flooding
                # the graph first.
                _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
                                       name_to_input_name)
                nodes_deleted_by_fuse.add(n)
        elif n not in reachable_by_input:
            # n is a node that after all the fusings, so keep it.
            nodes_after_fuse.append(n)
        else:
            # n is a node that is randomly in the graph but not connected to
            # the chain of dependencies.
            pass

    # Make a new graphdef with all the pre-input and input nodes
    out = _graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([_copy.deepcopy(name_to_node[node])])

    # Create any stacks to aggregate arguments into to a single input
    # i.e. for static_rnn's.
    # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
    sorted_input_indices = list(call.inputs.keys())
    sorted_input_indices.sort()
    sorted_output_indices = list(call.outputs.keys())
    sorted_output_indices.sort()
    new_node = _node_def_pb2.NodeDef()
    # Delegate to each operand to produce the proper new input for this stub node.
    # In particular, an aggregate input will now be a Pack of some previously
    # non-fused things.
    for input_index in sorted_input_indices:
        inputs = call.inputs[input_index]
        new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
    new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(
        sorted_input_indices)

    # Ceate the function
    new_node.op = call.function_name
    new_node.name = call.uuid
    out.node.extend([new_node])

    # Now call each output argument to give them a chance to make the proper
    # output type and add it to our new_node.
    output_dtypes = []
    for output_index in sorted_output_indices:
        output = call.outputs[output_index]
        output_dtype = (output.aggregate_and_return_name_for_output(
            new_node.name, output_index, out))
        output_dtypes.append(output_dtype)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    # TODO(aselle): what is right here?
    new_node.attr["_output_quantized"].b = False

    # Add post output nodes that do not depend on the outputs
    for n in nodes_after_fuse:
        should_keep = True
        for input_name in name_to_input_name[n]:
            if input_name in nodes_deleted_by_fuse:
                should_keep = False
        if should_keep:
            out.node.extend([_copy.deepcopy(name_to_node[n])])

    # Misc. graph_def data that needs copying.
    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)

    return out
示例#8
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
  """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

  if not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

  if isinstance(input_nodes, six.string_types):
    raise TypeError("input_nodes must be a list.")

  if isinstance(output_nodes, six.string_types):
    raise TypeError("output_nodes must be a list.")

  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      graph_def)
  _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

  # Nodes upto and including input_nodes
  reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
  # Nodes upto and including output_nodes
  reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                 name_to_input_name)

  # Set of nodes in the list input_nodes
  input_nodes_set = set(input_nodes)

  # Set of nodes in the list output_nodes
  output_nodes_set = set(output_nodes)

  nodes_post_output = []
  for node in graph_def.node:
    n = _node_name(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # n is between input and output, i.e., part of the fused op
        next_to_visit = [n]
        while next_to_visit:
          cur_node = next_to_visit[0]
          del next_to_visit[0]
          if cur_node in reachable_by_input and cur_node not in input_nodes_set:
            raise TypeError("Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
          if cur_node not in input_nodes_set:
            next_to_visit += name_to_input_name[cur_node]
    else:
      nodes_post_output.append(n)

  # Add all nodes upto the input nodes
  out = graph_pb2.GraphDef()
  reachable_by_input_sorted = sorted(
      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
  for node in reachable_by_input_sorted:
    out.node.extend([copy.deepcopy(name_to_node[node])])

  # Add the custom op
  new_node = node_def_pb2.NodeDef()
  for node in input_nodes:
    new_node.input.append(node)
  new_node.attr["_output_types"].list.type[:] = output_dtypes
  new_node.attr["_output_quantized"].b = output_quantized
  new_node.op = op_type
  new_node.name = op_name
  out.node.extend([new_node])

  # Add the nodes in the output of the custom op
  for index, n in enumerate(output_nodes):
    assert len(name_to_node[n].input) == 1
    new_node = copy.deepcopy(name_to_node[n])
    del new_node.input[:]
    new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
    out.node.extend([new_node])

  # Add the nodes post output_nodes
  for n in nodes_post_output:
    out.node.extend([copy.deepcopy(name_to_node[n])])

  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  return out