def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = tf.GraphDef() shape_constant = quantize_graph.create_constant_node(shape_constant_name, value=0, dtype=tf.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node("Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_remove_unneeded_nodes(self): a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + a_check_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node("Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) output = rewriter.remove_unneeded_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def test_mat_mul(m, n, k, a, b): """Tests a MatMul replacement.""" a_constant_name = "a_constant" b_constant_name = "b_constant" mat_mul_name = "mat_mul" float_graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=a, dtype=tf.float32, shape=[m, k]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=b, dtype=tf.float32, shape=[k, n]) float_graph_def.node.extend([b_constant]) mat_mul_node = quantize_graph.create_node( "MatMul", mat_mul_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32) quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False) quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False) float_graph_def.node.extend([mat_mul_node]) test_graph(float_graph_def, {}, [mat_mul_name])
def test_conv(depth, image_width, image_height, image_batch_count, filter_size, filter_count, stride, padding, input_values, filter_values): """Tests a Conv replacement.""" input_constant_name = "input_constant" filter_constant_name = "filter_constant" conv_name = "conv" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=input_values, dtype=tf.float32, shape=[image_batch_count, image_height, image_width, depth]) float_graph_def.node.extend([input_constant]) filter_constant = quantize_graph.create_constant_node( filter_constant_name, value=filter_values, dtype=tf.float32, shape=[filter_size, filter_size, depth, filter_count]) float_graph_def.node.extend([filter_constant]) conv_node = quantize_graph.create_node( "Conv2D", conv_name, [input_constant_name, filter_constant_name]) quantize_graph.set_attr_dtype(conv_node, "T", tf.float32) quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1]) quantize_graph.set_attr_string(conv_node, "padding", padding) float_graph_def.node.extend([conv_node]) test_graph(float_graph_def, {}, [conv_name])
def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = tf.GraphDef() shape_constant = quantize_graph.create_constant_node( shape_constant_name, value=0, dtype=tf.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_multiple_outputs(self): input_constant_name = "input_constant" split_constant_name = "split_constant" split_name = "split" concat_constant_name = "concat_constant" concat_name = "concat" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 6]) float_graph_def.node.extend([input_constant]) split_constant = quantize_graph.create_constant_node( split_constant_name, value=1, dtype=tf.int32, shape=[]) float_graph_def.node.extend([split_constant]) split_node = quantize_graph.create_node( "Split", split_name, [split_constant_name, input_constant_name]) quantize_graph.set_attr_int(split_node, "num_split", 2) quantize_graph.set_attr_dtype(split_node, "T", tf.float32) float_graph_def.node.extend([split_node]) concat_constant = quantize_graph.create_constant_node( concat_constant_name, value=1, dtype=tf.int32, shape=[]) float_graph_def.node.extend([concat_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [concat_constant_name, split_name + ":0", split_name + ":1"]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_mat_mul(m, n, k, a, b): """Tests a MatMul replacement.""" a_constant_name = "a_constant" b_constant_name = "b_constant" mat_mul_name = "mat_mul" float_graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=a, dtype=tf.float32, shape=[m, k]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=b, dtype=tf.float32, shape=[k, n]) float_graph_def.node.extend([b_constant]) mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name, [a_constant_name, b_constant_name]) quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32) quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False) quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False) float_graph_def.node.extend([mat_mul_node]) test_graph(float_graph_def, {}, [mat_mul_name])
def test_conv(depth, image_width, image_height, image_batch_count, filter_size, filter_count, stride, padding, input_values, filter_values): """Tests a Conv replacement.""" input_constant_name = "input_constant" filter_constant_name = "filter_constant" conv_name = "conv" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=input_values, dtype=tf.float32, shape=[ image_batch_count, image_height, image_width, depth ]) float_graph_def.node.extend([input_constant]) filter_constant = quantize_graph.create_constant_node( filter_constant_name, value=filter_values, dtype=tf.float32, shape=[ filter_size, filter_size, depth, filter_count ]) float_graph_def.node.extend([filter_constant]) conv_node = quantize_graph.create_node("Conv2D", conv_name, [input_constant_name, filter_constant_name]) quantize_graph.set_attr_dtype(conv_node, "T", tf.float32) quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1]) quantize_graph.set_attr_string(conv_node, "padding", padding) float_graph_def.node.extend([conv_node]) test_graph(float_graph_def, {}, [conv_name])
def test_identity(self): input_constant_name = "input_constant" identity_name = "identity" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 6]) float_graph_def.node.extend([input_constant]) identity_node = quantize_graph.create_node("Identity", identity_name, [input_constant_name]) quantize_graph.set_attr_dtype(identity_node, "T", tf.float32) float_graph_def.node.extend([identity_node]) test_graph(float_graph_def, {}, [identity_name])
def test_relu6(self): input_constant_name = "input_constant" relu6_name = "relu6" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 2, 6, 1]) float_graph_def.node.extend([input_constant]) relu6_node = quantize_graph.create_node("Relu6", relu6_name, [input_constant_name]) quantize_graph.set_attr_dtype(relu6_node, "T", tf.float32) float_graph_def.node.extend([relu6_node]) test_graph(float_graph_def, {}, [relu6_name])
def test_relu6(self): input_constant_name = "input_constant" relu6_name = "relu6" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 2, 6, 1]) float_graph_def.node.extend([input_constant]) relu6_node = quantize_graph.create_node("Relu6", relu6_name, [input_constant_name]) quantize_graph.set_attr_dtype(relu6_node, "T", tf.float32) float_graph_def.node.extend([relu6_node]) test_graph(float_graph_def, {}, [relu6_name])
def test_identity(self): input_constant_name = "input_constant" identity_name = "identity" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 6]) float_graph_def.node.extend([input_constant]) identity_node = quantize_graph.create_node("Identity", identity_name, [input_constant_name]) quantize_graph.set_attr_dtype(identity_node, "T", tf.float32) float_graph_def.node.extend([identity_node]) test_graph(float_graph_def, {}, [identity_name])
def test_batch_norm(self): input_constant_name = "input_constant" mean_constant_name = "mean_constant" variance_constant_name = "variance_constant" beta_constant_name = "beta_constant" gamma_constant_name = "gamma_constant" batch_norm_name = "batch_norm" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6], dtype=tf.float32, shape=[1, 1, 6, 2]) float_graph_def.node.extend([input_constant]) mean_constant = quantize_graph.create_constant_node(mean_constant_name, value=[10, 20], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([mean_constant]) variance_constant = quantize_graph.create_constant_node( variance_constant_name, value=[0.25, 0.5], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([variance_constant]) beta_constant = quantize_graph.create_constant_node(beta_constant_name, value=[0.1, 0.6], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([beta_constant]) gamma_constant = quantize_graph.create_constant_node( gamma_constant_name, value=[0, 0], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([gamma_constant]) batch_norm_node = quantize_graph.create_node( "BatchNormWithGlobalNormalization", batch_norm_name, [ input_constant_name, mean_constant_name, variance_constant_name, beta_constant_name, gamma_constant_name ]) quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32) quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization", False) quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001) float_graph_def.node.extend([batch_norm_node]) test_graph(float_graph_def, {}, [batch_norm_name])
def test_avg_pool(self): input_constant_name = "input_constant" avg_pool_name = "avg_pool" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 2, 6, 1]) float_graph_def.node.extend([input_constant]) avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name, [input_constant_name]) quantize_graph.set_attr_dtype(avg_pool_node, "T", tf.float32) quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1]) quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1]) quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME") float_graph_def.node.extend([avg_pool_node]) test_graph(float_graph_def, {}, [avg_pool_name])
def test_avg_pool(self): input_constant_name = "input_constant" avg_pool_name = "avg_pool" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 2, 6, 1]) float_graph_def.node.extend([input_constant]) avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name, [input_constant_name]) quantize_graph.set_attr_dtype(avg_pool_node, "T", tf.float32) quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1]) quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1]) quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME") float_graph_def.node.extend([avg_pool_node]) test_graph(float_graph_def, {}, [avg_pool_name])
def test_batch_norm(self): input_constant_name = "input_constant" mean_constant_name = "mean_constant" variance_constant_name = "variance_constant" beta_constant_name = "beta_constant" gamma_constant_name = "gamma_constant" batch_norm_name = "batch_norm" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6], dtype=tf.float32, shape=[1, 1, 6, 2]) float_graph_def.node.extend([input_constant]) mean_constant = quantize_graph.create_constant_node(mean_constant_name, value=[10, 20], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([mean_constant]) variance_constant = quantize_graph.create_constant_node( variance_constant_name, value=[0.25, 0.5], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([variance_constant]) beta_constant = quantize_graph.create_constant_node(beta_constant_name, value=[0.1, 0.6], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([beta_constant]) gamma_constant = quantize_graph.create_constant_node(gamma_constant_name, value=[0, 0], dtype=tf.float32, shape=[2]) float_graph_def.node.extend([gamma_constant]) batch_norm_node = quantize_graph.create_node( "BatchNormWithGlobalNormalization", batch_norm_name, [input_constant_name, mean_constant_name, variance_constant_name, beta_constant_name, gamma_constant_name]) quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32) quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization", False) quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001) float_graph_def.node.extend([batch_norm_node]) test_graph(float_graph_def, {}, [batch_norm_name])
def test_multiple_outputs(self): input_constant_name = "input_constant" split_constant_name = "split_constant" split_name = "split" concat_constant_name = "concat_constant" concat_name = "concat" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 6]) float_graph_def.node.extend([input_constant]) split_constant = quantize_graph.create_constant_node(split_constant_name, value=1, dtype=tf.int32, shape=[]) float_graph_def.node.extend([split_constant]) split_node = quantize_graph.create_node("Split", split_name, [split_constant_name, input_constant_name]) quantize_graph.set_attr_int(split_node, "num_split", 2) quantize_graph.set_attr_dtype(split_node, "T", tf.float32) float_graph_def.node.extend([split_node]) concat_constant = quantize_graph.create_constant_node(concat_constant_name, value=1, dtype=tf.int32, shape=[]) float_graph_def.node.extend([concat_constant]) concat_node = quantize_graph.create_node("Concat", concat_name, [concat_constant_name, split_name + ":0", split_name + ":1"]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_bias_add(self): input_constant_name = "input_constant" offset_constant_name = "offset_constant" bias_add_name = "bias_add" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 1, 2, 6]) float_graph_def.node.extend([input_constant]) offset_constant = quantize_graph.create_constant_node( offset_constant_name, value=[1, 2, 3, 4, 5, 6], dtype=tf.float32, shape=[6]) float_graph_def.node.extend([offset_constant]) bias_add_node = quantize_graph.create_node( "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name]) quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32) float_graph_def.node.extend([bias_add_node]) test_graph(float_graph_def, {}, [bias_add_name])
def test_bias_add(self): input_constant_name = "input_constant" offset_constant_name = "offset_constant" bias_add_name = "bias_add" float_graph_def = tf.GraphDef() input_constant = quantize_graph.create_constant_node(input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[1, 1, 2, 6]) float_graph_def.node.extend([input_constant]) offset_constant = quantize_graph.create_constant_node(offset_constant_name, value=[1, 2, 3, 4, 5, 6], dtype=tf.float32, shape=[6]) float_graph_def.node.extend([offset_constant]) bias_add_node = quantize_graph.create_node("BiasAdd", bias_add_name, [input_constant_name, offset_constant_name]) quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32) float_graph_def.node.extend([bias_add_node]) test_graph(float_graph_def, {}, [bias_add_name])
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node( "Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node( "QuantizeV2", a_quantize_name, [ a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node( "Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node( "QuantizeV2", b_quantize_name, [ b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2" ]) quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2" ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) graph_def.node.extend([mat_mul_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node( a_constant_min_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node( a_constant_max_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0, ), dtype=tf.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node( b_constant_min_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node( b_constant_max_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node( "QuantizedMatMul", mat_mul_name, [ a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name ]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) expected_output.node.extend([mat_mul_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name]) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)
def test_keep_control_edges(self): no_op_name = "no_op" a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) graph_def.node.extend([no_op]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node( "Identity", a_identity_name, [a_constant_name, "^" + a_check_name, "^" + no_op_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node( "Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node( "Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) expected_output.node.extend([no_op]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) a_identity_node = quantize_graph.create_node( "Identity", a_identity_name, [a_constant_name, "^" + no_op_name]) expected_output.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node( "Add", add_name, [a_identity_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) output = graph_util.remove_training_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)
def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" a_constant_max_name = "a_constant_max" a_dequantize_name = "a_dequantize" a_quantize_name = "a_quantize" b_constant_name = "b_constant" b_constant_min_name = "b_constant_min" b_constant_max_name = "b_constant_max" b_dequantize_name = "b_dequantize" b_quantize_name = "b_quantize" mat_mul_name = "mat_mul" graph_def = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0,), dtype=tf.quint8, shape=[]) graph_def.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node(a_constant_min_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node(a_constant_max_name, value=2, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant_max]) a_dequantize_node = quantize_graph.create_node("Dequantize", a_dequantize_name, [a_constant_name, a_constant_min_name, a_constant_max_name]) quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8) graph_def.node.extend([a_dequantize_node]) a_quantize_node = quantize_graph.create_node("QuantizeV2", a_quantize_name, [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8) graph_def.node.extend([a_quantize_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0,), dtype=tf.quint8, shape=[]) graph_def.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node(b_constant_min_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node(b_constant_max_name, value=3, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant_max]) b_dequantize_node = quantize_graph.create_node("Dequantize", b_dequantize_name, [b_constant_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8) graph_def.node.extend([b_dequantize_node]) b_quantize_node = quantize_graph.create_node("QuantizeV2", b_quantize_name, [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"]) quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8) graph_def.node.extend([b_quantize_node]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [a_quantize_name, b_quantize_name, a_quantize_name + ":1", a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) graph_def.node.extend([mat_mul_node]) expected_output = tf.GraphDef() a_constant = quantize_graph.create_constant_node(a_constant_name, value=(0,), dtype=tf.quint8, shape=[]) expected_output.node.extend([a_constant]) a_constant_min = quantize_graph.create_constant_node(a_constant_min_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_min]) a_constant_max = quantize_graph.create_constant_node(a_constant_max_name, value=2, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant_max]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=(0,), dtype=tf.quint8, shape=[]) expected_output.node.extend([b_constant]) b_constant_min = quantize_graph.create_constant_node(b_constant_min_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_min]) b_constant_max = quantize_graph.create_constant_node(b_constant_max_name, value=3, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant_max]) mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [a_constant_name, b_constant_name, a_constant_min_name, a_constant_max_name, b_constant_min_name, b_constant_max_name]) quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8) quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32) expected_output.node.extend([mat_mul_node]) rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name]) output = rewriter.remove_redundant_quantization(graph_def) stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) self.assertProtoEquals(expected_output, stripped_output)
def test_keep_control_edges(self): no_op_name = "no_op" a_constant_name = "a_constant" b_constant_name = "b_constant" a_check_name = "a_check" b_check_name = "b_check" a_identity_name = "a_identity" b_identity_name = "b_identity" add_name = "add" graph_def = tf.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) graph_def.node.extend([no_op]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([a_constant]) a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, [a_constant_name]) graph_def.node.extend([a_check_node]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + a_check_name, "^" + no_op_name]) graph_def.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) graph_def.node.extend([b_constant]) b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, [b_constant_name]) graph_def.node.extend([b_check_node]) b_identity_node = quantize_graph.create_node("Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) graph_def.node.extend([b_identity_node]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_identity_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) graph_def.node.extend([add_node]) expected_output = tf.GraphDef() no_op = quantize_graph.create_node("NoOp", no_op_name, []) expected_output.node.extend([no_op]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([a_constant]) a_identity_node = quantize_graph.create_node("Identity", a_identity_name, [a_constant_name, "^" + no_op_name]) expected_output.node.extend([a_identity_node]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=1, dtype=tf.float32, shape=[]) expected_output.node.extend([b_constant]) add_node = quantize_graph.create_node("Add", add_name, [a_identity_name, b_constant_name]) quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) output = graph_util.remove_training_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output)