def test_graph(float_graph_def, input_map, output_names): """Runs the float graph through the rewriter and tests the results.""" float_results = run_graph_def( float_graph_def, input_map, [output_name + ":0" for output_name in output_names]) # TODO(petewarden): round test is currently failing because there is no # RoundToSteps op available. # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round") # round_graph_def = round_rewriter.rewrite(output_name) # round_results = run_graph_def(round_graph_def, input_map, # [output_name + ":0"]) # assert are_tensors_near(expected, round_results[0], 1.0) # # TODO(petewarden): Add test for "quantize" mode. eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit") eightbit_graph_def = eightbit_rewriter.rewrite(output_names) eightbit_results = run_graph_def( eightbit_graph_def, input_map, [output_name + ":0" for output_name in output_names]) for expected, result in zip(float_results, eightbit_results): assert are_tensors_near(expected, result, 1.0) # Test the weights_rounded mode. This uses the default bit_depth. weights_rounded_rewriter = quantize_graph.GraphRewriter( float_graph_def, "weights_rounded") weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names) weights_rounded_results = run_graph_def( weights_rounded_graph_def, input_map, [output_name + ":0" for output_name in output_names]) for expected, result in zip(float_results, weights_rounded_results): assert are_tensors_near(expected, result, 1.0)
def test_remove_unneeded_nodes(self): a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + a_check_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node("Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) output = rewriter.remove_unneeded_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node( "Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node( "QuantizeV2", a_quantize_name, [ a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node( "Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node( "QuantizeV2", b_quantize_name, [ b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2" ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) graph_def.node.extend([mat_mul_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) expected_output.node.extend([mat_mul_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name]) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)