Beispiel #1
0
def strip_meta_graph(meta_graph_def, node_names, var_names):
    node_names = node_names[:]
    collections = meta_graph_def.collection_def

    # Look for matching variable names and initializers and keep them too.
    var_def = variable_pb2.VariableDef()
    for var_col_name in ["variables", "trainable_variables"]:
        var_def_bs = collections[var_col_name].bytes_list.value
        for var_def_b in var_def_bs:
            var_def.ParseFromString(var_def_b)
            if var_def.variable_name not in var_names:
                # TODO(adamb) Should remove variable from collection.
                continue
            node_names.append(var_def.initializer_name)

    wc_def = control_flow_pb2.WhileContextDef()
    wc_values = collections["while_context"].bytes_list.value
    for wc_ix in range(len(wc_values) - 1, -1, -1):
        wc_bytes = wc_values[wc_ix]
        wc_def.ParseFromString(wc_bytes)
        unused = True
        wc_pivot_name = wc_def.pivot_name
        for name in node_names:
            if name.startswith(wc_pivot_name):
                unused = False
                break

        if unused:
            del wc_values[wc_ix]

    graph_def = meta_graph_def.graph_def
    eprint("only keeping", node_names, "from",
           [n.name for n in graph_def.node])
    graph_def = graph_util.extract_sub_graph(graph_def, node_names)
    meta_graph_def.graph_def.CopyFrom(graph_def)
def strip_pruning_vars_fn(input_graph_def, output_node_names):
    """Removes mask variable from the graph.
  Replaces the masked_weight tensor with element-wise multiplication of mask
  and the corresponding weight variable.
  Args:
    input_graph_def: A GraphDef in which the variables have been converted to
      constants. This is typically the output of
      tf.graph_util.convert_variables_to_constant()
    output_node_names: List of name strings for the result nodes of the graph
  Returns:
    A GraphDef in which pruning-related variables have been removed
  """
    masked_weights_dict = _get_masked_weights(input_graph_def)
    pruned_graph_def = graph_pb2.GraphDef()

    # Replace masked_weight with a const op containing the
    # result of tf.multiply(mask,weight)
    for node in input_graph_def.node:
        output_node = node_def_pb2.NodeDef()
        if 'masked_weight' in node.name:
            output_node.op = 'Const'
            output_node.name = node.name
            dtype = node.attr['T']
            data = masked_weights_dict[node.name]
            output_node.attr['dtype'].CopyFrom(dtype)
            output_node.attr['value'].CopyFrom(
                attr_value_pb2.AttrValue(
                    tensor=tensor_util.make_tensor_proto(data)))

        else:
            output_node.CopyFrom(node)
        pruned_graph_def.node.extend([output_node])

    # Remove stranded nodes: mask and weights
    return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
