def test_graph(float_graph_def, input_map, output_names):
    """Runs the float graph through the rewriter and tests the results."""
    float_results = run_graph_def(
        float_graph_def, input_map,
        [output_name + ":0" for output_name in output_names])
    # TODO(petewarden): round test is currently failing because there is no
    # RoundToSteps op available.
    # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round")
    # round_graph_def = round_rewriter.rewrite(output_name)
    # round_results = run_graph_def(round_graph_def, input_map,
    #                               [output_name + ":0"])
    # assert are_tensors_near(expected, round_results[0], 1.0)
    #
    # TODO(petewarden): Add test for "quantize" mode.

    eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
                                                     "eightbit")
    eightbit_graph_def = eightbit_rewriter.rewrite(output_names)
    eightbit_results = run_graph_def(
        eightbit_graph_def, input_map,
        [output_name + ":0" for output_name in output_names])
    for expected, result in zip(float_results, eightbit_results):
        assert are_tensors_near(expected, result, 1.0)

    # Test the weights_rounded mode. This uses the default bit_depth.
    weights_rounded_rewriter = quantize_graph.GraphRewriter(
        float_graph_def, "weights_rounded")
    weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names)
    weights_rounded_results = run_graph_def(
        weights_rounded_graph_def, input_map,
        [output_name + ":0" for output_name in output_names])
    for expected, result in zip(float_results, weights_rounded_results):
        assert are_tensors_near(expected, result, 1.0)
Example #2
0
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
    def test_remove_redundant_quantization(self):
        a_constant_name = "a_constant"
        a_constant_min_name = "a_constant_min"
        a_constant_max_name = "a_constant_max"
        a_dequantize_name = "a_dequantize"
        a_quantize_name = "a_quantize"
        b_constant_name = "b_constant"
        b_constant_min_name = "b_constant_min"
        b_constant_max_name = "b_constant_max"
        b_dequantize_name = "b_dequantize"
        b_quantize_name = "b_quantize"
        mat_mul_name = "mat_mul"
        graph_def = tf.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        graph_def.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=tf.float32, shape=[])
        graph_def.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=tf.float32, shape=[])
        graph_def.node.extend([a_constant_max])
        a_dequantize_node = quantize_graph.create_node(
            "Dequantize", a_dequantize_name,
            [a_constant_name, a_constant_min_name, a_constant_max_name])
        quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8)
        graph_def.node.extend([a_dequantize_node])
        a_quantize_node = quantize_graph.create_node(
            "QuantizeV2", a_quantize_name, [
                a_dequantize_name, a_dequantize_name + ":1",
                a_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8)
        graph_def.node.extend([a_quantize_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        graph_def.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=tf.float32, shape=[])
        graph_def.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=tf.float32, shape=[])
        graph_def.node.extend([b_constant_max])
        b_dequantize_node = quantize_graph.create_node(
            "Dequantize", b_dequantize_name,
            [b_constant_name, b_constant_min_name, b_constant_max_name])
        quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8)
        graph_def.node.extend([b_dequantize_node])
        b_quantize_node = quantize_graph.create_node(
            "QuantizeV2", b_quantize_name, [
                b_dequantize_name, b_dequantize_name + ":1",
                b_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8)
        graph_def.node.extend([b_quantize_node])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_quantize_name, b_quantize_name, a_quantize_name + ":1",
                a_quantize_name + ":2", b_quantize_name + ":1",
                b_quantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
        graph_def.node.extend([mat_mul_node])

        expected_output = tf.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        expected_output.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=tf.float32, shape=[])
        expected_output.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=tf.float32, shape=[])
        expected_output.node.extend([a_constant_max])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        expected_output.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=tf.float32, shape=[])
        expected_output.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=tf.float32, shape=[])
        expected_output.node.extend([b_constant_max])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_constant_name, b_constant_name, a_constant_min_name,
                a_constant_max_name, b_constant_min_name, b_constant_max_name
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
        expected_output.node.extend([mat_mul_node])

        rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name])
        output = rewriter.remove_redundant_quantization(graph_def)
        stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
        self.assertProtoEquals(expected_output, stripped_output)