Ejemplo n.º 1
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):
    """Replaces all the variables in a graph with constants of the same values.
  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.
  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.
    variable_names_whitelist: The set of variable names to convert (by default,
                              all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants.
  Returns:
    GraphDef containing a simplified version of the original.
  """
    def has_variable_as_input(node):
        """Checks if the input node has a variable in `variables_data_map`."""
        for name in node.input:
            if name in variables_data_map or\
                    (name in identity_ops_input_map
                     and identity_ops_input_map[name] in variables_data_map):
                return True
        return False

    def dfs_find_variable(origin_name, name_to_nodes):

        if origin_name in variables_data_map:
            return origin_name, set()

        nodes_in_path = set()
        found_variables = set()

        def dfs(name):
            node = name_to_nodes[name]
            if node.op == "Switch":
                inputs = [node.input[0]]
            else:
                inputs = node.input
            for name in inputs:
                name = _node_name(name)
                if name in nodes_in_path:
                    continue
                elif name in variables_data_map:
                    found_variables.add(name)
                    continue
                else:
                    nodes_in_path.add(name)
                    dfs(name)

        nodes_in_path.add(origin_name)
        dfs(origin_name)

        if len(found_variables) > 1:
            raise ValueError("found variables %s" % found_variables)

        variable = None
        for v in found_variables:
            variable = v
        return variable, nodes_in_path

    def create_const_op(node_name, dtype, data, data_shape=None):
        """Creates a Const op."""
        output_node = node_def_pb2.NodeDef()
        output_node.op = "Const"
        output_node.name = node_name
        output_node.attr["dtype"].CopyFrom(dtype)
        output_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                data, dtype=dtype.type, shape=data_shape)))
        return output_node

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    # Get list of variables.
    variable_names = []
    variable_dict_names = []
    identity_ops_input_map = {}
    name_to_node = {}
    for node in inference_graph.node:
        name_to_node[node.name] = node
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None
                 and variable_name not in variable_names_whitelist)
                    or (variable_names_blacklist is not None
                        and variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
        elif node.op == "Identity":
            # TODO(nupurgarg): Move and reuse get_name from lite/convert.py.
            # Creates a map of Identity node names to the input names.
            input_info = node.input[0].split(":")
            if (len(input_info) == 1
                    or (len(input_info) == 2 and int(input_info[1]) == 0)):
                identity_ops_input_map[node.name] = input_info[0]

    # Gets map of variables and the associated data.
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    variables_data_map = dict(zip(variable_dict_names, returned_variables))
    logging.info("Froze %d variables.", len(returned_variables))

    # Reconstruct the graph with constants in place of variables.

    path_node_to_variables = {}
    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in variables_data_map:
            data = variables_data_map[input_node.name]
            output_node = create_const_op(input_node.name,
                                          input_node.attr["dtype"], data,
                                          data.shape)
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp":
            variable, nodes_in_path = dfs_find_variable(
                input_node.input[0], name_to_node)
            if variable is not None:
                # The first branch converts all VarHandleOps of ResourceVariables to
                # constants, so we need to convert the associated ReadVariableOps to
                # Identity ops.
                #
                # Handles the following cases:
                #   Variable --> ReadVariableOp
                #   Variable --> Identity --> ReadVariableOp
                output_node.op = "Identity"
                output_node.name = input_node.name
                output_node.input.extend([input_node.input[0]])
                output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
                if "_class" in input_node.attr:
                    output_node.attr["_class"].CopyFrom(
                        input_node.attr["_class"])
                for name in nodes_in_path:
                    path_node_to_variables[name] = variable
            else:
                raise ValueError("Cannot find variable for %s" %
                                 input_node.name)

        elif input_node.op == "ResourceGather":

            variable, nodes_in_path = dfs_find_variable(
                input_node.input[0], name_to_node)
            if variable is not None:
                # The first branch converts all VarHandleOps of ResourceGather to
                # constants, so we need to convert the associated ResourceGather to Gather
                # ops with a Const axis feeding into it.
                if input_node.attr["batch_dims"].i != 0:
                    raise ValueError(
                        "batch_dims != 0 is not supported by freeze_graph.")
                axis_data = input_node.attr["batch_dims"].i
                axis_node_name = input_node.name + "/axis"
                axis_dtype = input_node.attr["Tindices"]
                output_axis_node = create_const_op(axis_node_name, axis_dtype,
                                                   axis_data)
                output_graph_def.node.extend([output_axis_node])

                output_node.op = "GatherV2"
                output_node.name = input_node.name
                output_node.input.extend(
                    [input_node.input[0], input_node.input[1], axis_node_name])
                output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
                output_node.attr["Tindices"].CopyFrom(
                    input_node.attr["Tindices"])
                output_node.attr["Taxis"].CopyFrom(axis_dtype)
                if "_class" in input_node.attr:
                    output_node.attr["_class"].CopyFrom(
                        input_node.attr["_class"])
                for name in nodes_in_path:
                    path_node_to_variables[name] = variable
            else:
                raise ValueError("Cannot find variable for %s" %
                                 input_node.name)
        elif input_node.op == "VariableShape":

            variable, nodes_in_path = dfs_find_variable(
                input_node.input[0], name_to_node)
            if variable is not None:
                input_variable = name_to_node[variable]
                output_node.op = "Shape"
                output_node.name = input_node.name
                output_node.input.extend([input_node.input[0]])
                output_node.attr["T"].CopyFrom(input_variable.attr["dtype"])
                output_node.attr["out_type"].CopyFrom(
                    input_node.attr["out_type"])
                for name in nodes_in_path:
                    path_node_to_variables[name] = variable
            else:
                raise ValueError("Cannot find variable for %s" %
                                 input_node.name)
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)

    inference_graph = output_graph_def
    output_graph_def = graph_pb2.GraphDef()
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in path_node_to_variables:
            input_variable = path_node_to_variables[input_node.name]
            input_variable = name_to_node[input_variable]
            output_node.op = input_node.op
            output_node.name = input_node.name
            if input_node.op == "Enter":
                output_node.input.extend([input_node.input[0]])
                output_node.attr["T"].CopyFrom(input_variable.attr["dtype"])
                output_node.attr["frame_name"].CopyFrom(
                    input_node.attr["frame_name"])
                output_node.attr["is_constant"].CopyFrom(
                    input_node.attr["is_constant"])
                output_node.attr["parallel_iterations"]\
                    .CopyFrom(input_node.attr["parallel_iterations"])
            elif input_node.op == "Switch":
                output_node.input.extend(input_node.input)
                output_node.attr["T"].CopyFrom(input_variable.attr["dtype"])
            else:
                raise ValueError("cannot do type: %s" % input_node.op)
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)

    logging.info("Converted %d variables to const ops.", how_many_converted)
    return output_graph_def
Ejemplo n.º 2
0
def fuse_resize_and_conv(input_graph_def, output_node_names):
  """Merges preceding resize and mirror pad ops into a specialized convolution.

  There's a common pattern of enlarging the input to a convolution using a
  resize operation, and also using MirrorPad to extend the boundaries to that
  zero edge pixels don't bleed inwards when convolving. This routine looks for
  that pattern of operations, and fuses them together into a Conv2DWithResizeOp.

  Args:
    input_graph_def: A GraphDef containing a model.
    output_node_names: A list of names of the nodes that produce the final
      results.

  Returns:
    Modified graph with resize and pad ops merged.

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """

  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError("Duplicate node names detected for ", node.name)

  node_reference_count = collections.defaultdict(int)
  for node in input_graph_def.node:
    for input_name in node.input:
      stripped_name = node_name_from_input(input_name)
      node_reference_count[stripped_name] += 1
  for output_name in output_node_names:
    node_reference_count[output_name] += 1

  new_ops = []
  for node in input_graph_def.node:

    if node.op != "Conv2D":
      continue
    conv_op = node

    input_op = node_from_map(input_node_map, conv_op.input[0])
    if input_op.op == "MirrorPad":
      mirror_pad_op = input_op
      resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
      if resize_op.op != "ResizeBilinear":
        resize_op = None
    else:
      mirror_pad_op = None
      if input_op.op == "ResizeBilinear":
        resize_op = input_op
      else:
        resize_op = None

    # There are no ops to be fused into the conv, so skip replacing this one.
    if not mirror_pad_op and not resize_op:
      continue

    # We're replacing this node, so make sure the old one is removed.
    node_reference_count[conv_op.name] = 0
    if mirror_pad_op:
      node_reference_count[mirror_pad_op.name] -= 1
    if resize_op:
      node_reference_count[resize_op.name] -= 1

    fused_conv_op = node_def_pb2.NodeDef()
    if resize_op:
      fused_conv_op.op = "FusedResizeAndPadConv2D"
    else:
      fused_conv_op.op = "FusedPadConv2D"
    fused_conv_op.name = conv_op.name
    if mirror_pad_op:
      mirror_paddings_name = mirror_pad_op.input[1]
      mirror_paddings_mode = mirror_pad_op.attr["mode"]
    else:
      # If there was no MirrorPad op, then create settings that make the padding
      # stage of the fused operation a no-op.
      paddings_op = node_def_pb2.NodeDef()
      paddings_op.op = "Const"
      paddings_op.name = conv_op.name + "_dummy_paddings"
      paddings_op.attr["dtype"].CopyFrom(
          attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum))
      paddings_op.attr["value"].CopyFrom(
          attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
              [0, 0, 0, 0, 0, 0, 0, 0], dtypes.int32, [4, 2])))
      new_ops.extend([paddings_op])
      mirror_paddings_name = paddings_op.name
      mirror_paddings_mode = attr_value_pb2.AttrValue(s=b"REFLECT")
    if resize_op:
      fused_conv_op.input.extend([
          resize_op.input[0], resize_op.input[1], mirror_paddings_name,
          conv_op.input[1]
      ])
      fused_conv_op.attr["resize_align_corners"].CopyFrom(
          resize_op.attr["align_corners"])
    else:
      fused_conv_op.input.extend(
          [mirror_pad_op.input[0], mirror_paddings_name, conv_op.input[1]])
    fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
    fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
    fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
    fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
    new_ops.extend([fused_conv_op])

  result_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node_reference_count[node.name] < 1:
      continue
    new_node = node_def_pb2.NodeDef()
    new_node.CopyFrom(node)
    result_graph_def.node.extend([new_node])

  result_graph_def.node.extend(new_ops)
  return result_graph_def
Ejemplo n.º 3
0
def create_subgraph(tf_graph, node_list, sess, dst_scope=None):
    """
    Create a tf subgraph from the node list.
    :param tf_graph:
    :param node_list:
    :param sess:
    :param dst_scope:
    :return:
    """
    variable_dict_names = []
    variable_names = []
    tensor_op_names = []
    for n_ in node_list:  # type: tf.Operation
        tensor_op_names.extend([ts_.op.name for ts_ in n_.inputs])
        if n_.type in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = n_.name
            variable_dict_names.append(variable_name)

            if n_.type == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))
    all_op_names = set([n_.name for n_ in node_list])
    missing_ops = set(tensor_op_names) - all_op_names

    replacement = {}
    tf_graph_def = tf_graph.as_graph_def()
    subgraph_def = _extract_sub_graph(tf_graph_def, [n_.name for n_ in node_list], missing_ops)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in subgraph_def.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(
                    tensor=tensor_util.make_tensor_proto(
                        data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (
                input_node.input[0] in found_variables):
            # The preceding branch converts all VarHandleOps of ResourceVariables to
            # constants, so we need to convert the associated ReadVariableOps to
            # Identity ops.
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        elif input_node.name not in missing_ops:
            output_node.CopyFrom(input_node)
        else:
            output_node = None
        if output_node is not None:
            output_graph_def.node.extend([output_node])

    for input_node in tf_graph_def.node:
        if input_node.name in missing_ops:
            output_node = node_def_pb2.NodeDef()
            output_node.op = "Placeholder"
            output_node.name = input_node.name
            replacement[input_node.name] = input_node.name
            if str(input_node.attr["dtype"]):
                output_node.attr["dtype"].CopyFrom(input_node.attr["dtype"])
            elif str(input_node.attr["T"]):
                output_node.attr["dtype"].CopyFrom(input_node.attr["T"])
            else:
                if input_node.op == 'All':
                    output_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type="DT_BOOL"))
                elif input_node.op == 'Cast':
                    output_node.attr["dtype"].CopyFrom(input_node.attr["DstT"])
                else:
                    raise RuntimeError("Can't get the node data type for %s" % input_node.name)
            ts_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, input_node.name)
            output_node.attr["shape"].CopyFrom(
                attr_value_pb2.AttrValue(shape=ts_shape.as_proto()))
            output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(subgraph_def.library)
    with tf.Graph().as_default() as sub_graph:
        im_scope = "" if dst_scope is None else dst_scope
        tf.import_graph_def(output_graph_def, name=im_scope)
        if im_scope:
            replacement = {k_: im_scope + '/' + k_ for k_ in replacement}

    return sub_graph, replacement