Beispiel #3
0
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(
            node.attr["_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def
Beispiel #4
0
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    inputs_replaced_graph_def = tf.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names:
            placeholder_node = tf.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            placeholder_node.attr["dtype"].CopyFrom(
                tf.AttrValue(type=placeholder_type_enum))
            if "_output_shapes" in node.attr:
                placeholder_node.attr["_output_shapes"].CopyFrom(
                    node.attr["_output_shapes"])
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def
Beispiel #5
0
  def __init__(self,
               session,
               graph_def,
               output_node_names,
               variable_names_allowlist=None,
               variable_names_denylist=None):
    graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
    super(_SessionConverterData, self).__init__(
        graph_def,
        variable_names_allowlist=variable_names_allowlist,
        variable_names_denylist=variable_names_denylist)

    nodes_to_convert = []
    tensor_names_to_convert = []
    for node in self.graph_def.node:
      if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
        tensor_name = node.name
        if not self._should_convert(tensor_name):
          continue
        if node.op == "VarHandleOp":
          tensor_name = tensor_name + "/Read/ReadVariableOp"
        nodes_to_convert.append(node)
        tensor_names_to_convert.append(tensor_name + ":0")

    if tensor_names_to_convert:
      converted_tensors = session.run(tensor_names_to_convert)
      for node, tensor_value in zip(nodes_to_convert, converted_tensors):
        self._tensor_data[node.name] = _TensorData(
            numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
  def testExtractSubGraph(self):
    graph_def = graph_pb2.GraphDef()
    n1 = graph_def.node.add()
    n1.name = "n1"
    n1.input.extend(["n5"])
    n2 = graph_def.node.add()
    n2.name = "n2"
    # Take the first output of the n1 node as the input.
    n2.input.extend(["n1:0"])
    n3 = graph_def.node.add()
    n3.name = "n3"
    # Add a control input (which isn't really needed by the kernel, but
    # rather to enforce execution order between nodes).
    n3.input.extend(["^n2"])
    n4 = graph_def.node.add()
    n4.name = "n4"

    # It is fine to have a loops in the graph as well.
    n5 = graph_def.node.add()
    n5.name = "n5"
    n5.input.extend(["n1"])

    sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
    self.assertEqual("n1", sub_graph.node[0].name)
    self.assertEqual("n2", sub_graph.node[1].name)
    self.assertEqual("n3", sub_graph.node[2].name)
    self.assertEqual("n5", sub_graph.node[3].name)
Beispiel #8
0
def save_graph_only(sess, output_file_path, output_node_names, as_text=False):
    """Save a small version of the graph based on a session and the output node names."""
    for node in sess.graph_def.node:
        node.device = ''
    graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names)
    output_dir, output_filename = os.path.split(output_file_path)
    graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text)
Beispiel #9
0
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
    for name in input_node_names:
        if ":" in name:
            raise ValueError(
                f"Name '{name}' appears to refer to a Tensor, not an "
                "Operation.")

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    not_found = {name for name in input_node_names}
    inputs_replaced_graph_def = graph_pb2.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names:
            not_found.remove(node.name)
            placeholder_node = node_def_pb2.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            if isinstance(placeholder_type_enum, list):
                input_node_index = input_node_names.index(node.name)
                placeholder_node.attr["dtype"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        type=placeholder_type_enum[input_node_index]))
            else:
                placeholder_node.attr["dtype"].CopyFrom(
                    attr_value_pb2.AttrValue(type=placeholder_type_enum))
            if "_output_shapes" in node.attr:
                placeholder_node.attr["_output_shapes"].CopyFrom(
                    node.attr["_output_shapes"])
            if "shape" in node.attr:
                placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    if not_found:
        raise KeyError(
            f"The following input nodes were not found: {not_found}.")

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Beispiel #11
0
def parse_pb(file_or_path, output_nodes=None):
    """
    arguments
    =========
    - file_or_path: a file object or a path string of the pb file
    - output_nodes: list of output node names

    returns
    =======
    - nodes: mapping from node name to its NodeDef object
    """
    if sys.version_info.major < 3:
        file_type = (file, io.IOBase)  # pylint: disable=E0602
    else:
        file_type = io.IOBase

    if isinstance(file_or_path, str):
        fid = open(file_or_path, "rb")
    elif isinstance(file_or_path, file_type):
        fid = file_or_path
    else:
        raise ValueError(
            "`file_or_path` has to be either file object or path string")

    # load pb file
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fid.read())
    fid.close()

    if output_nodes:
        sub_graph_def = graph_util.extract_sub_graph(graph_def, output_nodes)
    else:
        sub_graph_def = graph_def

    return dict((node.op, node) for node in sub_graph_def.node)  # pylint: disable=E1101
  def test_keep_control_edges(self):
    no_op_name = "no_op"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    graph_def.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name,
        [a_constant_name, "^" + a_check_name, "^" + no_op_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node(
        "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    graph_def.node.extend([add_node])

    expected_output = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    expected_output.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
    expected_output.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    expected_output.node.extend([add_node])
    expected_output.versions.CopyFrom(graph_def.versions)
    expected_output.library.CopyFrom(graph_def.library)

    output = graph_util.remove_training_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Beispiel #13
0
 def save_pb(self, pb_path):
     extracted_graph = graph_util.extract_sub_graph(
         self.sess.graph_def,
         [self.name + '/light_state', self.name + '/light_position'])
     constant_graph = graph_util.convert_variables_to_constants(
         self.sess, self.sess.graph_def,
         [n.name for n in extracted_graph.node])
     with tf.gfile.FastGFile(pb_path, mode='wb') as f:
         f.write(constant_graph.SerializeToString())
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError("Name '%s' appears to refer to a Tensor, "
                       "not a Operation." % name)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError("The following input nodes were not found: %s\n" % not_found)

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def
Beispiel #15
0
    def run(self):
        # Normalize feeds and fetch
        fetch = self.fetch.split(",") if isinstance(self.fetch, str) else self.fetch
        feeds = self.feeds.split(",") if isinstance(self.feeds, str) else self.feeds

        # Find latest SavedModel export in path_saved_model
        subdirs = [
            str(path) for path in Path(self.path_saved_model).iterdir() if path.is_dir() and "temp" not in str(path)
        ]
        latest = str(sorted(subdirs)[-1])
        LOGGER.info(f"Using SavedModel {latest}")

        # Reload SavedModel Graph, optimize and export
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            meta_graph_def = tf.compat.v1.saved_model.loader.load(sess, ["serve"], latest)
            graph_def = meta_graph_def.graph_def

            # Add table initializer if present, or create it
            if INIT_ALL_TABLES in {node.name for node in graph_def.node}:
                fetch.append(INIT_ALL_TABLES)
            else:
                table_initializers = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)
                if table_initializers:
                    LOGGER.info(f"Adding {INIT_ALL_TABLES} Node to the graph")
                    table_init_op = tf.group(*table_initializers, name=INIT_ALL_TABLES)
                    node_def = table_init_op.node_def
                    graph_def.node.append(node_def)
                    fetch.append(INIT_ALL_TABLES)

            # Rename nodes
            graph_def = rename_nodes(graph_def, self.new_names)

            # Setup (create / remove) placeholders
            graph_def = make_placeholders(graph_def, feeds)

            # Keep only part of the graph that produces tensor 'fetch'
            graph_def = extract_sub_graph(graph_def, fetch)

            # Replace variables by constants
            graph_def = freeze_graph_with_def_protos(
                input_graph_def=graph_def,
                input_saver_def=None,
                input_checkpoint=None,
                output_node_names=",".join(fetch),
                restore_op_name=None,
                filename_tensor_name=None,
                output_graph=None,
                clear_devices=True,
                initializer_nodes=None,
                variable_names_blacklist=",".join(self.blacklisted_variables),
                input_saved_model_dir=latest,
                saved_model_tags=["serve"],
            )
            tf.io.write_graph(graph_def, logdir=self.path_optimized_model, name=self.graph_name, as_text=False)
            LOGGER.info(f"Optimized Model successfully exported to {self.path_optimized_model}/{self.graph_name}")
Beispiel #16
0
def tf_optimize(sess, input_names, output_names, graph_def):
    transforms = [
        "remove_nodes(op=Identity, op=CheckNumerics)",
        "fold_batch_norms",
        "fold_old_batch_norms"
        # fails: "fold_constants(ignore_errors=true)",
    ]
    needed_names = input_names + output_names
    graph_def = graph_util.extract_sub_graph(graph_def, needed_names)
    graph_def = TransformGraph(graph_def, input_names, output_names, transforms)
    return graph_def
Beispiel #17
0
def tf_optimize(sess, inputs, outputs, graph_def):
    """Optimize tensorflow graph for inference."""
    transforms = [
        "fold_constants(ignore_errors=true)",
        "fold_batch_norms",
        "fold_old_batch_norms",
    ]
    needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs]
    graph_def = graph_util.extract_sub_graph(graph_def, needed_names)
    graph_def = TransformGraph(graph_def, inputs, outputs, transforms)
    return graph_def
