Esempio n. 1
0
  def testExtractSubGraph(self):
    graph_def = tf.GraphDef()
    n1 = graph_def.node.add()
    n1.name = "n1"
    n1.input.extend(["n5"])
    n2 = graph_def.node.add()
    n2.name = "n2"
    # Take the first output of the n1 node as the input.
    n2.input.extend(["n1:0"])
    n3 = graph_def.node.add()
    n3.name = "n3"
    # Add a control input (which isn't really needed by the kernel, but
    # rather to enforce execution order between nodes).
    n3.input.extend(["^n2"])
    n4 = graph_def.node.add()
    n4.name = "n4"

    # It is fine to have a loops in the graph as well.
    n5 = graph_def.node.add()
    n5.name = "n5"
    n5.input.extend(["n1"])

    sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
    self.assertEqual("n1", sub_graph.node[0].name)
    self.assertEqual("n2", sub_graph.node[1].name)
    self.assertEqual("n3", sub_graph.node[2].name)
    self.assertEqual("n5", sub_graph.node[3].name)
Esempio n. 2
0
def convert_variables_to_constants(sess, input_graph_def, output_node_names):
    variable_names = []
    variable_dict_names = []
    for node in input_graph_def.node:
        if node.op == "Assign":
            variable_name = node.input[0]
            variable_dict_names.append(variable_name)
            variable_names.append(variable_name + ":0")
    returned_variables = sess.run(variable_names)
    found_variables = dict(zip(variable_dict_names, returned_variables))
    print("Frozen %d variables." % len(returned_variables))

    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = graph_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(
                tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])
    print("Converted %d variables to const ops." % how_many_converted)
    return output_graph_def
    def testExtractSubGraph(self):
        graph_def = tf.GraphDef()
        n1 = graph_def.node.add()
        n1.name = "n1"
        n1.input.extend(["n5"])
        n2 = graph_def.node.add()
        n2.name = "n2"
        # Take the first output of the n1 node as the input.
        n2.input.extend(["n1:0"])
        n3 = graph_def.node.add()
        n3.name = "n3"
        # Add a control input (which isn't really needed by the kernel, but
        # rather to enforce execution order between nodes).
        n3.input.extend(["^n2"])
        n4 = graph_def.node.add()
        n4.name = "n4"

        # It is fine to have a loops in the graph as well.
        n5 = graph_def.node.add()
        n5.name = "n5"
        n5.input.extend(["n1"])

        sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
        self.assertEqual("n1", sub_graph.node[0].name)
        self.assertEqual("n2", sub_graph.node[1].name)
        self.assertEqual("n3", sub_graph.node[2].name)
        self.assertEqual("n5", sub_graph.node[3].name)
