Пример #1
0
def bonsai_parser(model, model_in):
    """
    Full parsing function, handling the route layers
    Args:
        model: pytorch model to be processed
        model_in: model input

    Returns:
        A complete BonsaiParsedModel
    """

    # Simple parsing, without the routing layers
    bonsai_parsed_model = parse_simple_model(model, model_in.size())

    # Getting the graph that represents the underlying network connectivity
    gd = pg.graph(model, args=(model_in,))
    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(gd[0])

    # Convert node numbers to their short weight name
    graph_layers_to_weights = {get_node_name(k):v for k,v in name_to_seq_num.items()}

    # Route layers
    route_layers = {k: v.op.split('::')[1] for k,v in name_to_node.items() if v.op in ['onnx::Concat', 'onnx::Add']}
    route_weight_names = [get_node_name(x) for x in route_layers]

    # matching node (full name, node shortened name, and previous nodes connected by the graph)
    raw_predecessors = {(k, get_node_name(k)): [get_node_name(x) for x in in_names_list] for k, in_names_list in name_to_input_name.items()}

    # removing duplicates and empty strings
    predecessors = {weight: list(set([x for x in weight_list if len(x) > 0])) for (name, weight), weight_list in raw_predecessors.items()}

    # removing nodes that are intermediate values, they dont correspond to graph layers
    real_nodes = list(set(bonsai_parsed_model.get_weight_names())) + route_weight_names
    real_predecessors = get_real_predecessors(predecessors.copy(), real_nodes)

    # getting the relevant layers for the routing computation
    route_real_predecessors = {k:v for k,v in real_predecessors.items() if k in route_weight_names}
    route_predecessors_layers = {k:bonsai_parsed_model.get_layers_by_weights(v) for k,v in route_real_predecessors.items()}

    # computing the layer number of the generated route layer
    # here we set it to be 1 after the node that is previous to it in the GraphDef graph
    graph_connection = {graph_layers_to_weights[k]: [graph_layers_to_weights[x] for x in v] for k, v in route_real_predecessors.items()}
    prev_index = {k: v.index(int(k)-1) for k, v in graph_connection.items()}
    res_layers = {v[prev_index[graph_layers_to_weights[k]]] + 1: v for k, v in route_predecessors_layers.items()}

    # keeping in mind that layer indices shift when we add new layers
    shifted_values = {k: [val + len([x for x in res_layers if int(x) < int(val)]) for val in v] for k, v in res_layers.items()}
    final_layers = {int(k) + len([x for x in shifted_values if int(x) < int(k)]): v for k, v in shifted_values.items()}

    # adding the layers to the model
    for (k, v), operation in zip(final_layers.items(), route_layers.values()):
        if operation == 'Concat':
            bonsai_parsed_model.insert_module(int(k), 'route')
            bonsai_parsed_model.insert_param(int(k), 'layers', str(v))
        elif operation == 'Add':
            bonsai_parsed_model.insert_module(int(k), 'residual_add')
            bonsai_parsed_model.insert_param(int(k), 'layers', str(v))

    return bonsai_parsed_model
Пример #2
0
def _find_children_hints(call, graph_def):
  """Find all children hints.

  For a given OpHint, we find all children hints inside it, we also copy all the
  nodes inside function defs (if applicable) to the original graph_def, they are
  returned in a list as well.

  Args:
    call: Parent OpHint that contains children ophints.
    graph_def: Original graph def.

  Returns:
    Ordered children hints inside the parent ophint; new graph def that contains
    nodes inside function defs (if applicable); nodes inside function defs.
  """
  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  output_nodes_set = set(output_names)
  children_hints = []
  out = _graph_pb2.GraphDef()
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  function_def_nodes = set()
  for node in graph_def.node:
    out.node.extend([_copy.deepcopy(node)])
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # special handle for while loop function def.
        if node.op == "While":
          body_name = node.attr["body"].func.name
          inputs_outside_loop = node.input
          for function_def in graph_def.library.function:
            if function_def.signature.name == body_name:
              function_inputs = function_def.signature.input_arg
              assert len(inputs_outside_loop) == len(function_inputs)
              nodes_mapping = {}
              for i in range(len(function_inputs)):
                nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
              # TODO(b/123050804): Consider use grappler.
              (children_hints_in_loop,
               new_nodes) = _find_children_hints_in_while_loop(
                   function_def, nodes_mapping)
              function_def_nodes.update([x.name for x in new_nodes])
              children_hints.extend(children_hints_in_loop)
              out.node.extend(new_nodes)

  return children_hints, out, function_def_nodes