Beispiel #18
0
    def __init__(self,
                 meta_file,
                 checkpoint_file,
                 dest_nodes,
                 inputShape=None,
                 in_nodes=None):
        super(TensorflowParser, self).__init__()

        # load model files into TensorFlow graph
        if meta_file:
            model = TensorflowParser._load_meta(meta_file)

        if checkpoint_file:
            self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
            self.weight_loaded = True

        # extract subgraph using in_nodes and dest_nodes
        if in_nodes != None and inputShape != None:
            from tensorflow.python.tools import strip_unused_lib
            from tensorflow.python.framework import dtypes
            from tensorflow.python.platform import gfile
            input_node_names = in_nodes.split(',')
            output_node_names = dest_nodes.split(',')
            model = strip_unused_lib.strip_unused(
                input_graph_def=model,
                input_node_names=input_node_names,
                output_node_names=output_node_names,
                placeholder_type_enum=dtypes.float32.as_datatype_enum)

            input_list = [None]
            for i in range(len(inputShape)):
                input_list.append(tensorflow.Dimension(inputShape[i]))
            tensor_input = tensorflow.TensorShape(input_list)
            # Build network graph
            self.tf_graph = TensorflowGraph(model)
            for node in self.tf_graph.model.node:
                if node.name in input_node_names:
                    node.attr['shape'].list.shape.extend(
                        [tensor_input.as_proto()])
                    node.attr['_output_shapes'].list.shape.pop(
                    )  #unknown_rank pop
                    node.attr['_output_shapes'].list.shape.extend(
                        [tensor_input.as_proto()])

        # extract subgraph using dest_nodes
        elif dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))
            self.tf_graph = TensorflowGraph(model)

        else:
            self.tf_graph = TensorflowGraph(model)

        self.tf_graph.build()