Esempio n. 4
0
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Esempio n. 5
0
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Esempio n. 6
0
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
Esempio n. 7
0
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
    """Removes unused nodes from a graph."""

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    input_node_names_list = input_node_names.split(",")
    inputs_replaced_graph_def = tf.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names_list:
            placeholder_node = tf.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            placeholder_node.attr["dtype"].CopyFrom(
                tf.AttrValue(type=placeholder_type_enum))
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    output_graph_def = graph_util.extract_sub_graph(
        inputs_replaced_graph_def, output_node_names.split(","))

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Esempio n. 8
0
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices):
    """Converts all variables in a graph and checkpoint into constants."""

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    if not gfile.Exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    with open(input_graph, "rb") as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(bytes(f.read()), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as sess:
        if input_saver:
            with open(input_saver, "rb") as f:
                saver_def = tf.train.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = tf.train.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
        found_variables = {}
        for node in input_graph_def.node:
            if node.op == "Assign":
                variable_name = node.input[0]
                found_variables[variable_name] = sess.run(variable_name + ":0")

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(
        input_graph_def, output_node_names.split(","))

    output_graph_def = tf.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = tf.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            set_attr_dtype(output_node, "dtype", dtype)
            set_attr_tensor(output_node, "value", data, dtype.type, data.shape)
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    with gfile.FastGFile(output_graph, "w") as f:
        f.write(output_graph_def.SerializeToString())
    print("Converted %d variables to const ops." % how_many_converted)
    print("%d ops in the final graph." % len(output_graph_def.node))
Esempio n. 9
0
def freeze_graph(
    input_graph,
    input_saver,
    input_binary,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    output_graph,
    clear_devices,
):
    """Converts all variables in a graph and checkpoint into constants."""

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not tf.gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    if not tf.gfile.Exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with open(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as sess:
        if input_saver:
            with open(input_saver, mode) as f:
                saver_def = tf.train.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = tf.train.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
        found_variables = {}
        for node in input_graph_def.node:
            if node.op == "Assign":
                variable_name = node.input[0]
                found_variables[variable_name] = sess.run(variable_name + ":0")

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(input_graph_def, output_node_names.split(","))

    output_graph_def = tf.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = tf.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            set_attr_dtype(output_node, "dtype", dtype)
            set_attr_tensor(output_node, "value", data, dtype.type, data.shape)
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    with tf.gfile.FastGFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("Converted %d variables to const ops." % how_many_converted)
    print("%d ops in the final graph." % len(output_graph_def.node))
Esempio n. 10
0
 def remove_dead_nodes(self, output_names):
   """Removes nodes that are no longer needed for inference from the graph."""
   old_output_graph = self.output_graph
   self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                    output_names)
 def remove_dead_nodes(self, output_names):
     """Removes nodes that are no longer needed for inference from the graph."""
     old_output_graph = self.output_graph
     self.output_graph = graph_util.extract_sub_graph(
         old_output_graph, output_names)
Esempio n. 12
0
def convert_variables_to_constants(sess, input_graph_def, output_node_names):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    GraphDef containing a simplified version of the original.
  """
    print('call convert_variables')
    found_variables = {}
    variable_name_list = []
    found_variables_list = []
    print('search nodes...')
    for i, node in enumerate(input_graph_def.node):
        # print('node %s' % node)
        if node.op == "Assign":
            variable_name_list.append(node.input[0])
            sys.stdout.write(
                "\r%s" %
                "node: {0}/{1}".format(i + 1, len(input_graph_def.node)))
            sys.stdout.flush()
    print('')
    print('{0} nodes founded'.format(len(variable_name_list)))
    print('evaluate nodes..')
    found_variables_list = sess.run([v + ":0" for v in variable_name_list])
    print('insert values..')
    for i, v in enumerate(variable_name_list):
        found_variables[v] = found_variables_list[i]
        sys.stdout.write(
            "\r%s" % "node: {0}/{1}".format(i + 1, len(variable_name_list)))
        sys.stdout.flush()
    print('')

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(input_graph_def,
                                                   output_node_names)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = graph_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])
    print("Converted %d variables to const ops." % how_many_converted)
    return output_graph_def
Esempio n. 13
0
  def test_remove_redundant_quantization(self):
    a_constant_name = "a_constant"
    a_constant_min_name = "a_constant_min"
    a_constant_max_name = "a_constant_max"
    a_dequantize_name = "a_dequantize"
    a_quantize_name = "a_quantize"
    b_constant_name = "b_constant"
    b_constant_min_name = "b_constant_min"
    b_constant_max_name = "b_constant_max"
    b_dequantize_name = "b_dequantize"
    b_quantize_name = "b_quantize"
    mat_mul_name = "mat_mul"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_max])
    a_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   a_dequantize_name,
                                                   [a_constant_name,
                                                    a_constant_min_name,
                                                    a_constant_max_name])
    quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([a_dequantize_node])
    a_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 a_quantize_name,
                                                 [a_dequantize_name,
                                                  a_dequantize_name + ":1",
                                                  a_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8)
    graph_def.node.extend([a_quantize_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_max])
    b_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   b_dequantize_name,
                                                   [b_constant_name,
                                                    b_constant_min_name,
                                                    b_constant_max_name])
    quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([b_dequantize_node])
    b_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 b_quantize_name,
                                                 [b_dequantize_name,
                                                  b_dequantize_name + ":1",
                                                  b_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8)
    graph_def.node.extend([b_quantize_node])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_quantize_name,
                                               b_quantize_name,
                                               a_quantize_name + ":1",
                                               a_quantize_name + ":2",
                                               b_quantize_name + ":1",
                                               b_quantize_name + ":2"])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    graph_def.node.extend([mat_mul_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_max])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_max])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_constant_name,
                                               b_constant_name,
                                               a_constant_min_name,
                                               a_constant_max_name,
                                               b_constant_min_name,
                                               b_constant_max_name])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    expected_output.node.extend([mat_mul_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name])
    output = rewriter.remove_redundant_quantization(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
    self.assertProtoEquals(expected_output, stripped_output)
Esempio n. 14
0
  def test_remove_redundant_quantization(self):
    a_constant_name = "a_constant"
    a_constant_min_name = "a_constant_min"
    a_constant_max_name = "a_constant_max"
    a_dequantize_name = "a_dequantize"
    a_quantize_name = "a_quantize"
    b_constant_name = "b_constant"
    b_constant_min_name = "b_constant_min"
    b_constant_max_name = "b_constant_max"
    b_dequantize_name = "b_dequantize"
    b_quantize_name = "b_quantize"
    mat_mul_name = "mat_mul"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_max])
    a_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   a_dequantize_name,
                                                   [a_constant_name,
                                                    a_constant_min_name,
                                                    a_constant_max_name])
    quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([a_dequantize_node])
    a_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 a_quantize_name,
                                                 [a_dequantize_name,
                                                  a_dequantize_name + ":1",
                                                  a_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8)
    graph_def.node.extend([a_quantize_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_max])
    b_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   b_dequantize_name,
                                                   [b_constant_name,
                                                    b_constant_min_name,
                                                    b_constant_max_name])
    quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([b_dequantize_node])
    b_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 b_quantize_name,
                                                 [b_dequantize_name,
                                                  b_dequantize_name + ":1",
                                                  b_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8)
    graph_def.node.extend([b_quantize_node])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_quantize_name,
                                               b_quantize_name,
                                               a_quantize_name + ":1",
                                               a_quantize_name + ":2",
                                               b_quantize_name + ":1",
                                               b_quantize_name + ":2"])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    graph_def.node.extend([mat_mul_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_max])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_max])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_constant_name,
                                               b_constant_name,
                                               a_constant_min_name,
                                               a_constant_max_name,
                                               b_constant_min_name,
                                               b_constant_max_name])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    expected_output.node.extend([mat_mul_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name])
    output = rewriter.remove_redundant_quantization(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
    self.assertProtoEquals(expected_output, stripped_output)