Ejemplo n.º 4
0
 def create_node_def(self, op, name, inputs):
     new_node = node_def_pb2.NodeDef()
     new_node.op = op
     new_node.name = name
     new_node.input.extend(inputs)
     return new_node
Ejemplo n.º 5
0
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
            output_quantized, op_name, op_type):
    """Fuse subgraph between input_nodes and output_nodes into a single custom op.

  Args:
    graph_def: A graph_pb2.GraphDef proto.
    input_nodes: input nodes to the subgraph to be fused.
    output_nodes: output nodes to the subgraph to be fused.
    output_dtypes: A list of output datatypes for the custom op
    output_quantized: A boolean flag that indicates if output is quantized
    op_name: fused op name.
    op_type: fused op type.
  Returns:
    The GraphDef of the new graph.

  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

    if isinstance(input_nodes, six.string_types):
        raise TypeError("input_nodes must be a list.")

    if isinstance(output_nodes, six.string_types):
        raise TypeError("output_nodes must be a list.")

    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)

    # Nodes upto and including input_nodes
    reachable_by_input = _bfs_for_reachable_nodes(input_nodes,
                                                  name_to_input_name)
    # Nodes upto and including output_nodes
    reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
                                                   name_to_input_name)

    # Set of nodes in the list input_nodes
    input_nodes_set = set(input_nodes)

    # Set of nodes in the list output_nodes
    output_nodes_set = set(output_nodes)

    nodes_post_output = []
    for node in graph_def.node:
        n = _node_name(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is between input and output, i.e., part of the fused op
                next_to_visit = [n]
                while next_to_visit:
                    cur_node = next_to_visit[0]
                    del next_to_visit[0]
                    if cur_node in reachable_by_input and cur_node not in input_nodes_set:
                        raise TypeError(
                            "Node %s uses input %s not in input_nodes." %
                            (n, cur_node))
                    if cur_node not in input_nodes_set:
                        next_to_visit += name_to_input_name[cur_node]
        elif n not in reachable_by_input:
            nodes_post_output.append(n)

    # Add all nodes upto the input nodes
    out = graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([copy.deepcopy(name_to_node[node])])

    # Add the custom op
    new_node = node_def_pb2.NodeDef()
    for node in input_nodes:
        new_node.input.append(node)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    new_node.attr["_output_quantized"].b = output_quantized
    new_node.op = op_type
    new_node.name = op_name
    out.node.extend([new_node])

    # Add the nodes in the output of the custom op
    for index, n in enumerate(output_nodes):
        assert len(name_to_node[n].input) == 1
        new_node = copy.deepcopy(name_to_node[n])
        del new_node.input[:]
        new_node.input.append(op_name +
                              (":" + str(index) if index != 0 else ""))
        out.node.extend([new_node])

    # Add the nodes post output_nodes
    for n in nodes_post_output:
        out.node.extend([copy.deepcopy(name_to_node[n])])

    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)
    return out
Ejemplo n.º 6
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.
    variable_names_whitelist: The set of variable names to convert (by default,
                              all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants.

  Returns:
    GraphDef containing a simplified version of the original.
  """
    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    found_variables = {}
    variable_names = []
    variable_dict_names = []
    for node in inference_graph.node:
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None
                 and variable_name not in variable_names_whitelist)
                    or (variable_names_blacklist is not None
                        and variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))
    logging.info("Froze %d variables.", len(returned_variables))

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (input_node.input[0]
                                                    in found_variables):
            # The preceding branch converts all VarHandleOps of ResourceVariables to
            # constants, so we need to convert the associated ReadVariableOps to
            # Identity ops.
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)
    logging.info("Converted %d variables to const ops.", how_many_converted)
    return output_graph_def
Ejemplo n.º 7
0
    def apply_matmul_biasadd_relu_fusion(self, match_node_name):
        skip_node_name = match_node_name[1:]
        matched_node = self.node_name_mapping[match_node_name[0]]
        control_inputs, normal_inputs = self._get_node_input(
            matched_node.node.name)
        weight_name = normal_inputs[1]
        weight_node = self.node_name_mapping[helper.node_name_from_input(
            weight_name)].node

        # FIXME We only quantize the MatMul op which second input node type is const. This is a
        # workaround for RNN model like LTSM.
        if weight_node.op != 'Const':
            self.output_graph = self.input_graph
            return

        for i in self.node_name_mapping:
            if weight_node.name in self.node_name_mapping[i].output:
                self.output_graph = self.input_graph
                return

        q_weights_name, q_weights_min_name, q_weights_max_name = \
            self._intel_cpu_quantize_weight_eightbit(
                matched_node.node.op, self.node_name_mapping[weight_name].node, self.per_channel)

        skip_node_name.append(weight_name)

        for _, node in enumerate(self.input_graph.node):
            if node.name in skip_node_name:
                pass
            elif node.name == match_node_name[0]:
                self.logger.debug("matched node {} with input {}".format(
                    node.name, node.input))

                self.logger.debug("apply_matmul_biasadd_relu_fusion")

                quantized_node_name = node.name + "_eightbit_quantized_mat_mul"
                bias_node_name = self.node_name_mapping[
                    match_node_name[1]].node.input[1]
                relu_node_name = match_node_name[2]
                all_input_names = self._add_eightbit_prologue_nodes(
                    matched_node.node.name)
                all_input_names = all_input_names[:1] + [
                    q_weights_name
                ] + all_input_names[1:]
                all_input_names.append(q_weights_min_name)
                all_input_names.append(q_weights_max_name)
                quantized_node_input_names = all_input_names[:2] + [
                    bias_node_name
                ] + all_input_names[2:] + control_inputs

                quantized_matmul_node = helper.create_node(
                    "QuantizedMatMulWithBiasAndRelu", quantized_node_name,
                    quantized_node_input_names)

                helper.copy_attr(quantized_matmul_node, "transpose_a",
                                 node.attr["transpose_a"])
                helper.copy_attr(quantized_matmul_node, "transpose_b",
                                 node.attr["transpose_b"])
                helper.set_attr_dtype(quantized_matmul_node, "T1",
                                      dtypes.quint8)
                helper.set_attr_dtype(quantized_matmul_node, "T2",
                                      dtypes.qint8)
                helper.set_attr_dtype(quantized_matmul_node, "Toutput",
                                      dtypes.qint32)

                self.add_output_graph_node(quantized_matmul_node)

                quantize_down_name = self._add_quantize_down_nodes(
                    node, quantized_node_name, dtypes.quint8, False)
                self._intel_cpu_add_dequantize_result_node(
                    quantize_down_name, relu_node_name)
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                self.add_output_graph_node(new_node)
Ejemplo n.º 8
0
    def do_transformation(self):
        cur_graph = GraphAnalyzer()

        # according to https://github.com/onnx/tensorflow-onnx/issues/77
        for node in self.model.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']
            elif node.op == 'Assign':
                node.op = 'Identity'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']
                if 'validate_shape' in node.attr:
                    del node.attr['validate_shape']
                if len(node.input) == 2:
                    # input0: ref: Should be from a Variable node. May be uninitialized.
                    # input1: value: The value to be assigned to the variable.
                    node.input[0] = node.input[1]
                    del node.input[1]

        cur_graph.graph = self.model

        graph_info = cur_graph.parse_graph()

        for name in self.input_node_names:
            if ':' in name:
                self.logger.debug("Name {} appears to refer to a Tensor, "
                                  "not a Operation.".format(name))
                return False

        type_attr = {"Sub": "T"}

        not_found = {name for name in self.input_node_names}
        for node_name, _ in graph_info.items():
            if node_name in not_found:
                not_found.remove(node_name)
                node = graph_info[node_name].node
                # skip the convertion to Placeholder that with type list
                if 'component_types' in node.attr:
                    continue
                original_output = graph_info[node_name].outputs
                placeholder_node = node_def_pb2.NodeDef()
                placeholder_node.op = "Placeholder"
                placeholder_node.name = node.name

                if "dtype" in node.attr:
                    placeholder_node.attr["dtype"].CopyFrom(
                        attr_value_pb2.AttrValue(type=node.attr["dtype"].type))
                elif node.op in type_attr.keys():
                    placeholder_node.attr["dtype"].CopyFrom(
                        attr_value_pb2.AttrValue(
                            type=node.attr[type_attr[node.op]].type))
                else:
                    raise KeyError("%s op's type attribute is not found,"
                                   "you should add it to type_attr dict" %
                                   node.op)
                if "_output_shapes" in node.attr:
                    placeholder_node.attr["_output_shapes"].CopyFrom(
                        node.attr["_output_shapes"])
                if "shape" in node.attr:
                    placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])

                cur_graph.remove_node(node_name)

                cur_graph.replace_const_node(placeholder_node, [node_name],
                                             original_output)

        import tensorflow as tf
        return tf.compat.v1.graph_util.extract_sub_graph(
            cur_graph.dump_graph(), self.output_node_names)
Ejemplo n.º 9
0
    def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock):
        tag_name = 'tag'
        input_nodes = 'input_nodes'
        output_nodes = 'output_nodes'
        freeze_transform = 'freeze_graph'
        sparsify_transform = 'sparsify_gather'

        base_meta_graph_def = meta_graph_pb2.MetaGraphDef()

        # Add a table initializer.
        table_init_name = 'table_init'
        node_def = node_def_pb2.NodeDef(name=table_init_name,
                                        op='InitializeTableV2')
        base_meta_graph_def.graph_def.node.extend([node_def])

        # Add a group_deps node.
        group_deps_name = 'group_deps'
        node_def = node_def_pb2.NodeDef(name=group_deps_name, op='NoOp')
        node_def.input.extend(['^table_init'])
        base_meta_graph_def.graph_def.node.extend([node_def])

        base_meta_graph_def.collection_def[
            ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
                [table_init_name])
        base_meta_graph_def.collection_def[
            saved_model_constants.LEGACY_INIT_OP_KEY].node_list.value.extend(
                [group_deps_name])

        # Expected metagraphdef.
        expected_meta_graph_def = meta_graph_pb2.MetaGraphDef()
        expected_meta_graph_def.CopyFrom(base_meta_graph_def)
        expected_meta_graph_def.meta_info_def.tags.append(tag_name)

        transformed_graph_def = graph_pb2.GraphDef()
        transformed_graph_def.CopyFrom(expected_meta_graph_def.graph_def)
        freeze_mock.return_value = transformed_graph_def
        graph_transform_mock.return_value = transformed_graph_def

        # Add unsaved init node.
        unsaved_init_name = 'unsaved_node'
        node_def = node_def_pb2.NodeDef(name=unsaved_init_name, op='NoOp')
        base_meta_graph_def.graph_def.node.extend([node_def])

        # Add a saver.
        base_meta_graph_def.saver_def.filename_tensor_name = 'node1'
        base_meta_graph_def.saver_def.save_tensor_name = 'node3'
        base_meta_graph_def.saver_def.restore_op_name = 'node6'

        transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
            base_meta_graph_def, [input_nodes], [output_nodes],
            [freeze_transform, sparsify_transform], [tag_name])

        self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
        freeze_mock.assert_called_once_with(base_meta_graph_def.graph_def,
                                            [output_nodes], [table_init_name],
                                            group_deps_name,
                                            base_meta_graph_def.saver_def,
                                            None)
        graph_transform_mock.assert_called_once_with(transformed_graph_def, [
            input_nodes
        ], [output_nodes, group_deps_name, table_init_name], [
            sparsify_transform + '(group_init_node="sparify_gather_init_op")'
        ])