Beispiel #19
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.keras.backend.set_learning_phase(0)

    model = build_model()
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())

    graph_def = K.get_session().graph.as_graph_def()
    graph_def = graph_util.extract_sub_graph(graph_def, [FLAGS.output_nodes])
    tf.train.write_graph(graph_def,
                         FLAGS.save_dir,
                         FLAGS.graph_filename,
                         as_text=True)
    print("Finish export inference graph: {}".format(FLAGS.save_dir))
 def __init__(self,
              meta_file,
              checkpoint_file,
              frozen_file,
              dest_nodes=None):
     super(TensorflowParser, self).__init__()
     if meta_file:
         model = TensorflowParser._load_meta(meta_file)
     if checkpoint_file:
         self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
         self.weight_loaded = True
     if dest_nodes != None:
         from tensorflow.python.framework.graph_util import extract_sub_graph
         model = extract_sub_graph(model, dest_nodes.split(','))
     self.tf_graph = TensorflowGraph(model)
     self.tf_graph.build()
Beispiel #21
0
    def load_meta(self):
        """ Load a tensorflow meta file from disk

        Returns:
            model: A tensorflow protobuf file
        """

        from tensorflow.core.protobuf import meta_graph_pb2

        meta_graph = meta_graph_pb2.MetaGraphDef()
        self.load_protobuf_from_file(meta_graph, self._tf_model_prefix + '.meta')
        graph = meta_graph.graph_def
        if self._dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            graph = extract_sub_graph(graph, self._dest_nodes.split(','))
        print ("Tensorflow model file [%s] loaded successfully." % self._tf_model_prefix)
        return graph
Beispiel #22
0
def tf_optimize(inputs, outputs, graph_def, fold_constant=None):
    """Optimize tensorflow graph for inference."""
    transforms = []
    if fold_constant:
        transforms.extend([
            "fold_constants(ignore_errors=true)",
            "remove_attribute(attribute_name=_class)",  # remove node colocation attributes
        ])

    transforms.extend([
        "fold_batch_norms",
        "fold_old_batch_norms",
    ])
    needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs]
    graph_def = graph_util.extract_sub_graph(graph_def, needed_names)
    graph_def = TransformGraph(graph_def, inputs, outputs, transforms)
    return graph_def
Beispiel #23
0
    def extract_sub_graph(input_path, output_path=None, dest_nodes=None):
        if not output_path:
            output_path = append_file_name_suffix(input_path, "sub")

        logging.info("load from %s", input_path)
        graph_def = load_graph_def_from_pb(input_path)
        logging.info("\ttotal node = %s", len(graph_def.node))

        if dest_nodes:
            dest_nodes = dest_nodes.split(',')
        else:
            _, dest_nodes = get_graph_def_io_nodes(graph_def)

        graph_def = graph_util.extract_sub_graph(graph_def, dest_nodes)
        logging.info("save to %s", output_path)
        logging.info("\ttotal node = %s", len(graph_def.node))
        save_graph_def(graph_def, output_path)