Пример #3
0
def _find_children_hints(call, graph_def):
  """Find all children hints.

  For a given OpHint, we find all children hints inside it, we also copy all the
  nodes inside function defs (if applicable) to the original graph_def, they are
  returned in a list as well.

  Args:
    call: Parent OpHint that contains children ophints.
    graph_def: Original graph def.

  Returns:
    Ordered children hints inside the parent ophint; new graph def that contains
    nodes inside function defs (if applicable); nodes inside function defs.
  """
  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  output_nodes_set = set(output_names)
  children_hints = []
  out = _graph_pb2.GraphDef()
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  function_def_nodes = set()
  for node in graph_def.node:
    out.node.extend([_copy.deepcopy(node)])
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # special handle for while loop function def.
        if node.op == "While":
          body_name = node.attr["body"].func.name
          inputs_outside_loop = node.input
          for function_def in graph_def.library.function:
            if function_def.signature.name == body_name:
              function_inputs = function_def.signature.input_arg
              assert len(inputs_outside_loop) == len(function_inputs)
              nodes_mapping = {}
              for i in range(len(function_inputs)):
                nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
              # TODO(b/123050804): Consider use grappler.
              (children_hints_in_loop,
               new_nodes) = _find_children_hints_in_while_loop(
                   function_def, nodes_mapping)
              function_def_nodes.update([x.name for x in new_nodes])
              children_hints.extend(children_hints_in_loop)
              out.node.extend(new_nodes)

  return children_hints, out, function_def_nodes
Пример #4
0
  def _getGraphOpTypes(self, graphdef, output_nodes):
    """Returns used op types in `graphdef` reachable from `output_nodes`.

    This is used to check that after the stub transformation the expected
    nodes are there.

    NOTE: this is not a exact test that the graph is the correct output, but
      it balances compact expressibility of test with sanity checking.

    Args:
      graphdef: TensorFlow proto graphdef.
      output_nodes: A list of output node names that we need to reach.

    Returns:
      A set of node types reachable from `output_nodes`.
    """
    name_to_input_name, name_to_node, _ = (
        _extract_graph_summary(graphdef))
    # Find all nodes that are needed by the outputs
    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
    return set([name_to_node[node_name].op for node_name in used_node_names])
Пример #5
0
  def _getGraphOpTypes(self, graphdef, output_nodes):
    """Returns used op types in `graphdef` reachable from `output_nodes`.

    This is used to check that after the stub transformation the expected
    nodes are there.

    NOTE: this is not a exact test that the graph is the correct output, but
      it balances compact expressibility of test with sanity checking.

    Args:
      graphdef: TensorFlow proto graphdef.
      output_nodes: A list of output node names that we need to reach.

    Returns:
      A set of node types reachable from `output_nodes`.
    """
    name_to_input_name, name_to_node, _ = (
        _extract_graph_summary(graphdef))
    # Find all nodes that are needed by the outputs
    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
    return set([name_to_node[node_name].op for node_name in used_node_names])