Ejemplo n.º 10
0
    def do_transformation(self):
        float32_type = dtypes.float32.as_datatype_enum
        qint32_type = dtypes.qint32.as_datatype_enum
        target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(
                self.fuse_patterns[self.version])
        for i in target_nodes:
            # TODO Remove below checker once the TF's limitation removed.
            if len(i) == 5:
                continue

            quantized_node_name = i[0]
            quantized_node = self.graph_info[quantized_node_name].node
            requantize_node_name = i[1]
            requantize_node = self.graph_info[requantize_node_name].node
            requested_output_min_name = requantize_node.input[3]
            requested_output_max_name = requantize_node.input[4]
            deq_node_name = i[2]

            quantized_node_op = i[-1][0]

            new_node = node_def_pb2.NodeDef()

            new_node.op = quantized_node_op + "AndDequantize"
            new_node.name = requantize_node_name
            for _, value in enumerate(quantized_node.input):
                new_node.input.append(value)

            new_node.input.append(requested_output_min_name)
            new_node.input.append(requested_output_max_name)
            if 'T1' in quantized_node.attr:
                new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
            if 'T2' in quantized_node.attr:
                new_node.attr["T2"].CopyFrom(quantized_node.attr['T2'])

            top_node_name = Helper.node_name_from_input(quantized_node.input[0])
            max_filter_node = self.graph_info[new_node.input[6]].node
            min_filter_node = self.graph_info[new_node.input[5]].node
            last_node = self.graph_info[new_node.input[0]].node

            bias_node = self.graph_info[new_node.input[2]].node
            max_input_node = self.graph_info[last_node.input[-1]].node
            min_input_node = self.graph_info[last_node.input[-2]].node
            min_input_value = (min_input_node.attr['value'].tensor.float_val)[0]
            max_input_value = (max_input_node.attr['value'].tensor.float_val)[0]

            max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
            min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]

            weights_tensor = tensor_util.MakeNdarray(
                    self.graph_info[new_node.input[1]].node.attr['value'].tensor)
            bias_tensor = tensor_util.MakeNdarray(
                self.graph_info[new_node.input[2]].node.attr['value'].tensor)
            bias_scale = 255.0 * 127.0 / (
                    (max_input_value -min_input_value) *
                    max(abs(max_filter_value), abs(min_filter_value)))
            relative_scale = 255 * min_input_value / (max_input_value - min_input_value)
            int32_bias = []
            for bias_index, value in enumerate(
                    np.sum(np.array(weights_tensor, dtype=np.int32),
                            axis=0,
                            dtype=np.int32)):
                int32_bias.append(int(bias_tensor[bias_index] *
                                      bias_scale + value * relative_scale))

            bias_node.attr['dtype'].CopyFrom(
                attr_value_pb2.AttrValue(
                    type=float32_type if self.device == 'gpu' else qint32_type))
            bias_node.attr['value'].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    bias_tensor if self.device == 'gpu' else int32_bias, dtypes.
                    float32 if self.device == 'gpu' else dtypes.int32, bias_tensor.shape)))

            bias_node.attr['value'].tensor.dtype = float32_type \
                                    if self.device == 'gpu' else qint32_type
            new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type \
                                            if self.device == 'gpu' else qint32_type))

            new_node.attr["Toutput"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type))

            self.graph_analyzer.remove_node(requantize_node_name)

            if self.graph_info[deq_node_name].outputs:
                self.graph_analyzer.replace_single_node(
                    new_node, [top_node_name], quantized_node_name,
                    self.graph_info[deq_node_name].outputs, deq_node_name)
                self.graph_analyzer.remove_node(deq_node_name)
            else:
                self.graph_analyzer.remove_node(deq_node_name)

                new_node.name = deq_node_name
                self.graph_analyzer.replace_single_node(
                    new_node, [top_node_name], quantized_node_name,
                    [], deq_node_name)

            self.graph_analyzer.remove_node(quantized_node_name)

        return self.graph_analyzer.dump_graph()
Ejemplo n.º 11
0
def _convert_single_op_hint_to_stub(call, graph_def):
    """Given a graph_def, converts `call` into a stub and returns a new graph_def.

  Args:
    call: A single function call to be converted.
    graph_def: A graph_def to use as input (that hass call obviously).
  Returns:
    A new transformed graph-def that has call as a stub (single op).

  Note: after this process, the graph_def can no longer be loaded into
      the tensorflow runtime, so all future manipulations are done in graph_def
      level.
  """
    name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
        graph_def)
    input_names, output_names = call.flattened_inputs_and_outputs()

    reachable_by_input = _bfs_for_reachable_nodes(input_names,
                                                  name_to_input_name)
    reachable_by_output = _bfs_for_reachable_nodes(output_names,
                                                   name_to_input_name)
    input_nodes_set = set(input_names)
    output_nodes_set = set(output_names)
    nodes_after_fuse = []
    nodes_deleted_by_fuse = set()
    # Classify each node. We want to keep everything reachable by input, but
    # we don't know if things that are not reachable by output or input (things
    # after fusing).
    for node in graph_def.node:
        n = _tensor_name_base(node.name)
        if n in reachable_by_output:
            if n not in reachable_by_input and n not in output_nodes_set:
                # n is an internal node. Check to make sure it is really internal.
                # TODO(aselle): this could be done more efficiently by flooding
                # the graph first.
                _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
                                       name_to_input_name)
                nodes_deleted_by_fuse.add(n)
        elif n not in reachable_by_input:
            # n is a node that after all the fusings, so keep it.
            nodes_after_fuse.append(n)
        else:
            # n is a node that is randomly in the graph but not connected to
            # the chain of dependencies.
            pass

    # Make a new graphdef with all the pre-input and input nodes
    out = _graph_pb2.GraphDef()
    reachable_by_input_sorted = sorted(list(reachable_by_input),
                                       key=lambda n: name_to_seq_num[n])
    for node in reachable_by_input_sorted:
        out.node.extend([_copy.deepcopy(name_to_node[node])])

    # Create any stacks to aggregate arguments into to a single input
    # i.e. for static_rnn's.
    # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
    sorted_input_indices = list(call.inputs.keys())
    sorted_input_indices.sort()
    sorted_output_indices = list(call.outputs.keys())
    sorted_output_indices.sort()
    new_node = _node_def_pb2.NodeDef()
    # Delegate to each operand to produce the proper new input for this stub node.
    # In particular, an aggregate input will now be a Pack of some previously
    # non-fused things.
    for input_index in sorted_input_indices:
        inputs = call.inputs[input_index]
        new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
    new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(
        sorted_input_indices)

    # Ceate the function
    new_node.op = call.function_name
    new_node.name = call.uuid
    out.node.extend([new_node])

    # Now call each output argument to give them a chance to make the proper
    # output type and add it to our new_node.
    output_dtypes = []
    for output_index in sorted_output_indices:
        output = call.outputs[output_index]
        output_dtype = (output.aggregate_and_return_name_for_output(
            new_node.name, output_index, out))
        output_dtypes.append(output_dtype)
    new_node.attr["_output_types"].list.type[:] = output_dtypes
    # TODO(aselle): what is right here?
    new_node.attr["_output_quantized"].b = False

    # Add post output nodes that do not depend on the outputs
    for n in nodes_after_fuse:
        should_keep = True
        for input_name in name_to_input_name[n]:
            if input_name in nodes_deleted_by_fuse:
                should_keep = False
        if should_keep:
            out.node.extend([_copy.deepcopy(name_to_node[n])])

    # Misc. graph_def data that needs copying.
    out.library.CopyFrom(graph_def.library)
    out.versions.CopyFrom(graph_def.versions)

    return out
Ejemplo n.º 12
0
    def do_transformation(self):
        """Fuse the quantized op with the following requantize op.
        Returns:
            [graphdef]: the optimized graphdef object
        """
        uint8_type = dtypes.quint8.as_datatype_enum
        float32_type = dtypes.float32.as_datatype_enum
        qint32_type = dtypes.qint32.as_datatype_enum

        while True:
            target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(
                self.fuse_patterns['default'])
            if len(target_nodes) == 0:
                break

            i = target_nodes[0]
            quantized_node_name = i[0]
            quantized_node = self.graph_info[quantized_node_name].node
            requantize_node_name = i[1]
            requantize_node = self.graph_info[requantize_node_name].node
            requested_output_min_name = requantize_node.input[3]
            requested_output_max_name = requantize_node.input[4]

            quantized_node_op = i[-1][0]

            new_node = node_def_pb2.NodeDef()

            new_node.op = quantized_node_op + "AndRequantize"
            new_node.name = requantize_node_name
            for _, value in enumerate(quantized_node.input):
                new_node.input.append(value)
            new_node.input.append(requested_output_min_name)
            new_node.input.append(requested_output_max_name)
            if 'T1' in quantized_node.attr:
                new_node.attr["T1"].CopyFrom(quantized_node.attr['T1'])
            if 'T2' in quantized_node.attr:
                new_node.attr["T2"].CopyFrom(quantized_node.attr['T2'])

            parent_node_name = Helper.node_name_from_input(quantized_node.input[0])
            max_filter_node = self.graph_info[new_node.input[6]].node
            min_filter_node = self.graph_info[new_node.input[5]].node
            last_node = self.graph_info[new_node.input[0]].node

            is_min_first = bool(quantized_node.attr['input_quant_mode'] == 'MIN_FIRST')
            if last_node.op.find('Requantize') != -1 or last_node.op.find('QuantizeV2') != -1:

                bias_node = self.graph_info[new_node.input[2]].node
                max_input_node = self.graph_info[last_node.input[-1]].node
                min_input_node = self.graph_info[last_node.input[-2]].node
                min_input_value = (min_input_node.attr['value'].tensor.float_val)[0]
                max_input_value = (max_input_node.attr['value'].tensor.float_val)[0]

                max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
                min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]

                weights_tensor = tensor_util.MakeNdarray(
                        self.graph_info[new_node.input[1]].node.attr['value'].tensor)
                bias_tensor = tensor_util.MakeNdarray(
                    self.graph_info[new_node.input[2]].node.attr['value'].tensor)
                input_range = max_input_value - \
                    min_input_value if is_min_first else max(
                        abs(max_input_value), abs(min_input_value))
                bias_scale = 255.0 * 127.0 / (
                        input_range * max(abs(max_filter_value), abs(min_filter_value)))
                relative_scale = 255 * min_input_value / (max_input_value - min_input_value)
                int32_bias = []
                for bias_index, value in enumerate(
                        np.sum(np.array(weights_tensor, dtype=np.int32),
                                axis=0,
                                dtype=np.int32)):
                    int32_bias.append(int(bias_tensor[bias_index]
                                          * bias_scale + value * relative_scale))

                bias_node.attr['dtype'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        type=float32_type if self.device == 'gpu' else qint32_type))
                bias_node.attr['value'].CopyFrom(
                    attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                        bias_tensor if self.device == 'gpu' else int32_bias, dtypes.
                        float32 if self.device == 'gpu' else dtypes.int32, bias_tensor.shape)))

                bias_node.attr['value'].tensor.dtype = float32_type \
                                        if self.device == 'gpu' else qint32_type
                new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type \
                                                if self.device == 'gpu' else qint32_type))

                new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=uint8_type))
                #TODO enabled below commit once the graph refactor pre_optimize commmitted.
                if quantized_node_op.find('Relu') == -1:
                    deq_node_name = self.graph_info[requantize_node_name].outputs[0]
                    deq_node = self.graph_info[deq_node_name].node
                    deq_node.attr['T'].CopyFrom(attr_value_pb2.AttrValue(type=uint8_type))
            else:
                new_node.attr["Tbias"].CopyFrom(attr_value_pb2.AttrValue(type=float32_type))

            self.graph_analyzer.replace_single_node(
                new_node, [parent_node_name], quantized_node_name,
                [self.graph_info[requantize_node_name].outputs[0]], requantize_node_name)
            self.graph_analyzer.remove_node(quantized_node_name)

        return self.graph_analyzer.dump_graph()