Beispiel #24
0
    def run(self):
        # Find latest SavedModel export in path_saved_model
        subdirs = [
            str(path) for path in Path(self.path_saved_model).iterdir()
            if path.is_dir() and "temp" not in str(path)
        ]
        latest = str(sorted(subdirs)[-1])
        LOGGER.info(f"Using SavedModel {latest}")

        # Reload SavedModel Graph, optimize and export
        with tf.Session(graph=tf.Graph()) as sess:
            graph = tf.saved_model.loader.load(sess, ["serve"], latest)
            graph_def = graph.graph_def

            # Rename nodes
            graph_def = rename_nodes(graph_def, self.new_names)

            # Setup (create / remove) placeholders
            graph_def = make_placeholders(graph_def, self.feeds)

            # Keep only part of the graph that produces tensor 'fetch'
            graph_def = extract_sub_graph(graph_def, [self.fetch])

            # Replace variables by constants
            graph_def = freeze_graph_with_def_protos(
                input_graph_def=graph_def,
                input_saver_def=None,
                input_checkpoint=None,
                output_node_names=self.fetch,
                restore_op_name=None,
                filename_tensor_name=None,
                output_graph=None,
                clear_devices=True,
                initializer_nodes=None,
                variable_names_blacklist=",".join(self.blacklisted_variables),
                input_saved_model_dir=latest,
                saved_model_tags=["serve"],
            )
            tf.io.write_graph(graph_def,
                              logdir=self.path_optimized_model,
                              name=self.graph_name,
                              as_text=False)
            LOGGER.info(
                f"Online KNN successfully exported to {self.path_optimized_model}/{self.graph_name}"
            )
    def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into TensorFlow graph
        if meta_file:
            model = TensorflowParser._load_meta(meta_file)

        if checkpoint_file:
            self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()
Beispiel #26
0
    def __init__(self, input_args, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into Keras graph
        from six import string_types as _string_types
        if isinstance(input_args, _string_types):
            model = TensorflowParser._load_meta(input_args)
        elif isinstance(input_args, tuple):
            model = TensorflowParser._load_meta(input_args[0])
            self.ckpt_data = TensorflowParser._load_weights(input_args[1])
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph =  TensorflowGraph(model)
        self.tf_graph.build()
Beispiel #27
0
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
    """Removes unused nodes from a graph."""

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    input_node_names_list = input_node_names.split(",")
    inputs_replaced_graph_def = tf.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names_list:
            placeholder_node = tf.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            placeholder_node.attr["dtype"].CopyFrom(
                tf.AttrValue(type=placeholder_type_enum))
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    output_graph_def = graph_util.extract_sub_graph(
        inputs_replaced_graph_def, output_node_names.split(","))

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Beispiel #28
0
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Beispiel #29
0
def tf_optimize(sess, inputs, outputs, graph_def):
    # print("tf_optimize begin")
    """Optimize tensorflow graph for inference."""
    transforms = [
        "fold_constants(ignore_errors=true)",
        "fold_batch_norms",
        "fold_old_batch_norms",

    ]
    # TODO 这俩 在 研究 研究
    needed_names = [utils.node_name(i) for i in inputs] + [utils.node_name(i) for i in outputs]
    print("---------------needed_names:", needed_names)
    graph_def = graph_util.extract_sub_graph(graph_def, needed_names)

    print("extract_sub_graph done")

    graph_def = TransformGraph(graph_def, inputs, outputs, transforms)

    print("TransformGraph done")
    return graph_def
Beispiel #30
0
def parse_pb(file_or_path, output_nodes=None):
    """
  Arguments
  =========
  - file_or_path: a file object or a path string of the pb file
  - output_nodes: list of output node names

  Returns
  =======
  - ops_info <dict>: a dict with information neccessary for
    building context in uTensor
  - ops_topo <list>: list of op node names in topological sorted order
  - output_nodes <list>: list of output node names
  """
    if sys.version_info.major < 3:
        file_type = (file, io.IOBase)  # pylint: disable=E0602
    else:
        file_type = io.IOBase

    if isinstance(file_or_path, str):
        fid = open(file_or_path, "rb")
    elif isinstance(file_or_path, file_type):
        fid = file_or_path
    else:
        raise ValueError(
            "`file_or_path` has to be either file object or path string")

    # load pb file
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fid.read())
    fid.close()

    if output_nodes is not None:
        graph_def = graph_util.extract_sub_graph(graph_def, output_nodes)

    ops_info, ops_topo, output_nodes = _parse_graph_def(
        graph_def, output_nodes)
    return ops_info, ops_topo, output_nodes
