def test_quantized_input_range_errors(self): with self.assertRaises(ValueError): # Invalid mode. quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded", [0, 1]) with self.assertRaises(ValueError): # Invalid range. quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1])
def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map, output_names, input_range): if sys.version_info[0] == 3: # uint8->quint8 conversion for numpy is not working currently. return quantized_input_map = {} for k, v in input_map.items(): arr = [ int( round((n - input_range[0]) * 255 / (input_range[1] - input_range[0]))) for n in v.flat ] arr = np.array(arr, np.uint8) arr = arr.reshape(v.shape) arr = arr.astype(dtypes.quint8.as_numpy_dtype) quantized_input_map[k] = arr output_tensors = [output_name + ":0" for output_name in output_names] float_results = run_graph_def(float_graph_def, input_map, output_tensors) # Quantize treating the input as quantized in range <input_range>. rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit", input_range) graph_def = rewriter.rewrite(output_names) results = run_graph_def(graph_def, quantized_input_map, output_tensors) for expected, result in zip(float_results, results): assert are_tensors_near(expected, result, .5) ops = [node.op for node in graph_def.node] self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) self.assertEqual(len(output_names), ops.count("Dequantize")) # Quantize without treating input as quantized. rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit", quantized_input_range=None) graph_def = rewriter.rewrite(output_names) results = run_graph_def(graph_def, input_map, output_tensors) for expected, result in zip(float_results, results): assert are_tensors_near(expected, result, .5) ops = [node.op for node in graph_def.node] self.assertEqual(len(input_map), ops.count("QuantizeV2") + ops.count("Quantize")) self.assertEqual(len(output_names), ops.count("Dequantize"))
def test_bias_add_w_fake_quant_w_min_max_vars(self): input_node = quantize_graph.create_constant_node( "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtypes.float32, shape=[1, 1, 2, 5]) offset_node = quantize_graph.create_constant_node( "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5]) bias_add_node = quantize_graph.create_node( "BiasAdd", "bias_add", [input_node.name, offset_node.name]) quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32) min_node = quantize_graph.create_constant_node("min_bias_add", value=-.5, dtype=dtypes.float32, shape=[]) max_node = quantize_graph.create_constant_node("max_bias_add", value=15.5, dtype=dtypes.float32, shape=[]) fake_quant_node = quantize_graph.create_node( "FakeQuantWithMinMaxVars", "fake_quant", [bias_add_node.name, min_node.name, max_node.name]) float_graph_def = graph_pb2.GraphDef() float_graph_def.node.extend([ input_node, offset_node, bias_add_node, min_node, max_node, fake_quant_node ]) graph_test(float_graph_def, {}, [fake_quant_node.name], log_graph=True) # Verify there is only one Quantize and one Requantize op. # Pass in fallback_quantization_range, although it will have no effect # because the FakeQuantWithMinMaxVars are used instead. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None, fallback_quantization_range=[-100, 100]) eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) ops = [node.op for node in eightbit_graph_def.node] node_names = [node.name for node in eightbit_graph_def.node] # No quantize since all inputs are const and can be quantized up-front. self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) # One dequantize at the end. self.assertEqual(1, ops.count("Dequantize")) # The fallback constants are not in the graph. self.assertEqual(0, node_names.count("fallback_quantization_min_value")) self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
def test_reshape(self): """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize.""" def make_matmul(name, a, b): n = quantize_graph.create_node("MatMul", name, [a.name, b.name]) quantize_graph.set_attr_dtype(n, "T", dtypes.float32) quantize_graph.set_attr_bool(n, "transpose_a", False) quantize_graph.set_attr_bool(n, "transpose_b", False) return n # matmul_1 = input*weight_1 input_node = quantize_graph.create_constant_node("input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1]) weight_1_node = quantize_graph.create_constant_node( "weight_1", value=[.5, .6, .7, .8, .9], dtype=dtypes.float32, shape=[1, 5]) matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node) # Reshape 4x5 to 10x2. new_shape_node = quantize_graph.create_constant_node( "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2]) reshape_node = quantize_graph.create_node( "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name]) quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32) # matmul_2_node = reshape*weight_2 weight_2_node = quantize_graph.create_constant_node( "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1]) matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node) g = graph_pb2.GraphDef() g.node.extend([ input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node, weight_2_node, matmul_2_node ]) # Test the graph graph_test(g, {}, ["matmul_2"]) # Verify there is only one Quantize and one Requantize op. eightbit_rewriter = quantize_graph.GraphRewriter( g, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"]) ops = [node.op for node in eightbit_graph_def.node] # No quantize since all inputs are const and can be quantized up-front. self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) self.assertEqual(1, ops.count("QuantizedReshape")) # One dequantize at the end. self.assertEqual(1, ops.count("Dequantize"))
def graph_test(float_graph_def, input_map, output_names, log_graph=False): """Runs the float graph through the rewriter and tests the results.""" float_results = run_graph_def( float_graph_def, input_map, [output_name + ":0" for output_name in output_names]) # TODO(petewarden): round test is currently failing because there is no # RoundToSteps op available. # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round") # round_graph_def = round_rewriter.rewrite(output_name) # round_results = run_graph_def(round_graph_def, input_map, # [output_name + ":0"]) # assert are_tensors_near(expected, round_results[0], 1.0) # # TODO(petewarden): Add test for "quantize" mode. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite(output_names) eightbit_results = run_graph_def( eightbit_graph_def, input_map, [output_name + ":0" for output_name in output_names]) for expected, result in zip(float_results, eightbit_results): assert are_tensors_near(expected, result, 1.0) if log_graph: tf_logging.info("8bit:\n%s", str(eightbit_graph_def)) # Test the weights_rounded mode. This uses the default bit_depth. weights_rounded_rewriter = quantize_graph.GraphRewriter( float_graph_def, "weights_rounded", quantized_input_range=None) weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names) weights_rounded_results = run_graph_def( weights_rounded_graph_def, input_map, [output_name + ":0" for output_name in output_names]) for expected, result in zip(float_results, weights_rounded_results): assert are_tensors_near(expected, result, 1.0)
def test_bias_add_w_fallback_min_max_vars(self): input_node = quantize_graph.create_constant_node( "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtypes.float32, shape=[1, 1, 2, 5]) offset_node = quantize_graph.create_constant_node( "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5]) bias_add_node = quantize_graph.create_node( "BiasAdd", "bias_add", [input_node.name, offset_node.name]) quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32) float_graph_def = graph_pb2.GraphDef() float_graph_def.node.extend([input_node, offset_node, bias_add_node]) graph_test(float_graph_def, {}, [bias_add_node.name], log_graph=True) # Verify there is only one Quantize, one Requantize op, and no # RequantizationRange op. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None, fallback_quantization_range=[-.5, 15.5]) eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name]) ops = [node.op for node in eightbit_graph_def.node] node_names = [node.name for node in eightbit_graph_def.node] # No quantize since all inputs are const and can be quantized up-front. self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) # One dequantize at the end. self.assertEqual(1, ops.count("Dequantize")) # No RequantizationRange self.assertEqual(0, ops.count("RequantizationRange")) # The fallback constants are in the graph. self.assertEqual(1, node_names.count("fallback_quantization_min_value")) self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
def test_relu_w_fake_quant_w_min_max_vars(self): input_node = quantize_graph.create_constant_node( "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[1, 2, 6, 1]) relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name]) quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32) min_node = quantize_graph.create_constant_node("min_bias_add", value=0, dtype=dtypes.float32, shape=[]) max_node = quantize_graph.create_constant_node("max_bias_add", value=12, dtype=dtypes.float32, shape=[]) fake_quant_node = quantize_graph.create_node( "FakeQuantWithMinMaxVars", "fake_quant", [relu_node.name, min_node.name, max_node.name]) float_graph_def = graph_pb2.GraphDef() float_graph_def.node.extend( [input_node, relu_node, min_node, max_node, fake_quant_node]) graph_test(float_graph_def, {}, [fake_quant_node.name], log_graph=True) # Verify there is only one Quantize and one Requantize op. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) ops = [node.op for node in eightbit_graph_def.node] # No quantize since all inputs are const and can be quantized up-front. self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) # One dequantize at the end. self.assertEqual(1, ops.count("Dequantize"))
def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = graph_pb2.GraphDef() shape_constant = quantize_graph.create_constant_node( shape_constant_name, value=0, dtype=dtypes.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=dtypes.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32) float_graph_def.node.extend([concat_node]) graph_test(float_graph_def, {}, [concat_name]) # Verify the concat is quantized. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite([concat_name]) ops = [node.op for node in eightbit_graph_def.node] self.assertEqual(1, ops.count("QuantizedConcat"))
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = graph_pb2.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=dtypes.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node( "Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node( "QuantizeV2", a_quantize_name, [ a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=dtypes.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node( "Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node( "QuantizeV2", b_quantize_name, [ b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2" ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) graph_def.node.extend([mat_mul_node]) expected_output = graph_pb2.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=dtypes.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=dtypes.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) expected_output.node.extend([mat_mul_node]) expected_output.versions.CopyFrom(graph_def.versions) expected_output.library.CopyFrom(graph_def.library) rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name], quantized_input_range=None) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)