Ejemplo n.º 13
0
def bn_fold(input_graph_def, conv_name, weight_name, mean_name,
            var_name, beta_name, gamma_name, epsilon_name, add_name):
    input_node_map = get_input_node_map(input_graph_def)

    skip_ops = [conv_name, weight_name, mean_name,
                var_name, beta_name, gamma_name, epsilon_name, add_name]
    skip_ops.extend([])

    try:
        conv_op = input_node_map[conv_name]
        weights_op = input_node_map[weight_name]
        mean_op = input_node_map[mean_name]
        var_op = input_node_map[var_name]
        beta_op = input_node_map[beta_name]
        gamma_op = input_node_map[gamma_name]
        epsilon_op = input_node_map[epsilon_name]
        add_op = input_node_map[add_name]
    except KeyError as e:
        print("node %s not in graph"%e)
        return [],[]

    weights = values_from_const(weights_op)
    mean_value = values_from_const(mean_op)
    var_value = values_from_const(var_op)
    beta_value = values_from_const(beta_op)
    gamma_value = values_from_const(gamma_op)
    variance_epsilon_value = values_from_const(epsilon_op)

    new_ops = []

    scale_value = (
        (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) *
        gamma_value)

    offset_value = (-mean_value * scale_value) + beta_value
    scaled_weights = np.copy(weights)
    it = np.nditer(
        scaled_weights, flags=["multi_index"], op_flags=["readwrite"])
    while not it.finished:

      if conv_op.op == "DepthwiseConv2dNative":
        current_scale = scale_value[it.multi_index[2]]
      else:
        current_scale = scale_value[it.multi_index[3]]
      it[0] *= current_scale
      it.iternext()
    scaled_weights_op = node_def_pb2.NodeDef()
    scaled_weights_op.op = "Const"
    scaled_weights_op.name = weights_op.name
    scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
    scaled_weights_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            scaled_weights, weights.dtype.type, weights.shape)))

    new_conv_op = node_def_pb2.NodeDef()
    new_conv_op.CopyFrom(conv_op)
    offset_op = node_def_pb2.NodeDef()
    offset_op.op = "Const"
    offset_op.name = conv_op.name + "_bn_offset"
    offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"])
    offset_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            offset_value, mean_value.dtype.type, offset_value.shape)))

    new_add_op = node_def_pb2.NodeDef()
    new_add_op.CopyFrom(add_op)
    del new_add_op.input[:]
    new_add_op.input.extend([new_conv_op.name, offset_op.name])

    new_ops.extend([scaled_weights_op, new_conv_op, offset_op, new_add_op])
    return skip_ops,new_ops
Ejemplo n.º 14
0
         ]


    for conv_name, weight_name, mean_name,var_name, beta_name, gamma_name, epsilon_name, add_name in zip(conv_names,
                            weight_names, mean_names,var_names, beta_names, gamma_names, epsilon_names, add_names):

        skip_op, new_op = bn_fold(output_graph_def, conv_name, weight_name, mean_name,
                                var_name, beta_name, gamma_name, epsilon_name, add_name)
        skip_ops.extend(skip_op)
        new_ops.extend(new_op)

    result_graph_def = graph_pb2.GraphDef()
    for node in output_graph_def.node:
        if node.name in skip_ops:
            continue
        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        result_graph_def.node.extend([new_node])

    result_graph_def.node.extend(new_ops)
    output_graph_def = result_graph_def
    output_graph_def = strip_unused_lib.strip_unused(
        output_graph_def, input_node_names=input_node_names,
        output_node_names=output_node_names,
        placeholder_type_enum=dtypes.uint8.as_datatype_enum)

    with open(output_pb_file,'wb') as f:
        f.write(output_graph_def.SerializeToString())