def strip_pruning_vars_fn(input_graph_def, output_node_names):
  """Removes mask variable from the graph.

  Replaces the masked_weight tensor with element-wise multiplication of mask
  and the corresponding weight variable.

  Args:
    input_graph_def: A GraphDef in which the variables have been converted to
      constants. This is typically the output of
      tf.graph_util.convert_variables_to_constant()
    output_node_names: List of name strings for the result nodes of the graph

  Returns:
    A GraphDef in which pruning-related variables have been removed
  """
  masked_weights_dict = _get_masked_weights(input_graph_def)
  pruned_graph_def = graph_pb2.GraphDef()

  # Replace masked_weight with a const op containing the
  # result of tf.multiply(mask,weight)
  for node in input_graph_def.node:
    output_node = node_def_pb2.NodeDef()
    if 'masked_weight' in node.name:
      output_node.op = 'Const'
      output_node.name = node.name
      dtype = node.attr['T']
      data = masked_weights_dict[node.name]
      output_node.attr['dtype'].CopyFrom(dtype)
      output_node.attr['value'].CopyFrom(
          attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data)))

    else:
      output_node.CopyFrom(node)
    pruned_graph_def.node.extend([output_node])

  # Remove stranded nodes: mask and weights
  return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
Beispiel #32
0
 def testExtractSubGraphWithInvalidDestNodes(self):
   graph_def = graph_pb2.GraphDef()
   n1 = graph_def.node.add()
   n1.name = "n1"
   with self.assertRaisesRegexp(TypeError, "must be a list"):
     graph_util.extract_sub_graph(graph_def, "n1")
 def remove_dead_nodes(self, output_names):
   """Removes nodes that are no longer needed for inference from the graph."""
   old_output_graph = self.output_graph
   self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                    output_names)