Пример #6
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
    """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

    if isinstance(input_nodes, six.string_types):
        raise TypeError("input_nodes must be a list.")

    if isinstance(output_nodes, six.string_types):
        raise TypeError("output_nodes must be a list.")

    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

    # Nodes upto and including input_nodes
    reachable_by_input = _bfs_for_reachable_nodes(input_nodes,
                                                  name_to_input_name)
    # Nodes upto and including output_nodes
    reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                   name_to_input_name)

    # Set of nodes in the list input_nodes
    input_nodes_set = set(input_nodes)

    # Set of nodes in the list output_nodes
    output_nodes_set = set(output_nodes)

    nodes_post_output = []
    for node in graph_def.node:
        n = _node_name(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is between input and output, i.e., part of the fused op
                next_to_visit = [n]
                while next_to_visit:
                    cur_node = next_to_visit[0]
                    del next_to_visit[0]
                    if cur_node in reachable_by_input and cur_node not in input_nodes_set:
                        raise TypeError(
                            "Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
                    if cur_node not in input_nodes_set:
                        next_to_visit += name_to_input_name[cur_node]
        else:
            nodes_post_output.append(n)

    # Add all nodes upto the input nodes
    out = graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([copy.deepcopy(name_to_node[node])])

    # Add the custom op
    new_node = node_def_pb2.NodeDef()
    for node in input_nodes:
        new_node.input.append(node)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    new_node.attr["_output_quantized"].b = output_quantized
    new_node.op = op_type
    new_node.name = op_name
    out.node.extend([new_node])

    # Add the nodes in the output of the custom op
    for index, n in enumerate(output_nodes):
        assert len(name_to_node[n].input) == 1
        new_node = copy.deepcopy(name_to_node[n])
        del new_node.input[:]
        new_node.input.append(op_name +
                              (":" + str(index) if index != 0 else ""))
        out.node.extend([new_node])

    # Add the nodes post output_nodes
    for n in nodes_post_output:
        out.node.extend([copy.deepcopy(name_to_node[n])])

    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)
    return out
Пример #7
0
def _remove_one_redundant_stack_unstack(in_graph_def):
  """Removes a stack->unstack pattern from in_graph_def in a returned graph.

  Args:
    in_graph_def: Graph def to use as input.
  Returns:
    Simplified tuple (graph_def, changed_something) where changed_something
    is true if anything was done.
  """
  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      in_graph_def)
  del name_to_seq_num

  # TODO(aselle): Make this not hardcoded.
  do_generic_pack_unpack = True

  out = _graph_pb2.GraphDef()
  out.library.CopyFrom(in_graph_def.library)
  out.versions.CopyFrom(in_graph_def.versions)
  for n in in_graph_def.node:
    node_name = _tensor_name_base(n.name)
    if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
      continue
    next_to_visit = [node_name]
    visited = set()

    unpack_nodes = set()
    pack_node = node_name

    # Find a pattern of unstack connected to a stack (with identities
    # in between.
    matches_pattern = True
    is_hint_created_stack = False
    while next_to_visit:
      current_node_name = next_to_visit[0]
      visited.add(current_node_name)
      del next_to_visit[0]
      node = name_to_node[current_node_name]
      is_op_hint_stack = node.name.startswith("OpHintStack")
      is_op_hint_unstack = node.name.startswith("OpHintUnstack")
      if (node.op == "Identity" or is_op_hint_stack
          or (do_generic_pack_unpack and node.op == "Pack")):
        is_hint_created_stack |= is_op_hint_stack
        next_to_visit += [
            input_node for input_node in name_to_input_name[current_node_name]
            if input_node not in visited
        ]
      elif (is_op_hint_unstack
            or (do_generic_pack_unpack and node.op == "Unpack")):
        unpack_nodes.add(node.name)
        is_hint_created_stack &= is_op_hint_unstack
      else:
        matches_pattern = False
        break
      visited.add(node.name)

    if matches_pattern and len(unpack_nodes) == 1:
      pack_node = node_name

      # Check to see if anyone depends on the intermediate identity or the
      # Unstacked form
      no_external_dependency = True
      for other_n in in_graph_def.node:
        if other_n.name in visited: continue
        for input_tensor in name_to_input_name[other_n.name]:
          input_op = _tensor_name_base(input_tensor)
          if input_op in visited and input_op != pack_node:
            no_external_dependency = False
      # Proceed with the substitution if the stack/unstack pair was created
      # through hints, or that it was not, but nobody is consuming things
      # between the stack and unstack.
      if is_hint_created_stack or no_external_dependency:
        end = unpack_nodes.pop()
        end_input = name_to_node[end].input[0]
        # All nodes that depend on the final stack need to be redone to use
        for other_n in in_graph_def.node:
          node_name = _tensor_name_base(other_n.name)
          if node_name not in visited:
            new_node = _copy.deepcopy(other_n)
            new_node.input[:] = [
                (end_input if stripped == pack_node else
                 non_stripped) for stripped, non_stripped in zip(
                     name_to_input_name[node_name], new_node.input[:])
            ]
            out.node.extend([new_node])
        return out, True
  return in_graph_def, False
Пример #8
0
def _convert_single_op_hint_to_stub(call, graph_def):
  """Given a graph_def, converts `call` into a stub and returns a new graph_def.

  Args:
    call: A single function call to be converted.
    graph_def: A graph_def to use as input (that hass call obviously).
  Returns:
    A new transformed graph-def that has call as a stub (single op).

  Note: after this process, the graph_def can no longer be loaded into
      the tensorflow runtime, so all future manipulations are done in graph_def
      level.
  """
  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      graph_def)
  input_names, output_names = call.flattened_inputs_and_outputs()

  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
  reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                 name_to_input_name)
  input_nodes_set = set(input_names)
  output_nodes_set = set(output_names)
  nodes_after_fuse = []
  nodes_deleted_by_fuse = set()
  # Classify each node. We want to keep everything reachable by input, but
  # we don't know if things that are not reachable by output or input (things
  # after fusing).
  for node in graph_def.node:
    n = _tensor_name_base(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # n is an internal node. Check to make sure it is really internal.
        # TODO(aselle): this could be done more efficiently by flooding
        # the graph first.
        _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
                               name_to_input_name)
        nodes_deleted_by_fuse.add(n)
    elif n not in reachable_by_input:
      # n is a node that after all the fusings, so keep it.
      nodes_after_fuse.append(n)
    else:
      # n is a node that is randomly in the graph but not connected to
      # the chain of dependencies.
      pass

  # Make a new graphdef with all the pre-input and input nodes
  out = _graph_pb2.GraphDef()
  reachable_by_input_sorted = sorted(
      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
  for node in reachable_by_input_sorted:
    out.node.extend([_copy.deepcopy(name_to_node[node])])

  # Create any stacks to aggregate arguments into to a single input
  # i.e. for static_rnn's.
  # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
  sorted_input_indices = list(call.inputs.keys())
  sorted_input_indices.sort()
  sorted_output_indices = list(call.outputs.keys())
  sorted_output_indices.sort()
  new_node = _node_def_pb2.NodeDef()
  # Delegate to each operand to produce the proper new input for this stub node.
  # In particular, an aggregate input will now be a Pack of some previously
  # non-fused things.
  for input_index in sorted_input_indices:
    inputs = call.inputs[input_index]
    new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)

  # Ceate the function
  new_node.op = call.function_name
  new_node.name = call.uuid
  out.node.extend([new_node])

  # Now call each output argument to give them a chance to make the proper
  # output type and add it to our new_node.
  output_dtypes = []
  for output_index in sorted_output_indices:
    output = call.outputs[output_index]
    output_dtype = (
        output.aggregate_and_return_name_for_output(new_node.name, output_index,
                                                    out))
    output_dtypes.append(output_dtype)
  new_node.attr["_output_types"].list.type[:] = output_dtypes
  # TODO(aselle): what is right here?
  new_node.attr["_output_quantized"].b = False

  # Add post output nodes that do not depend on the outputs
  for n in nodes_after_fuse:
    should_keep = True
    for input_name in name_to_input_name[n]:
      if input_name in nodes_deleted_by_fuse:
        should_keep = False
    if should_keep:
      out.node.extend([_copy.deepcopy(name_to_node[n])])

  # Misc. graph_def data that needs copying.
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)

  return out
Пример #9
0
def _remove_one_redundant_stack_unstack(in_graph_def):
    """Removes a stack->unstack pattern from in_graph_def in a returned graph.

  Args:
    in_graph_def: Graph def to use as input.
  Returns:
    Simplified tuple (graph_def, changed_something) where changed_something
    is true if anything was done.
  """
    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        in_graph_def)
    del name_to_seq_num

    # TODO(aselle): Make this not hardcoded.
    do_generic_pack_unpack = True

    out = _graph_pb2.GraphDef()
    out.library.CopyFrom(in_graph_def.library)
    out.versions.CopyFrom(in_graph_def.versions)
    for n in in_graph_def.node:
        node_name = _tensor_name_base(n.name)
        if not node_name.startswith("OpHintStack") and not n.op.startswith(
                "Pack"):
            continue
        next_to_visit = [node_name]
        visited = set()

        unpack_nodes = set()
        pack_node = node_name

        # Find a pattern of unstack connected to a stack (with identities
        # in between.
        matches_pattern = True
        is_hint_created_stack = False
        while next_to_visit:
            current_node_name = next_to_visit[0]
            visited.add(current_node_name)
            del next_to_visit[0]
            node = name_to_node[current_node_name]
            is_op_hint_stack = node.name.startswith("OpHintStack")
            is_op_hint_unstack = node.name.startswith("OpHintUnstack")
            if (node.op == "Identity" or is_op_hint_stack
                    or (do_generic_pack_unpack and node.op == "Pack")):
                is_hint_created_stack |= is_op_hint_stack
                next_to_visit += [
                    input_node
                    for input_node in name_to_input_name[current_node_name]
                    if input_node not in visited
                ]
            elif (is_op_hint_unstack
                  or (do_generic_pack_unpack and node.op == "Unpack")):
                unpack_nodes.add(node.name)
                is_hint_created_stack &= is_op_hint_unstack
            else:
                matches_pattern = False
                break
            visited.add(node.name)

        if matches_pattern and len(unpack_nodes) == 1:
            pack_node = node_name

            # Check to see if anyone depends on the intermediate identity or the
            # Unstacked form
            no_external_dependency = True
            for other_n in in_graph_def.node:
                if other_n.name in visited: continue
                for input_tensor in name_to_input_name[other_n.name]:
                    input_op = _tensor_name_base(input_tensor)
                    if input_op in visited and input_op != pack_node:
                        no_external_dependency = False
            # Proceed with the substitution if the stack/unstack pair was created
            # through hints, or that it was not, but nobody is consuming things
            # between the stack and unstack.
            if is_hint_created_stack or no_external_dependency:
                end = unpack_nodes.pop()
                end_input = name_to_node[end].input[0]
                # All nodes that depend on the final stack need to be redone to use
                for other_n in in_graph_def.node:
                    node_name = _tensor_name_base(other_n.name)
                    if node_name not in visited:
                        new_node = _copy.deepcopy(other_n)
                        new_node.input[:] = [(end_input if stripped
                                              == pack_node else non_stripped)
                                             for stripped, non_stripped in zip(
                                                 name_to_input_name[node_name],
                                                 new_node.input[:])]
                        out.node.extend([new_node])
                return out, True
    return in_graph_def, False
Пример #10
0
def _convert_single_op_hint_to_stub(call, graph_def):
    """Given a graph_def, converts `call` into a stub and returns a new graph_def.

  Args:
    call: A single function call to be converted.
    graph_def: A graph_def to use as input (that hass call obviously).
  Returns:
    A new transformed graph-def that has call as a stub (single op).

  Note: after this process, the graph_def can no longer be loaded into
      the tensorflow runtime, so all future manipulations are done in graph_def
      level.
  """
    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    input_names, output_names = call.flattened_inputs_and_outputs()

    reachable_by_input = _bfs_for_reachable_nodes(input_names,
                                                  name_to_input_name)
    reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                   name_to_input_name)
    input_nodes_set = set(input_names)
    output_nodes_set = set(output_names)
    nodes_after_fuse = []
    nodes_deleted_by_fuse = set()
    # Classify each node. We want to keep everything reachable by input, but
    # we don't know if things that are not reachable by output or input (things
    # after fusing).
    for node in graph_def.node:
        n = _tensor_name_base(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is an internal node. Check to make sure it is really internal.
                # TODO(aselle): this could be done more efficiently by flooding
                # the graph first.
                _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
                                       name_to_input_name)
                nodes_deleted_by_fuse.add(n)
        elif n not in reachable_by_input:
            # n is a node that after all the fusings, so keep it.
            nodes_after_fuse.append(n)
        else:
            # n is a node that is randomly in the graph but not connected to
            # the chain of dependencies.
            pass

    # Make a new graphdef with all the pre-input and input nodes
    out = _graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([_copy.deepcopy(name_to_node[node])])

    # Create any stacks to aggregate arguments into to a single input
    # i.e. for static_rnn's.
    # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
    sorted_input_indices = list(call.inputs.keys())
    sorted_input_indices.sort()
    sorted_output_indices = list(call.outputs.keys())
    sorted_output_indices.sort()
    new_node = _node_def_pb2.NodeDef()
    # Delegate to each operand to produce the proper new input for this stub node.
    # In particular, an aggregate input will now be a Pack of some previously
    # non-fused things.
    for input_index in sorted_input_indices:
        inputs = call.inputs[input_index]
        new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
    new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(
        sorted_input_indices)

    # Ceate the function
    new_node.op = call.function_name
    new_node.name = call.uuid
    out.node.extend([new_node])

    # Now call each output argument to give them a chance to make the proper
    # output type and add it to our new_node.
    output_dtypes = []
    for output_index in sorted_output_indices:
        output = call.outputs[output_index]
        output_dtype = (output.aggregate_and_return_name_for_output(
            new_node.name, output_index, out))
        output_dtypes.append(output_dtype)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    # TODO(aselle): what is right here?
    new_node.attr["_output_quantized"].b = False

    # Add post output nodes that do not depend on the outputs
    for n in nodes_after_fuse:
        should_keep = True
        for input_name in name_to_input_name[n]:
            if input_name in nodes_deleted_by_fuse:
                should_keep = False
        if should_keep:
            out.node.extend([_copy.deepcopy(name_to_node[n])])

    # Misc. graph_def data that needs copying.
    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)

    return out
Пример #11
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
  """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

  if not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

  if isinstance(input_nodes, six.string_types):
    raise TypeError("input_nodes must be a list.")

  if isinstance(output_nodes, six.string_types):
    raise TypeError("output_nodes must be a list.")

  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
      graph_def)
  _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

  # Nodes upto and including input_nodes
  reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
  # Nodes upto and including output_nodes
  reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                 name_to_input_name)

  # Set of nodes in the list input_nodes
  input_nodes_set = set(input_nodes)

  # Set of nodes in the list output_nodes
  output_nodes_set = set(output_nodes)

  nodes_post_output = []
  for node in graph_def.node:
    n = _node_name(node.name)
    if n in reachable_by_output:
      if n not in reachable_by_input and n not in output_nodes_set:
        # n is between input and output, i.e., part of the fused op
        next_to_visit = [n]
        while next_to_visit:
          cur_node = next_to_visit[0]
          del next_to_visit[0]
          if cur_node in reachable_by_input and cur_node not in input_nodes_set:
            raise TypeError("Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
          if cur_node not in input_nodes_set:
            next_to_visit += name_to_input_name[cur_node]
    else:
      nodes_post_output.append(n)

  # Add all nodes upto the input nodes
  out = graph_pb2.GraphDef()
  reachable_by_input_sorted = sorted(
      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
  for node in reachable_by_input_sorted:
    out.node.extend([copy.deepcopy(name_to_node[node])])

  # Add the custom op
  new_node = node_def_pb2.NodeDef()
  for node in input_nodes:
    new_node.input.append(node)
  new_node.attr["_output_types"].list.type[:] = output_dtypes
  new_node.attr["_output_quantized"].b = output_quantized
  new_node.op = op_type
  new_node.name = op_name
  out.node.extend([new_node])

  # Add the nodes in the output of the custom op
  for index, n in enumerate(output_nodes):
    assert len(name_to_node[n].input) == 1
    new_node = copy.deepcopy(name_to_node[n])
    del new_node.input[:]
    new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
    out.node.extend([new_node])

  # Add the nodes post output_nodes
  for n in nodes_post_output:
    out.node.extend([copy.deepcopy(name_to_node[n])])

  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  return out