Ejemplo n.º 15
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.
    variable_names_whitelist: The set of variable names to convert (by default,
                              all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants.

  Returns:
    GraphDef containing a simplified version of the original.
  """
    def get_input_name(node):
        """Gets the name of the first input. Errors if suffix is not :0."""
        details = node.input[0].split(":")
        if len(details) == 1 or int(details[1]) == 0:
            return details[0]
        # While it is valid for input tensors to have a suffix that is not :0, this
        # method is used to find the associated ops, not tensors, and therefore it
        # is not valid.
        raise ValueError("Tensor name '{0}' is invalid.".format(node.input[0]))

    def create_const_op(node_name, dtype, data, data_shape=None):
        """Creates a Const op."""
        output_node = node_def_pb2.NodeDef()
        output_node.op = "Const"
        output_node.name = node_name
        output_node.attr["dtype"].CopyFrom(dtype)
        output_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                data, dtype=dtype.type, shape=data_shape)))
        return output_node

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    # Identify the ops in the graph.
    map_name_to_node = {node.name: node for node in inference_graph.node}

    # Get list of variables.
    variable_names = []
    variable_dict_names = []
    resource_identity_types = {}
    for node in inference_graph.node:
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None
                 and variable_name not in variable_names_whitelist)
                    or (variable_names_blacklist is not None
                        and variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
        elif node.op in ["ReadVariableOp", "ResourceGather"]:
            # There can be one or more Identity ops in between the ReadVariableOp and
            # VarHandleOp.  Store the Identity ops with the associated dtypes.
            source_op_name = get_input_name(node)
            while map_name_to_node[source_op_name].op == "Identity":
                resource_identity_types[source_op_name] = node.attr["dtype"]
                source_op_name = get_input_name(
                    map_name_to_node[source_op_name])
            if map_name_to_node[source_op_name].op != "VarHandleOp":
                raise ValueError("Cannot find the variable that is an input "
                                 "to the ReadVariableOp.")

    # Gets map of variables and the associated data.
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    variables_data_map = dict(zip(variable_dict_names, returned_variables))
    logging.info("Froze %d variables.", len(returned_variables))

    # Reconstruct the graph with constants in place of variables.
    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in variables_data_map:
            data = variables_data_map[input_node.name]
            output_node = create_const_op(input_node.name,
                                          input_node.attr["dtype"], data,
                                          data.shape)
            how_many_converted += 1
        elif input_node.name in resource_identity_types:
            # Converts the Identities of type RESOURCE_DT to the appropriate type
            # based on the input they are referencing.
            output_node.CopyFrom(input_node)
            output_node.attr["T"].CopyFrom(
                resource_identity_types[input_node.name])
        elif input_node.op == "ReadVariableOp":
            # The first branch converts all VarHandleOps of ResourceVariables to
            # constants, so we need to convert the associated ReadVariableOps to
            # Identity ops.
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        elif input_node.op == "ResourceGather":
            # The first branch converts all VarHandleOps of ResourceGather to
            # constants, so we need to convert the associated ResourceGather to Gather
            # ops with a Const axis feeding into it.
            if input_node.attr["batch_dims"].i != 0:
                raise ValueError(
                    "batch_dims != 0 is not supported by freeze_graph.")
            axis_data = input_node.attr["batch_dims"].i
            axis_node_name = input_node.name + "/axis"
            axis_dtype = input_node.attr["Tindices"]
            output_axis_node = create_const_op(axis_node_name, axis_dtype,
                                               axis_data)
            output_graph_def.node.extend([output_axis_node])

            output_node.op = "GatherV2"
            output_node.name = input_node.name
            output_node.input.extend(
                [input_node.input[0], input_node.input[1], axis_node_name])
            output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
            output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
            output_node.attr["Taxis"].CopyFrom(axis_dtype)
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)
    logging.info("Converted %d variables to const ops.", how_many_converted)
    return output_graph_def
Ejemplo n.º 16
0
class TestFoldConstant(unittest.TestCase):
    x_node = node_def_pb2.NodeDef()
    x_node.name = "placeholder"
    x_node.op = "Placeholder"

    input0_node = node_def_pb2.NodeDef()
    input0_node.name = "input0"
    input0_node.op = "Const"
    input0_value = np.float32(np.abs(np.random.randn(4, 3, 2)))
    input0_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input0_value, input0_value.dtype.type, input0_value.shape)))

    input1_node = node_def_pb2.NodeDef()
    input1_node.name = "input1"
    input1_node.op = "Const"
    input1_value = np.float32(np.abs(np.random.randn(4, 1, 1)))
    input1_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input1_value, input1_value.dtype.type, input1_value.shape)))

    add_node = node_def_pb2.NodeDef()
    add_node.op = "Add"
    add_node.name = "add"
    add_node.input.extend([input0_node.name, input1_node.name])

    input2_node = node_def_pb2.NodeDef()
    input2_node.name = "input2"
    input2_node.op = "Const"
    input2_value = np.float32(np.abs(np.random.randn(1)))
    input2_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input2_value, input2_value.dtype.type, input2_value.shape)))

    input3_node = node_def_pb2.NodeDef()
    input3_node.name = "input3"
    input3_node.op = "Const"
    input3_value = np.float32(np.abs(np.random.randn(1)))
    input3_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input3_value, input3_value.dtype.type, input3_value.shape)))

    switch_node = node_def_pb2.NodeDef()
    switch_node.name = "switch"
    switch_node.op = "Switch"

    input4_node = node_def_pb2.NodeDef()
    input4_node.name = "input4"
    input4_node.op = "Const"
    input4_value = np.float32(np.abs(np.random.randn(1)))
    input4_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input4_value, input4_value.dtype.type, input4_value.shape)))
    input4_node.input.extend([switch_node.name])

    input5_node = node_def_pb2.NodeDef()
    input5_node.name = "input5"
    input5_node.op = "Const"
    input5_value = np.float32(np.abs(np.random.randn(1)))
    input5_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input5_value, input5_value.dtype.type, input5_value.shape)))
    input5_node.input.extend([switch_node.name])

    cond_end = node_def_pb2.NodeDef()
    cond_end.name = "cond"
    cond_end.op = "Add"
    cond_end.input.extend([input4_node.name, input5_node.name])

    mul_node = node_def_pb2.NodeDef()
    mul_node.op = "Mul"
    mul_node.name = "mul"
    mul_node.input.extend([add_node.name, input3_node.name])

    sqrt_node = node_def_pb2.NodeDef()
    sqrt_node.name = "rsqrt"
    sqrt_node.op = "Rsqrt"
    sqrt_node.input.extend([mul_node.name])

    relu_node = node_def_pb2.NodeDef()
    relu_node.op = "Relu"
    relu_node.name = "relu"
    relu_node.input.extend([sqrt_node.name])

    block_node = node_def_pb2.NodeDef()
    block_node.name = "block_output"
    block_node.op = "Add"
    block_node.input.extend([x_node.name, relu_node.name])

    res_node = node_def_pb2.NodeDef()
    res_node.name = "res_add"
    res_node.op = "Add"
    res_node.input.extend([sqrt_node.name, input2_node.name])

    end_node = node_def_pb2.NodeDef()
    end_node.name = "end"
    end_node.op = "Add"
    end_node.input.extend([block_node.name, res_node.name])

    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        x_node, input0_node, input1_node, input2_node, input3_node, add_node,
        mul_node, sqrt_node, relu_node, block_node, res_node, end_node
    ])

    def test_fold_constant(self):

        graph = self.graph_def
        rewriter = GraphFoldConstantOptimizer(graph)
        new_graph = rewriter.do_transformation()

        for node in new_graph.node:
            assert node.name in [
                "placeholder", "block_output", "rsqrt_const", "relu",
                "res_add_const", "end"
            ]

    def test_condition_fold_constant(self):
        graph_def = graph_pb2.GraphDef()
        graph_def.node.extend([
            self.cond_end, self.input4_node, self.input5_node, self.switch_node
        ])
        rewriter = GraphFoldConstantOptimizer(graph_def)
        new_graph = rewriter.do_transformation()
        for node in new_graph.node:
            assert node.name in ["switch", "cond", "input4", "input5"]

    def test_slice_int_input(self):
        graph_def = graph_pb2.GraphDef()
        index0_node = node_def_pb2.NodeDef()
        index0_node.name = "index0"
        index0_node.op = "Const"
        index0_value = np.array(3).astype(np.int32).reshape(())
        index0_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                index0_value, index0_value.dtype.type, index0_value.shape)))

        index1_node = node_def_pb2.NodeDef()
        index1_node.name = "index1"
        index1_node.op = "Const"
        index1_value = np.array(1).astype(np.int32).reshape(())
        index1_node.attr["value"].CopyFrom(
            attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                index1_value, index1_value.dtype.type, index1_value.shape)))

        minus_node = node_def_pb2.NodeDef()
        minus_node.name = "sub"
        minus_node.op = "Sub"
        minus_node.input.extend([index0_node.name, index1_node.name])

        graph_def.node.extend([index0_node, index1_node, minus_node])
        rewriter = GraphFoldConstantOptimizer(graph_def)
        new_graph = rewriter.do_transformation()
        with tf.compat.v1.Session() as sess:
            tf.compat.v1.import_graph_def(new_graph)
def generate_output_graph(input_graph_def, input_node_map, output_node_map,
                          fuse_op_list, fuse_op_deq_list):
    output_graph_def = graph_pb2.GraphDef()
    skip_list = []
    skip_node_name = []
    int8_type = dtypes.qint8.as_datatype_enum
    uint8_type = dtypes.quint8.as_datatype_enum
    float32_type = dtypes.float32.as_datatype_enum
    qint32_type = dtypes.qint32.as_datatype_enum
    for index, node in enumerate(input_graph_def.node):
        if index in fuse_op_list:
            const_node_1 = input_graph_def.node[index + 1]
            const_node_2 = input_graph_def.node[index + 2]
            requantize_node = input_graph_def.node[index + 3]
            new_node = node_def_pb2.NodeDef()

            new_node.op = node.op + "AndRequantize"
            new_node.name = requantize_node.name
            for _, value in enumerate(node.input):
                new_node.input.append(value)

            new_node.input.append(const_node_1.name)
            new_node.input.append(const_node_2.name)

            new_node.attr["Tinput"].CopyFrom(node.attr['Tinput'])
            new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter'])
            new_node.attr["strides"].CopyFrom(node.attr['strides'])
            new_node.attr["padding"].CopyFrom(node.attr['padding'])
            if input_node_map[new_node.input[0]].op.find("Requantize") != -1:
                bias_node = input_node_map[new_node.input[2]]
                last_node = input_node_map[new_node.input[0]]
                max_input_node = (input_node_map[last_node.input[4][:-2]])
                min_input_node = (input_node_map[last_node.input[3][:-2]])
                max_filter = input_node_map[new_node.input[6]]
                min_filter = input_node_map[new_node.input[5]]

                min_input = (min_input_node.attr['value'].tensor.float_val)[0]
                max_input = (max_input_node.attr['value'].tensor.float_val)[0]
                if 'Depthwise' in node.op or "RequantizePerChannel" in [
                        node.op for node in output_node_map[node.name]
                ]:

                    channel_size = max_filter.attr[
                        'value'].tensor.tensor_shape.dim[0].size
                    max_filter_tensor = tensor_util.MakeNdarray(
                        max_filter.attr['value'].tensor)
                    min_filter_tensor = tensor_util.MakeNdarray(
                        min_filter.attr['value'].tensor)
                else:

                    channel_size = 1
                    max_filter_tensor = []
                    min_filter_tensor = []
                    max_filter_tensor.append(
                        (max_filter.attr['value'].tensor.float_val)[0])
                    min_filter_tensor.append(
                        (min_filter.attr['value'].tensor.float_val)[0])

                bias_tensor = tensor_util.MakeNdarray(
                    input_node_map[new_node.input[2]].attr['value'].tensor)
                bias_length = bias_tensor.shape[0]
                scales = []
                for i in range(channel_size):
                    scales.append(255.0 * 127.0 /
                                  (max(abs(max_input), abs(min_input)) *
                                   max(abs(max_filter_tensor[i]),
                                       abs(min_filter_tensor[i]))))
                int32_bias = []
                if channel_size > 1:
                    for i in range(bias_length):
                        int32_bias.append((int)(bias_tensor[i] * scales[i]))
                else:
                    for i in range(bias_length):
                        int32_bias.append((int)(bias_tensor[i] * scales[0]))
                bias_node.attr['dtype'].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                bias_node.attr['value'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            int32_bias, dtypes.int32, bias_tensor.shape)))

                bias_node.attr['value'].tensor.dtype = qint32_type
                skip_node_name.append(bias_node.name)
                output_graph_def.node.extend([bias_node])
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
            else:
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=float32_type))

            if "padding_list" in node.attr:
                new_node.attr["padding_list"].CopyFrom(
                    node.attr['padding_list'])
            if "dilations" in node.attr:
                new_node.attr["dilations"].CopyFrom(node.attr['dilations'])

            if node.op == "QuantizedConv2D" or node.op == "QuantizedConv2DWithBias":
                new_node.attr["out_type"].CopyFrom(
                    attr_value_pb2.AttrValue(type=int8_type))
            else:
                new_node.attr["out_type"].CopyFrom(
                    attr_value_pb2.AttrValue(type=uint8_type))

            skip_list.append(index + 1)
            skip_list.append(index + 2)
            skip_list.append(index + 3)
            output_graph_def.node.extend(
                [new_node, const_node_1, const_node_2])
        elif index in skip_list or node.name in skip_node_name:
            continue
        elif node.op == "Dequantize":
            new_node = node_def_pb2.NodeDef()
            new_node.CopyFrom(node)
            new_node.attr["mode"].s = b"SCALED"
            p_node = input_node_map[new_node.input[0]]
            pp_node = input_node_map[p_node.name].input[0]
            if input_node_map[pp_node].op.find("Relu") != -1 or p_node.op in (
                    "QuantizedAvgPool", "QuantizedMaxPool",
                    "QuantizedConcatV2"):
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=uint8_type))
            elif input_node_map[pp_node].op.find(
                    "QuantizedMatMulWithBias") != -1 and p_node.op.find(
                        "Requantize") != -1:
                new_node.attr["mode"].s = node.attr["mode"].s
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=node.attr["T"].type))
            else:
                new_node.attr["T"].CopyFrom(
                    attr_value_pb2.AttrValue(type=int8_type))
            output_graph_def.node.extend([new_node])
        elif index in fuse_op_deq_list:
            original_summand_node = input_node_map[
                input_graph_def.node[index].input[-1]]
            sum_const_node_1 = input_graph_def.node[index + 1]
            sum_const_node_2 = input_graph_def.node[index + 2]
            sum_requantize_node = input_graph_def.node[index + 3]

            new_node = node_def_pb2.NodeDef()

            new_node.op = node.op + "AndRequantize"
            new_node.name = sum_requantize_node.name
            for _, value in enumerate(node.input[:-1]):
                new_node.input.append(value)
            new_node.input.append(sum_const_node_1.name)
            new_node.input.append(sum_const_node_2.name)
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0])
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0] + ":1")
            new_node.input.append(
                input_node_map[original_summand_node.name].input[0] + ":2")

            # skip_list.append(index + 1)
            # skip_list.append(index + 2)
            skip_list.append(index + 3)

            new_node.attr["Tinput"].CopyFrom(node.attr['Tinput'])
            new_node.attr["Tfilter"].CopyFrom(node.attr['Tfilter'])
            new_node.attr["strides"].CopyFrom(node.attr['strides'])
            new_node.attr["padding"].CopyFrom(node.attr['padding'])
            if input_node_map[new_node.input[0]].op.find("Requantize") != -1:

                bias_node = input_node_map[new_node.input[2]]
                last_node = input_node_map[new_node.input[0]]
                max_input_node = (input_node_map[last_node.input[4][:-2]])
                min_input_node = (input_node_map[last_node.input[3][:-2]])
                max_filter = input_node_map[new_node.input[6]]
                min_filter = input_node_map[new_node.input[5]]

                min_input = (min_input_node.attr['value'].tensor.float_val)[0]
                max_input = (max_input_node.attr['value'].tensor.float_val)[0]

                if "RequantizePerChannel" in [
                        node.op for node in output_node_map[node.name]
                ]:
                    channel_size = max_filter.attr[
                        'value'].tensor.tensor_shape.dim[0].size
                    max_filter_tensor = tensor_util.MakeNdarray(
                        max_filter.attr['value'].tensor)
                    min_filter_tensor = tensor_util.MakeNdarray(
                        min_filter.attr['value'].tensor)
                else:
                    channel_size = 1
                    max_filter_tensor = []
                    min_filter_tensor = []
                    max_filter_tensor.append(
                        (max_filter.attr['value'].tensor.float_val)[0])
                    min_filter_tensor.append(
                        (min_filter.attr['value'].tensor.float_val)[0])

                bias_tensor = (tensor_util.MakeNdarray(
                    input_node_map[new_node.input[2]].attr['value'].tensor))
                bias_length = bias_tensor.shape[0]
                scales = []
                for i in range(channel_size):
                    scales.append(255.0 * 127.0 /
                                  (max(abs(max_input), abs(min_input)) *
                                   max(abs(max_filter_tensor[i]),
                                       abs(min_filter_tensor[i]))))
                int32_bias = []
                if channel_size > 1:
                    for i in range(bias_length):
                        int32_bias.append(int(bias_tensor[i] * scales[i]))
                else:
                    for i in range(bias_length):
                        int32_bias.append(int(bias_tensor[i] * scales[0]))
                bias_node.attr['dtype'].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                bias_node.attr['value'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(
                            int32_bias, dtypes.int32, bias_tensor.shape)))
                bias_node.attr['value'].tensor.dtype = qint32_type
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=qint32_type))
                skip_node_name.append(bias_node.name)
                output_graph_def.node.extend([bias_node])
            else:
                new_node.attr["Tbias"].CopyFrom(
                    attr_value_pb2.AttrValue(type=float32_type))

            if "padding_list" in node.attr:
                new_node.attr["padding_list"].CopyFrom(
                    node.attr['padding_list'])
            if "dilations" in node.attr:
                new_node.attr["dilations"].CopyFrom(node.attr['dilations'])

            new_node.attr["out_type"].CopyFrom(
                attr_value_pb2.AttrValue(type=uint8_type))

            summand_op_type = uint8_type if dtypes.as_dtype(
                original_summand_node.attr["T"].type
            ) == uint8_type else int8_type

            if summand_op_type == int8_type:
                new_node.op = "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"

            new_node.attr["Tsummand"].CopyFrom(
                attr_value_pb2.AttrValue(type=summand_op_type))
            output_graph_def.node.extend([new_node])
        else:
            new_node = node_def_pb2.NodeDef()
            new_node.CopyFrom(node)
            output_graph_def.node.extend([new_node])
    return output_graph_def
Ejemplo n.º 18
0
    def do_transformation(self):
        """Removes batch normalization ops by folding them into convolutions.

        Batch normalization during training has multiple dynamic parameters that are
        updated, but once the graph is finalized these become constants. That means
        there's an opportunity to reduce the computations down to a scale and
        addition, rather than the more expensive multiple ops, and even bake the
        scaling into the convolution weights. This function identifies the typical
        pattern of batch normalization subgraphs, and performs the transformation to
        fold the computations down into a simpler form. It currently only spots batch
        normalization that's performed by the BatchNormWithGlobalNormalization and
        FusedBatchNorm ops, and will need to be extended in the future to handle the
        newer style.

        Returns:
          Modified graph with BN ops removed, and modified weights.

        Raises:
          ValueError: If the graph is badly formed with duplicate node names.
        """
        cur_graph = GraphAnalyzer()
        cur_graph.graph = self.model

        graph_info = cur_graph.parse_graph()
        target_nodes = cur_graph.query_fusion_pattern_nodes(
            [["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2"),
             ["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3"]])
        for node_combination in target_nodes:
            matched_node = node_combination[:-1]
            has_add_op = True if len(node_combination[-1]) == 3 else False
            conv_node = graph_info[Helper.node_name_from_input(matched_node[0])].node
            weights_node_name = graph_info[Helper.node_name_from_input(
                matched_node[0])].node.input[1]
            weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node
            bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node

            if weights_node.op != "Const":
                self.logger.warning("Didn't find expected conv Constant input to '%s',"
                                    " found %s instead. Maybe because freeze_graph wasn't"
                                    " run first?" % (bn_node.name, weights_node_name))
                continue
            weights = Helper.values_from_const(weights_node)

            if conv_node.op == "Conv2D":
                channel_count = weights.shape[3]
            elif conv_node.op == "DepthwiseConv2dNative":
                channel_count = weights.shape[2] * weights.shape[3]

            mean_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("mean_op")])
            mean_node = graph_info[mean_node_name].node

            if mean_node.op != "Const":
                continue

            mean_value = Helper.values_from_const(mean_node)

            if has_add_op:
                bias_node_name = graph_info[Helper.node_name_from_input(
                    matched_node[1])].node.input[1]
                bias_node = graph_info[Helper.node_name_from_input(bias_node_name)].node
                if bias_node.op != "Const":
                    continue

                if mean_value.shape != (channel_count, ):
                    continue

                mean_value = mean_value - Helper.values_from_const(bias_node)
                cur_graph.remove_node(bias_node.name)
                cur_graph.remove_node(matched_node[1])

            if mean_value.shape != (channel_count, ):
                self.logger.warning("Incorrect shape for mean, found %s, expected %s,"
                                    " for node %s" % (str(mean_value.shape), str(
                                        (channel_count, )), conv_node.name))
                continue
            var_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("var_op")])
            var_node = graph_info[var_node_name].node
            if var_node.op != "Const":
                continue
            var_value = Helper.values_from_const(var_node)

            if var_value.shape != (channel_count, ):
                continue

            beta_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("beta_op")])
            beta_node = graph_info[beta_node_name].node
            if beta_node.op != "Const":
                continue
            beta_value = Helper.values_from_const(beta_node)

            if beta_value.shape != (channel_count, ):
                continue

            gamma_node_name = Helper.node_name_from_input(
                bn_node.input[self.INPUT_ORDER[bn_node.op].index("gamma_op")])
            gamma_node = graph_info[gamma_node_name].node

            if gamma_node.op != "Const":
                continue
            gamma_value = Helper.values_from_const(gamma_node)

            if gamma_value.shape != (channel_count, ):
                continue

            variance_epsilon_value = bn_node.attr[self.EPSILON_ATTR[bn_node.op]].f

            if self.scale_after_normalization(bn_node):
                scale_value = (
                    (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) *
                    gamma_value)
            else:
                scale_value = (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value))

            offset_value = (-mean_value * scale_value) + beta_value


            if conv_node.op == "Conv2D":
                original_shape =weights.shape
                tmp_shape = (original_shape[-1], int(weights.size/original_shape[-1]))
                tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)]
                scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape)
                reshape_scale = np.array(scale_value).reshape(len(scale_value), 1)
                scaled_weights = np.multiply(
                    scaled_weights, reshape_scale).transpose().reshape(original_shape)
            elif conv_node.op == "DepthwiseConv2dNative":
                scaled_weights = np.copy(weights)
                it = np.nditer(scaled_weights, flags=["multi_index"], op_flags=["readwrite"])
                channel_multiplier = weights.shape[3]
                while not it.finished:
                    current_scale = scale_value[it.multi_index[2] * channel_multiplier +
                                                it.multi_index[3]]
                    it[0] *= current_scale
                    it.iternext()

            scaled_weights_node = node_def_pb2.NodeDef()
            scaled_weights_node.op = "Const"
            scaled_weights_node.name = weights_node_name + "_bn_offset"
            scaled_weights_node.attr["dtype"].CopyFrom(weights_node.attr["dtype"])
            scaled_weights_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    scaled_weights, weights.dtype.type, weights.shape)))
            cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name)

            offset_node = node_def_pb2.NodeDef()
            offset_node.op = "Const"
            offset_node.name = conv_node.name + "_bn_offset"
            offset_node.attr["dtype"].CopyFrom(mean_node.attr["dtype"])
            offset_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    offset_value, mean_value.dtype.type, offset_value.shape)))
            bias_add_node = node_def_pb2.NodeDef()
            bias_add_node.op = "BiasAdd"
            bias_add_node.name = bn_node.name
            bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"])
            bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"])
            bias_add_node.input.extend([conv_node.name, offset_node.name])

            cur_graph.add_node(offset_node, [], [bias_add_node.name])
            cur_graph.add_node(bias_add_node, conv_node.name,
                               graph_info[Helper.node_name_from_input(matched_node[-1])].outputs)
            cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name)

            cur_graph.remove_node(weights_node_name)
            cur_graph.remove_node(mean_node_name)
            cur_graph.remove_node(var_node_name)
            cur_graph.remove_node(beta_node_name)
            cur_graph.remove_node(gamma_node_name)

        return cur_graph.dump_graph()
Ejemplo n.º 19
0
def remove_training_nodes(input_graph, protected_nodes=None):
    """Prunes out nodes that aren't needed for inference.

  There are nodes like Identity and CheckNumerics that are only useful
  during training, and can be removed in graphs that will be used for
  nothing but inference. Here we identify and remove them, returning an
  equivalent graph. To be specific, CheckNumerics nodes are always removed, and
  Identity nodes that aren't involved in control edges are spliced out so that
  their input and outputs are directly connected.

  Args:
    input_graph: Model to analyze and prune.
    protected_nodes: An optional list of names of nodes to be kept
      unconditionally. This is for example useful to preserve Identity output
      nodes.

  Returns:
    A list of nodes with the unnecessary ones removed.
  """
    if not protected_nodes:
        protected_nodes = []

    types_to_remove = {"CheckNumerics": True}

    input_nodes = input_graph.node
    names_to_remove = {}
    for node in input_nodes:
        if node.op in types_to_remove and node.name not in protected_nodes:
            names_to_remove[node.name] = True

    nodes_after_removal = []
    for node in input_nodes:
        if node.name in names_to_remove:
            continue
        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            if input_name in names_to_remove:
                continue
            new_node.input.append(full_input_name)
        nodes_after_removal.append(new_node)

    types_to_splice = {"Identity": True}
    control_input_names = set()
    node_names_with_control_input = set()
    for node in nodes_after_removal:
        for node_input in node.input:
            if "^" in node_input:
                control_input_names.add(node_input.replace("^", ""))
                node_names_with_control_input.add(node.name)

    names_to_splice = {}
    for node in nodes_after_removal:
        if node.op in types_to_splice and node.name not in protected_nodes:
            # We don't want to remove nodes that have control edge inputs, because
            # they might be involved in subtle dependency issues that removing them
            # will jeopardize.
            if node.name not in node_names_with_control_input:
                names_to_splice[node.name] = node.input[0]

    # We also don't want to remove nodes which are used as control edge inputs.
    names_to_splice = {
        name: value
        for name, value in names_to_splice.items()
        if name not in control_input_names
    }

    nodes_after_splicing = []
    for node in nodes_after_removal:
        if node.name in names_to_splice:
            continue
        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            while input_name in names_to_splice:
                full_input_name = names_to_splice[input_name]
                input_name = re.sub(r"^\^", "", full_input_name)
            new_node.input.append(full_input_name)
        nodes_after_splicing.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_splicing)
    return output_graph
Ejemplo n.º 20
0
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_tensor_names: A list of the nodes we use as inputs.
    output_tensor_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed. and a map containing the old input
    names to the new input names

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
    for name in input_tensor_names:
        if ":" not in name:
            raise ValueError("Input '%s' appears to refer to a Operation, "
                             "not a Tensor." % name)

    old2new = {}

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    not_found = {name for name in input_tensor_names}
    input_node_names = {name.split(":")[0] for name in input_tensor_names}
    output_node_names = list(
        {name.split(":")[0]
         for name in output_tensor_names})
    inputs_replaced_graph_def = graph_pb2.GraphDef()
    for node in input_graph_def.node:
        if node.name not in input_node_names:
            for i in range(len(node.input)):
                if _append_port(node.input[i]) in input_tensor_names:
                    old_name = _append_port(node.input[i])
                    not_found.remove(old_name)
                    new_input_name = node.input[i].replace(":", "_")
                    placeholder_node = node_def_pb2.NodeDef()
                    placeholder_node.op = "Placeholder"
                    placeholder_node.name = new_input_name
                    if isinstance(placeholder_type_enum, list):
                        input_node_index = input_tensor_names.index(old_name)
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum[input_node_index]))
                    else:
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum))
                    if "_output_shapes" in node.attr:
                        placeholder_node.attr["_output_shapes"].CopyFrom(
                            node.attr["_output_shapes"])
                    node.input[i] = new_input_name
                    old2new[old_name] = new_input_name + ":0"
                    inputs_replaced_graph_def.node.extend([placeholder_node])
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    if not_found:
        raise KeyError("The following input nodes were not found: %s\n" %
                       not_found)

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def, old2new
Ejemplo n.º 21
0
    def generate_output_graph(self, input_graph_def, input_node_map,
                              fuse_op_name):
        output_graph_def = graph_pb2.GraphDef()
        skip_list = []
        skip_node_name = []
        for index, node in enumerate(input_graph_def.node):

            if node.name in fuse_op_name:
                skip_list.append(index + 1)

                original_node = input_node_map[node.name]
                mul_node = input_node_map[fuse_op_name[node.name]]
                weights_node_name = original_node.input[1]
                weights_node = input_node_map[weights_node_name]
                mul_value_node_name = mul_node.input[1]
                mul_value_node = input_node_map[mul_value_node_name]

                new_node = node_def_pb2.NodeDef()
                new_node.op = original_node.op
                new_node.name = mul_node.name

                for _, value in enumerate(node.input):
                    new_node.input.append(value)

                if original_node.op == "DepthwiseConv2dNative":
                    weights_col = weights_node.attr[
                        'value'].tensor.tensor_shape.dim[
                            2].size * weights_node.attr[
                                'value'].tensor.tensor_shape.dim[3].size
                elif original_node.op == "Conv2D":
                    weights_col = weights_node.attr[
                        'value'].tensor.tensor_shape.dim[3].size
                else:
                    weights_col = weights_node.attr[
                        'value'].tensor.tensor_shape.dim[1].size
                mul_value_node_tensor = mul_value_node.attr['value'].tensor
                weights_node_tensor = weights_node.attr['value'].tensor

                if len(mul_value_node_tensor.tensor_shape.dim
                       ) != 1 or mul_value_node_tensor.tensor_shape.dim[
                           0].size != weights_col:
                    print("Invalid Mul OP fusion.")

                mul_value_node_list = [
                    i for i in tensor_util.MakeNdarray(
                        mul_value_node_tensor).flat
                ]
                new_weights = []
                for index, i in enumerate(
                        tensor_util.MakeNdarray(weights_node_tensor).flat):
                    new_weights_value = i * mul_value_node_list[
                        index % len(mul_value_node_list)]
                    new_weights.append(new_weights_value)

                weights_node.attr['value'].CopyFrom(
                    attr_value_pb2.
                    AttrValue(tensor=tensor_util.make_tensor_proto(
                        new_weights, dtypes.float32,
                        tensor_util.MakeNdarray(weights_node_tensor).shape)))
                skip_node_name.append(weights_node.name)
                output_graph_def.node.extend([weights_node])
                for key in original_node.attr:
                    new_node.attr[key].CopyFrom(original_node.attr[key])

                output_graph_def.node.extend([new_node])

            elif index in skip_list or node.name in skip_node_name:
                continue
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                output_graph_def.node.extend([new_node])
        return output_graph_def
Ejemplo n.º 22
0
    def eightbitize_nodes_recursively(self, current_node):
        if current_node.name in self.state.already_visited:
            if (self.should_merge_with_fake_quant_node()
                    or current_node.name in self.state.merged_with_fake_quant):
                raise ValueError(
                    "Unsupported graph structure: output of node %s "
                    "is processed by a FakeQuant* node and should have "
                    "no other outputs.", current_node.name)
            return
        self.state.already_visited[current_node.name] = True

        for i, input_node_name in enumerate(current_node.input):
            quantize_input = False
            if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
                                   "AvgPool", "Relu", "Relu6",
                                   "BatchNormWithGlobalNormalization"):
                quantize_input = True
            elif current_node.op == "Concat" and i > 0:
                quantize_input = (dtypes.as_dtype(
                    current_node.attr["T"].type) == dtypes.float32)
            elif current_node.op == "Reshape" and i == 0:
                quantize_input = (dtypes.as_dtype(
                    current_node.attr["T"].type) == dtypes.float32)

            self.state.output_node_stack.append(
                (current_node, i, quantize_input))

            input_node_name = node_name_from_input(input_node_name)
            input_node = self.nodes_map[input_node_name]
            self.eightbitize_nodes_recursively(input_node)

            self.state.output_node_stack.pop()

        if current_node.op == "MatMul":
            self.eightbitize_mat_mul_node(current_node)
        elif current_node.op == "Conv2D":
            self.eightbitize_conv_node(current_node)
        elif current_node.op == "BiasAdd":
            self.eightbitize_bias_add_node(current_node)
        elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
            self.eightbitize_single_input_tensor_node(current_node,
                                                      self.add_pool_function)
        elif current_node.op == "Relu" or current_node.op == "Relu6":
            self.eightbitize_single_input_tensor_node(current_node,
                                                      self.add_relu_function)
        elif (current_node.op == "Concat" and dtypes.as_dtype(
                current_node.attr["T"].type) == dtypes.float32):
            self.eightbitize_concat_node(current_node)
        elif current_node.op == "BatchNormWithGlobalNormalization":
            self.eightbitize_batch_norm_node(current_node)
        elif (current_node.op == "Reshape" and dtypes.as_dtype(
                current_node.attr["T"].type) == dtypes.float32):
            self.eightbitize_reshape_node(current_node)
        elif (self.input_range
              and current_node.op in ("Placeholder", "PlaceholderV2")):
            self.eightbitize_placeholder_node(current_node)
        elif current_node.op == "FakeQuantWithMinMaxVars":
            pass
        elif current_node.op == "Const":
            if self.should_quantize_const(current_node):
                for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
                    self.add_output_graph_node(n)
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(current_node)
                self.add_output_graph_node(new_node)
        else:
            new_node = node_def_pb2.NodeDef()
            new_node.CopyFrom(current_node)
            self.add_output_graph_node(new_node)

        if (self.should_merge_with_fake_quant_node() and current_node.name
                not in self.state.merged_with_fake_quant):
            raise ValueError(
                "FakeQuant* node %s failed to merge with node %s of type %s" %
                (self.state.output_node_stack[-1][0], current_node.name,
                 current_node.op))
Ejemplo n.º 23
0
class TestGraph_util(unittest.TestCase):
    x_node = node_def_pb2.NodeDef()
    x_node.name = "placeholder"
    x_node.op = "Placeholder"

    input0_node = node_def_pb2.NodeDef()
    input0_node.name = "input0"
    input0_node.op = "Const"
    input0_value = np.float32(np.abs(np.random.randn(4, 3, 2)))
    input0_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input0_value, input0_value.dtype.type, input0_value.shape)))

    input1_node = node_def_pb2.NodeDef()
    input1_node.name = "input1"
    input1_node.op = "Const"
    input1_value = np.float32(np.abs(np.random.randn(4, 1, 1)))
    input1_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input1_value, input1_value.dtype.type, input1_value.shape)))

    add_node = node_def_pb2.NodeDef()
    add_node.op = "Add"
    add_node.name = "add"
    add_node.input.extend([input0_node.name, input1_node.name])

    input2_node = node_def_pb2.NodeDef()
    input2_node.name = "input2"
    input2_node.op = "Const"
    input2_value = np.float32(np.abs(np.random.randn(1)))
    input2_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input2_value, input2_value.dtype.type, input2_value.shape)))

    input3_node = node_def_pb2.NodeDef()
    input3_node.name = "input3"
    input3_node.op = "Const"
    input3_value = np.float32(np.abs(np.random.randn(1)))
    input3_node.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            input3_value, input3_value.dtype.type, input3_value.shape)))

    mul_node = node_def_pb2.NodeDef()
    mul_node.op = "Mul"
    mul_node.name = "mul"
    mul_node.input.extend([add_node.name, input3_node.name])

    sqrt_node = node_def_pb2.NodeDef()
    sqrt_node.name = "rsqrt"
    sqrt_node.op = "Rsqrt"
    sqrt_node.input.extend([mul_node.name])

    sqrt1_node = node_def_pb2.NodeDef()
    sqrt1_node.op = "Relu"
    sqrt1_node.name = "sqrt1"
    sqrt1_node.input.extend([sqrt_node.name])

    block_node = node_def_pb2.NodeDef()
    block_node.name = "block_output"
    block_node.op = "Add"
    block_node.input.extend([x_node.name, sqrt1_node.name])

    res_node = node_def_pb2.NodeDef()
    res_node.name = "res_add"
    res_node.op = "Add"
    res_node.input.extend([sqrt_node.name, input2_node.name])

    end_node = node_def_pb2.NodeDef()
    end_node.name = "end"
    end_node.op = "Add"
    end_node.input.extend([block_node.name, res_node.name])

    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend([
        x_node, input0_node, input1_node, input2_node, input3_node, add_node,
        mul_node, sqrt_node, sqrt1_node, block_node, res_node, end_node
    ])

    def test_replace_constant_graph_with_constant_node(self):
        graph_analyzer = GraphAnalyzer()
        graph_analyzer.graph = copy.deepcopy(self.graph_def)

        graph_analyzer.parse_graph()

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.add_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.add_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 10

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.mul_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.mul_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 8

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.sqrt_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.sqrt_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 7

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.block_node.name + "_const", new_constant_value,
            new_constant_type)
        assert not graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.block_node.name)

    def test_replace_node(self):
        graph_analyzer = GraphAnalyzer()
        graph_analyzer.graph = copy.deepcopy(self.graph_def)

        graph_analyzer.parse_graph()

        new_add_node = node_def_pb2.NodeDef()
        new_add_node.op = "Add"
        new_add_node.name = "add1"
        new_add_node.input.extend(
            [self.input0_node.name, self.input1_node.name])
        graph_analyzer.replace_node(new_add_node, self.add_node.name,
                                    [self.mul_node.name])
        result_graph = graph_analyzer.dump_graph()
        assert self.add_node not in list(result_graph.node)
        assert new_add_node in list(result_graph.node)

    def test_freeze_value_regrex(self):
        sample_str_1 = ';efficientnet-b3/model/blocks_14/se/conv2d/Conv2D_eightbit_requant_range__print__;__requant_min_max:[-2.35420851e+09][2.59383834e+09]'
        sample_str_2 = ';efficientnet-b3/model/blocks_15/se/conv2d/Conv2D_eightbit_requant_range__print__;__requant_min_max:[-1.254][2.59383834]'
        print_suffix = '__print__'
        postfix = '__requant_min_max'
        res_1 = re.search(
            r"{};{}:\[\-?\d+\.?\d*e?\+?\d*\]".format(print_suffix, postfix),
            sample_str_1)
        res_2 = re.search(
            r"{};{}:\[\-?\d+\.?\d*e?\+?\d*\]".format(print_suffix, postfix),
            sample_str_2)
        self.assertNotEqual(res_1, None)
        self.assertNotEqual(res_2, None)
Ejemplo n.º 24
0
def remove_training_nodes(input_graph):
    """Prunes out nodes that aren't needed for inference.

  There are nodes like Identity and CheckNumerics that are only useful
  during training, and can be removed in graphs that will be used for
  nothing but inference. Here we identify and remove them, returning an
  equivalent graph. To be specific, CheckNumerics nodes are always removed, and
  Identity nodes that aren't involved in control edges are spliced out so that
  their input and outputs are directly connected.

  Args:
    input_graph: Model to analyze and prune.

  Returns:
    A list of nodes with the unnecessary ones removed.
  """

    types_to_remove = {"CheckNumerics": True}

    input_nodes = input_graph.node
    names_to_remove = {}
    for node in input_nodes:
        if node.op in types_to_remove:
            names_to_remove[node.name] = True

    nodes_after_removal = []
    for node in input_nodes:
        if node.name in names_to_remove:
            continue
        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            if input_name in names_to_remove:
                continue
            new_node.input.append(full_input_name)
        nodes_after_removal.append(new_node)

    types_to_splice = {"Identity": True}
    names_to_splice = {}
    for node in nodes_after_removal:
        if node.op in types_to_splice:
            # We don't want to remove nodes that have control edge inputs, because
            # they might be involved in subtle dependency issues that removing them
            # will jeopardize.
            has_control_edge = False
            for input_name in node.input:
                if re.match(r"^\^", input_name):
                    has_control_edge = True
            if not has_control_edge:
                names_to_splice[node.name] = node.input[0]

    nodes_after_splicing = []
    for node in nodes_after_removal:
        if node.name in names_to_splice:
            continue
        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            if input_name in names_to_splice:
                new_node.input.append(names_to_splice[input_name])
            else:
                new_node.input.append(full_input_name)
        nodes_after_splicing.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_splicing)
    return output_graph
Ejemplo n.º 25
0
def fold_batch_norms(input_graph_def):
  """Removes batch normalization ops by folding them into convolutions.

  Batch normalization during training has multiple dynamic parameters that are
  updated, but once the graph is finalized these become constants. That means
  there's an opportunity to reduce the computations down to a scale and
  addition, rather than the more expensive multiple ops, and even bake the
  scaling into the convolution weights. This function identifies the typical
  pattern of batch normalization subgraphs, and performs the transformation to
  fold the computations down into a simpler form. It currently only spots batch
  normalization that's performed by the BatchNormWithGlobalNormalization and
  FusedBatchNorm ops, and will need to be extended in the future to handle the
  newer style.

  Args:
    input_graph_def: A GraphDef containing a model.

  Returns:
    Modified graph with BN ops removed, and modified weights.

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """
  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError("Duplicate node names detected for ", node.name)

  nodes_to_skip = {}
  new_ops = []
  for node in input_graph_def.node:
    if node.op not in ("BatchNormWithGlobalNormalization", "FusedBatchNorm"):
      continue

    conv_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("conv_op")])
    if conv_op.op != "Conv2D":
      tf_logging.warning(
          "Didn't find expected Conv2D input to '%s'" % node.name)
      continue

    weights_op = node_from_map(input_node_map, conv_op.input[1])
    if weights_op.op != "Const":
      tf_logging.warning("Didn't find expected conv Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (conv_op.name, weights_op))
      continue
    weights = values_from_const(weights_op)
    channel_count = weights.shape[3]

    mean_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("mean_op")])
    if mean_op.op != "Const":
      tf_logging.warning("Didn't find expected mean Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, mean_op))
      continue
    mean_value = values_from_const(mean_op)
    if mean_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
                         " for node %s" % (str(mean_value.shape), str(
                             (channel_count,)), node.name))
      continue

    var_op = node_from_map(input_node_map,
                           node.input[INPUT_ORDER[node.op].index("var_op")])
    if var_op.op != "Const":
      tf_logging.warning("Didn't find expected var Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, var_op))
      continue
    var_value = values_from_const(var_op)
    if var_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for var, found %s, expected %s,"
                         " for node %s" % (str(var_value.shape), str(
                             (channel_count,)), node.name))
      continue

    beta_op = node_from_map(input_node_map,
                            node.input[INPUT_ORDER[node.op].index("beta_op")])
    if beta_op.op != "Const":
      tf_logging.warning("Didn't find expected beta Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, beta_op))
      continue
    beta_value = values_from_const(beta_op)
    if beta_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for beta, found %s, expected %s,"
                         " for node %s" % (str(beta_value.shape), str(
                             (channel_count,)), node.name))
      continue

    gamma_op = node_from_map(input_node_map,
                             node.input[INPUT_ORDER[node.op].index("gamma_op")])
    if gamma_op.op != "Const":
      tf_logging.warning("Didn't find expected gamma Constant input to '%s',"
                         " found %s instead. Maybe because freeze_graph wasn't"
                         " run first?" % (node.name, gamma_op))
      continue
    gamma_value = values_from_const(gamma_op)
    if gamma_value.shape != (channel_count,):
      tf_logging.warning("Incorrect shape for gamma, found %s, expected %s,"
                         " for node %s" % (str(gamma_value.shape), str(
                             (channel_count,)), node.name))
      continue

    variance_epsilon_value = node.attr[EPSILON_ATTR[node.op]].f
    nodes_to_skip[node.name] = True
    nodes_to_skip[weights_op.name] = True
    nodes_to_skip[mean_op.name] = True
    nodes_to_skip[var_op.name] = True
    nodes_to_skip[beta_op.name] = True
    nodes_to_skip[gamma_op.name] = True
    nodes_to_skip[conv_op.name] = True

    if scale_after_normalization(node):
      scale_value = (
          (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) *
          gamma_value)
    else:
      scale_value = (
          1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value))
    offset_value = (-mean_value * scale_value) + beta_value
    scaled_weights = np.copy(weights)
    it = np.nditer(
        scaled_weights, flags=["multi_index"], op_flags=["readwrite"])
    while not it.finished:
      current_scale = scale_value[it.multi_index[3]]
      it[0] *= current_scale
      it.iternext()
    scaled_weights_op = node_def_pb2.NodeDef()
    scaled_weights_op.op = "Const"
    scaled_weights_op.name = weights_op.name
    scaled_weights_op.attr["dtype"].CopyFrom(weights_op.attr["dtype"])
    scaled_weights_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            scaled_weights, weights.dtype.type, weights.shape)))
    new_conv_op = node_def_pb2.NodeDef()
    new_conv_op.CopyFrom(conv_op)
    offset_op = node_def_pb2.NodeDef()
    offset_op.op = "Const"
    offset_op.name = conv_op.name + "_bn_offset"
    offset_op.attr["dtype"].CopyFrom(mean_op.attr["dtype"])
    offset_op.attr["value"].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            offset_value, mean_value.dtype.type, offset_value.shape)))
    bias_add_op = node_def_pb2.NodeDef()
    bias_add_op.op = "BiasAdd"
    bias_add_op.name = node.name
    bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
    bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
    bias_add_op.input.extend([new_conv_op.name, offset_op.name])
    new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])

  result_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in nodes_to_skip:
      continue
    new_node = node_def_pb2.NodeDef()
    new_node.CopyFrom(node)
    result_graph_def.node.extend([new_node])

  result_graph_def.node.extend(new_ops)
  return result_graph_def
Ejemplo n.º 26
0
    def generate_output_graph(self, input_graph_def, input_node_map,
                              fuse_op_list):
        output_graph_def = graph_pb2.GraphDef()
        skip_list = []
        skip_node_name = []
        float32_type = dtypes.float32.as_datatype_enum
        for index, node in enumerate(input_graph_def.node):
            if index in fuse_op_list:
                input_node = input_node_map[node.input[0]]
                if input_node.op == 'QuantizeV2':
                    new_node = node_def_pb2.NodeDef()

                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]

                    new_node.name = dequantize_node.name

                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                elif input_node.op == "Requantize":
                    new_node = node_def_pb2.NodeDef()
                    new_node.op = node.op + "AndDequantize"
                    for _, value in enumerate(node.input):
                        new_node.input.append(value)

                    dequantize_node = input_graph_def.node[index + 4]
                    frozen_max_node = input_graph_def.node[index + 2]
                    frozen_min_node = input_graph_def.node[index + 1]
                    new_node.name = dequantize_node.name
                    skip_list.append(index + 1)
                    skip_list.append(index + 2)
                    skip_list.append(index + 3)
                    skip_list.append(index + 4)
                    new_node.input.append(frozen_min_node.name)
                    new_node.input.append(frozen_max_node.name)

                    new_node.attr["T1"].CopyFrom(node.attr['T1'])
                    new_node.attr["T2"].CopyFrom(node.attr['T2'])

                    new_node.attr["Tbias"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))
                    new_node.attr["Toutput"].CopyFrom(
                        attr_value_pb2.AttrValue(type=float32_type))

                    output_graph_def.node.extend(
                        [new_node, frozen_max_node, frozen_min_node])
                else:
                    new_node = node_def_pb2.NodeDef()
                    new_node.CopyFrom(node)
                    output_graph_def.node.extend([new_node])

            elif index in skip_list or node.name in skip_node_name:
                continue
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                output_graph_def.node.extend([new_node])
        return output_graph_def
Ejemplo n.º 27
0
graph_filename = 'mrt_graph_1.pb'
graph_filename_converted = 'mrt_graph_2.pb'

f = gfile.FastGFile(graph_filename, 'rb')

# define graph def object
graph_def = tf.GraphDef()

# store frozen graph from pb file
graph_def.ParseFromString(f.read())

# define new empty graph
modified_graph_def = graph_pb2.GraphDef()

# pre-define empty image placeholder node
image_placeholder_node = node_def_pb2.NodeDef()

# iterate through all nodes in graph
for node in graph_def.node:

    # set dtype attibute of imagePlaceholder node to int32
    if node.name == 'vars/Cast':
        print node

# iterate through all nodes in graph
for node in graph_def.node:

    # set dtype attibute of imagePlaceholder node to int32
    if node.name == 'imagePlaceholder':
        # print("found image placeholder")
        node.attr["dtype"].CopyFrom(
Ejemplo n.º 28
0
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None,
                                   use_fp16=False):
    from tensorflow.python.framework.graph_util_impl import extract_sub_graph
    from tensorflow.core.framework import graph_pb2
    from tensorflow.core.framework import node_def_pb2
    from tensorflow.core.framework import attr_value_pb2
    from tensorflow.core.framework import types_pb2
    from tensorflow.python.framework import tensor_util

    def patch_dtype(input_node, field_name, output_node):
        if use_fp16 and (field_name in input_node.attr) and (input_node.attr[field_name].type == types_pb2.DT_FLOAT):
            output_node.attr[field_name].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_HALF))

    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    variable_names = []
    variable_dict_names = []
    for node in inference_graph.node:
        if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
            variable_name = node.name
            if ((variable_names_whitelist is not None and
                 variable_name not in variable_names_whitelist) or
                    (variable_names_blacklist is not None and
                     variable_name in variable_names_blacklist)):
                continue
            variable_dict_names.append(variable_name)
            if node.op == "VarHandleOp":
                variable_names.append(variable_name + "/Read/ReadVariableOp:0")
            else:
                variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = node_def_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]

            if use_fp16 and dtype.type == types_pb2.DT_FLOAT:
                output_node.attr["value"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        tensor=tensor_util.make_tensor_proto(data.astype('float16'),
                                                             dtype=types_pb2.DT_HALF,
                                                             shape=data.shape)))
            else:
                output_node.attr["dtype"].CopyFrom(dtype)
                output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(
                    tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type,
                                                         shape=data.shape)))
            how_many_converted += 1
        elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables):
            # placeholder nodes
            # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"]))
            output_node.op = "Identity"
            output_node.name = input_node.name
            output_node.input.extend([input_node.input[0]])
            output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
            if "_class" in input_node.attr:
                output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
        else:
            # mostly op nodes
            output_node.CopyFrom(input_node)

        patch_dtype(input_node, 'dtype', output_node)
        patch_dtype(input_node, 'T', output_node)
        patch_dtype(input_node, 'DstT', output_node)
        patch_dtype(input_node, 'SrcT', output_node)
        patch_dtype(input_node, 'Tparams', output_node)

        if use_fp16 and ('value' in output_node.attr) and (
                output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT):
            # hard-coded value need to be converted as well
            output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
                tensor=tensor_util.make_tensor_proto(
                    output_node.attr['value'].tensor.float_val[0],
                    dtype=types_pb2.DT_HALF)))

        output_graph_def.node.extend([output_node])

    output_graph_def.library.CopyFrom(inference_graph.library)
    return output_graph_def
Ejemplo n.º 29
0
 def _StripNode(self, nd):
   snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
   if nd.device:
     snode.device = nd.device
   return snode
Ejemplo n.º 30
0
    def apply_conv_single_fusion(self, match_node_name):
        skip_node_name = match_node_name[1:]
        matched_node = self.node_name_mapping[match_node_name[0]]
        _, normal_inputs = self._get_node_input(matched_node.node.name)
        weight_name = normal_inputs[1]
        # TODO this is workaround as the tf 2.1 doesn't support depthwise/conv s8
        # feature.
        if self.enable_s8 and not self._find_relu_node(matched_node.node):
            self.output_graph = self.input_graph
            return

        self._intel_cpu_quantize_weight_eightbit(
            matched_node.node.op, self.node_name_mapping[weight_name].node,
            self.per_channel)

        all_input_names = self._add_eightbit_prologue_nodes(
            matched_node.node.name)
        skip_node_name.append(weight_name)

        for _, node in enumerate(self.input_graph.node):
            if node.name in skip_node_name:
                self.logger.debug("skip node {}".format(node.name))
            elif node.name == match_node_name[0]:
                postfix = "_eightbit_quantized_conv" if node.op == "Conv2D" else "_eightbit_quantized_depthwise_conv"
                quantized_node_name = node.name + postfix
                if node.op == "Conv2D":
                    quantized_conv_node = helper.create_node(
                        "QuantizedConv2DPerChannel"
                        if self.per_channel else "QuantizedConv2D",
                        quantized_node_name, all_input_names)

                elif node.op == "DepthwiseConv2dNative":
                    quantized_conv_node = helper.create_node(
                        "QuantizedDepthwiseConv2D", quantized_node_name,
                        all_input_names)

                helper.copy_attr(quantized_conv_node, "strides",
                                 node.attr["strides"])
                helper.copy_attr(quantized_conv_node, "padding",
                                 node.attr["padding"])
                if node.op != 'DepthwiseConv2dNative' and "padding_list" in node.attr:
                    helper.copy_attr(quantized_conv_node, "padding_list",
                                     node.attr["padding_list"])
                helper.copy_attr(quantized_conv_node, "dilations",
                                 node.attr["dilations"])
                input_data_type = dtypes.quint8 if self._find_relu_node(
                    node) else dtypes.qint8
                helper.set_attr_dtype(quantized_conv_node, "Tinput",
                                      input_data_type)
                helper.set_attr_dtype(quantized_conv_node, "Tfilter",
                                      dtypes.qint8)
                helper.set_attr_dtype(quantized_conv_node, "out_type",
                                      dtypes.qint32)
                self.add_output_graph_node(quantized_conv_node)
                quantize_down_name = self._add_quantize_down_nodes(
                    node, quantized_node_name, dtypes.qint8)
                self._intel_cpu_add_dequantize_result_node(
                    quantize_down_name, node.name, dtypes.qint8)
            else:
                new_node = node_def_pb2.NodeDef()
                new_node.CopyFrom(node)
                self.add_output_graph_node(new_node)