Beispiel #34
0
    def __init__(self,
                 meta_file,
                 checkpoint_file,
                 dest_nodes,
                 inputShape=None,
                 in_nodes=None):
        super(TensorflowParser, self).__init__()

        # load model files into TensorFlow graph
        if meta_file:
            model = TensorflowParser._load_meta(meta_file)

        if checkpoint_file:
            self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
            self.weight_loaded = True

        # extract subgraph using in_nodes and dest_nodes
        if in_nodes != None and inputShape != None:
            from tensorflow.python.tools import strip_unused_lib
            from tensorflow.python.framework import dtypes
            from tensorflow.python.platform import gfile
            model = strip_unused_lib.strip_unused(
                input_graph_def=model,
                input_node_names=in_nodes,
                output_node_names=dest_nodes,
                placeholder_type_enum=dtypes.float32.as_datatype_enum)

            input_list = [None]
            for i in range(len(inputShape)):
                input_list.append(tensorflow.Dimension(inputShape[i]))
            tensor_input = tensorflow.TensorShape(input_list)
            # Build network graph
            self.tf_graph = TensorflowGraph(model)
            for node in self.tf_graph.model.node:
                if node.name in in_nodes:
                    node.attr['shape'].shape.CopyFrom(tensor_input.as_proto())
                    node.attr['_output_shapes'].list.shape.pop(
                    )  #unknown_rank pop
                    node.attr['_output_shapes'].list.shape.extend(
                        [tensor_input.as_proto()])

        # extract subgraph using dest_nodes
        elif dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes)
            self.tf_graph = TensorflowGraph(model)

        else:
            self.tf_graph = TensorflowGraph(model)

        # Graph Transform
        transforms = ["fold_constants(ignore_errors=true)"]

        #  Get input node name
        if not in_nodes:
            in_nodes = []
            for node in model.node:
                if node.op == 'Placeholder':
                    in_nodes.append(node.name)

        transformed_graph_def = TransformGraph(model, in_nodes, dest_nodes,
                                               transforms)
        in_type_list = {}
        in_shape_list = {}

        for n in transformed_graph_def.node:
            if n.name in in_nodes:
                in_type_list[n.name] = n.attr['dtype'].type
                in_node_shape = n.attr['shape'].shape
                in_node_shape_str = self._shapeToStr(in_node_shape)
                in_shape_list[n.name] = in_node_shape_str

        dtype = tensorflow.float32
        with tensorflow.Graph().as_default() as g:
            input_map = {}
            for in_node in in_nodes:
                if in_type_list[in_node] == 1 or in_type_list[in_node] == 0:
                    dtype = tensorflow.float32

                elif in_type_list[in_node] == 3:
                    dtype = tensorflow.int32

                elif in_type_list[in_node] == 10:
                    dtype = tensorflow.bool

                x = tensorflow.placeholder(dtype, shape=in_shape_list[in_node])
                input_map[in_node] = x

            tensorflow.import_graph_def(transformed_graph_def,
                                        name='',
                                        input_map=input_map)

        with tensorflow.Session(graph=g) as sess:
            tempdir = tempfile.mkdtemp()
            meta_graph_def = tensorflow.train.export_meta_graph(
                filename=os.path.join(tempdir, 'my-model.meta'))
            model = meta_graph_def.graph_def
            shutil.rmtree(tempdir)

        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()

        process_graph(self.tf_graph, self.ckpt_data)
 def testExtractSubGraphWithInvalidDestNodes(self):
     graph_def = graph_pb2.GraphDef()
     n1 = graph_def.node.add()
     n1.name = "n1"
     with self.assertRaisesRegexp(TypeError, "must be a list"):
         graph_util.extract_sub_graph(graph_def, "n1")
