def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = tf.GraphDef() shape_constant = quantize_graph.create_constant_node(shape_constant_name, value=0, dtype=tf.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node("Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_multiple_outputs(self): input_constant_name = "input_constant" split_constant_name = "split_constant" split_name = "split" concat_constant_name = "concat_constant" concat_name = "concat" float_graph_def = graph_pb2.GraphDef() input_constant = quantize_graph.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[2, 6]) float_graph_def.node.extend([input_constant]) split_constant = quantize_graph.create_constant_node( split_constant_name, value=1, dtype=dtypes.int32, shape=[]) float_graph_def.node.extend([split_constant]) split_node = quantize_graph.create_node( "Split", split_name, [split_constant_name, input_constant_name]) quantize_graph.set_attr_int(split_node, "num_split", 2) quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32) float_graph_def.node.extend([split_node]) concat_constant = quantize_graph.create_constant_node( concat_constant_name, value=1, dtype=dtypes.int32, shape=[]) float_graph_def.node.extend([concat_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [concat_constant_name, split_name + ":0", split_name + ":1"]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = tf.GraphDef() shape_constant = quantize_graph.create_constant_node( shape_constant_name, value=0, dtype=tf.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name])
def test_non_float_concat(self): concat_dim = quantize_graph.create_constant_node( "concat_dim", value=0, dtype=tf.int32, shape=[]) a = quantize_graph.create_constant_node( "a", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.int32, shape=[2, 2, 3]) b = quantize_graph.create_constant_node( "b", value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.int32, shape=[2, 2, 3]) concat = quantize_graph.create_node( "Concat", "concat", [concat_dim.name, a.name, b.name]) quantize_graph.set_attr_int(concat, "N", 2) quantize_graph.set_attr_dtype(concat, "T", tf.int32) g = tf.GraphDef() g.node.extend([concat_dim, a, b, concat]) test_graph(g, {}, [concat.name])
def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = tf.GraphDef() shape_constant = quantize_graph.create_constant_node(shape_constant_name, value=0, dtype=tf.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node(a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node(b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=tf.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node("Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", tf.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name]) # Verify the concat is quantized. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite([concat_name]) ops = [node.op for node in eightbit_graph_def.node] self.assertEqual(1, ops.count("QuantizedConcat"))
def test_non_float_concat(self): concat_dim = quantize_graph.create_constant_node( "concat_dim", value=0, dtype=dtypes.int32, shape=[]) a = quantize_graph.create_constant_node( "a", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.int32, shape=[2, 2, 3]) b = quantize_graph.create_constant_node( "b", value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=dtypes.int32, shape=[2, 2, 3]) concat = quantize_graph.create_node("Concat", "concat", [concat_dim.name, a.name, b.name]) quantize_graph.set_attr_int(concat, "N", 2) quantize_graph.set_attr_dtype(concat, "T", dtypes.int32) g = graph_pb2.GraphDef() g.node.extend([concat_dim, a, b, concat]) test_graph(g, {}, [concat.name])
def test_concat(self): shape_constant_name = "shape_constant" a_constant_name = "a_constant" b_constant_name = "b_constant" concat_name = "concat" float_graph_def = graph_pb2.GraphDef() shape_constant = quantize_graph.create_constant_node( shape_constant_name, value=0, dtype=dtypes.int32, shape=[]) float_graph_def.node.extend([shape_constant]) a_constant = quantize_graph.create_constant_node( a_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[2, 2, 3]) float_graph_def.node.extend([a_constant]) b_constant = quantize_graph.create_constant_node( b_constant_name, value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=dtypes.float32, shape=[2, 2, 3]) float_graph_def.node.extend([b_constant]) concat_node = quantize_graph.create_node( "Concat", concat_name, [shape_constant_name, a_constant_name, b_constant_name]) quantize_graph.set_attr_int(concat_node, "N", 2) quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32) float_graph_def.node.extend([concat_node]) test_graph(float_graph_def, {}, [concat_name]) # Verify the concat is quantized. eightbit_rewriter = quantize_graph.GraphRewriter( float_graph_def, "eightbit", quantized_input_range=None) eightbit_graph_def = eightbit_rewriter.rewrite([concat_name]) ops = [node.op for node in eightbit_graph_def.node] self.assertEqual(1, ops.count("QuantizedConcat"))