Beispiel #36
0
def strip_unused(input_graph_def, input_tensor_names, output_tensor_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_tensor_names: A list of the nodes we use as inputs.
    output_tensor_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed. and a map containing the old input
    names to the new input names

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
    for name in input_tensor_names:
        if ":" not in name:
            raise ValueError("Input '%s' appears to refer to a Operation, "
                             "not a Tensor." % name)

    old2new = {}

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    not_found = {name for name in input_tensor_names}
    input_node_names = {name.split(":")[0] for name in input_tensor_names}
    output_node_names = list(
        {name.split(":")[0]
         for name in output_tensor_names})
    inputs_replaced_graph_def = graph_pb2.GraphDef()
    for node in input_graph_def.node:
        if node.name not in input_node_names:
            for i in range(len(node.input)):
                if _append_port(node.input[i]) in input_tensor_names:
                    old_name = _append_port(node.input[i])
                    not_found.remove(old_name)
                    new_input_name = node.input[i].replace(":", "_")
                    placeholder_node = node_def_pb2.NodeDef()
                    placeholder_node.op = "Placeholder"
                    placeholder_node.name = new_input_name
                    if isinstance(placeholder_type_enum, list):
                        input_node_index = input_tensor_names.index(old_name)
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum[input_node_index]))
                    else:
                        placeholder_node.attr["dtype"].CopyFrom(
                            attr_value_pb2.AttrValue(
                                type=placeholder_type_enum))
                    if "_output_shapes" in node.attr:
                        placeholder_node.attr["_output_shapes"].CopyFrom(
                            node.attr["_output_shapes"])
                    node.input[i] = new_input_name
                    old2new[old_name] = new_input_name + ":0"
                    inputs_replaced_graph_def.node.extend([placeholder_node])
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    if not_found:
        raise KeyError("The following input nodes were not found: %s\n" %
                       not_found)

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def, old2new
Beispiel #37
0
def parse_pb(file_or_path, output_nodes=None) -> (dict, list):
    """
  Arguments
  =========
  - file_or_path: a file object or a path string of the pb file
  - output_nodes: list of output node names

  Returns
  =======
  - graph_info <defaultdict>: a dict with information neccessary for 
    building context in uTensor
  - layers <list>: list of layer which is a list of operation names 
    in the graph

  Note
  ====
  graph_info example:
    { 'my_const': {
        "input_tensor": set([]),
        "output_tensor": set(["my_const:0"])
        "output_content": {"my_const:0": 3.14},
        "op_type": "Const"
      },
      'myop': {
        "input_tensor": set(["input1:0", "input2:0"]),
        "output_tensor": set(["my_op:0", "my_op:1"]),
        "output_content": {},
        "op_type": "MyOp"
      },
      ...
    }

  layers example:
    `bottom` <--------> `top`
      foo -
            \\
              tar - - var
            /
      bar -
  the return list, layers, will be [['foo', 'bar'], ['tar'], ['var']]
  That is, layers[0] is the bottom layer of the graph, layers[1] is the
  second bottom layer of the graph, so on and so forth
  """
    if sys.version_info.major < 3:
        file_type = (file, io.IOBase)  # pylint: disable=E0602
    else:
        file_type = io.IOBase

    if isinstance(file_or_path, str):
        fid = open(file_or_path, "rb")
    elif isinstance(file_or_path, file_type):
        fid = file_or_path
    else:
        raise ValueError(
            "`file_or_path` has to be either file object or path string")

    # load pb file
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fid.read())
    fid.close()

    if output_nodes is not None:
        graph_def = graph_util.extract_sub_graph(graph_def, output_nodes)

    graph_info, layers = _parse_graph_def(graph_def)
    return graph_info, layers
  def test_remove_redundant_quantization(self):
    a_constant_name = "a_constant"
    a_constant_min_name = "a_constant_min"
    a_constant_max_name = "a_constant_max"
    a_dequantize_name = "a_dequantize"
    a_quantize_name = "a_quantize"
    b_constant_name = "b_constant"
    b_constant_min_name = "b_constant_min"
    b_constant_max_name = "b_constant_max"
    b_dequantize_name = "b_dequantize"
    b_quantize_name = "b_quantize"
    mat_mul_name = "mat_mul"
    graph_def = graph_pb2.GraphDef()
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    graph_def.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(
        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(
        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant_max])
    a_dequantize_node = quantize_graph.create_node(
        "Dequantize", a_dequantize_name,
        [a_constant_name, a_constant_min_name, a_constant_max_name])
    quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
    graph_def.node.extend([a_dequantize_node])
    a_quantize_node = quantize_graph.create_node(
        "QuantizeV2", a_quantize_name,
        [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
    graph_def.node.extend([a_quantize_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    graph_def.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(
        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(
        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant_max])
    b_dequantize_node = quantize_graph.create_node(
        "Dequantize", b_dequantize_name,
        [b_constant_name, b_constant_min_name, b_constant_max_name])
    quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
    graph_def.node.extend([b_dequantize_node])
    b_quantize_node = quantize_graph.create_node(
        "QuantizeV2", b_quantize_name,
        [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
    graph_def.node.extend([b_quantize_node])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
        a_quantize_name, b_quantize_name, a_quantize_name + ":1",
        a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
    ])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
    graph_def.node.extend([mat_mul_node])

    expected_output = graph_pb2.GraphDef()
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    expected_output.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(
        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(
        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant_max])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    expected_output.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(
        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(
        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant_max])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
        a_constant_name, b_constant_name, a_constant_min_name,
        a_constant_max_name, b_constant_min_name, b_constant_max_name
    ])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
    expected_output.node.extend([mat_mul_node])
    expected_output.versions.CopyFrom(graph_def.versions)
    expected_output.library.CopyFrom(graph_def.library)

    rewriter = quantize_graph.GraphRewriter(
        graph_def, [mat_mul_name], quantized_input_range=None)
    output = rewriter.remove_redundant_quantization(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
    self.assertProtoEquals(expected_output, stripped_output)