def test_concat(self):
    shape_constant_name = "shape_constant"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    shape_constant = quantize_graph.create_constant_node(shape_constant_name,
                                                         value=0,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([shape_constant])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=[1, 2, 3, 4, 5, 6, 7,
                                                            8, 9, 10, 11, 12],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=[13, 14, 15, 16, 17,
                                                            18, 19, 20, 21, 22,
                                                            23, 24],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([b_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [shape_constant_name,
                                              a_constant_name, b_constant_name])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])
Exemplo n.º 2
0
  def test_multiple_outputs(self):
    input_constant_name = "input_constant"
    split_constant_name = "split_constant"
    split_name = "split"
    concat_constant_name = "concat_constant"
    concat_name = "concat"

    float_graph_def = graph_pb2.GraphDef()
    input_constant = quantize_graph.create_constant_node(
        input_constant_name,
        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
        dtype=dtypes.float32,
        shape=[2, 6])
    float_graph_def.node.extend([input_constant])
    split_constant = quantize_graph.create_constant_node(
        split_constant_name, value=1, dtype=dtypes.int32, shape=[])
    float_graph_def.node.extend([split_constant])
    split_node = quantize_graph.create_node(
        "Split", split_name, [split_constant_name, input_constant_name])
    quantize_graph.set_attr_int(split_node, "num_split", 2)
    quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32)
    float_graph_def.node.extend([split_node])
    concat_constant = quantize_graph.create_constant_node(
        concat_constant_name, value=1, dtype=dtypes.int32, shape=[])
    float_graph_def.node.extend([concat_constant])
    concat_node = quantize_graph.create_node(
        "Concat", concat_name,
        [concat_constant_name, split_name + ":0", split_name + ":1"])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])
Exemplo n.º 3
0
    def test_concat(self):
        shape_constant_name = "shape_constant"
        a_constant_name = "a_constant"
        b_constant_name = "b_constant"
        concat_name = "concat"

        float_graph_def = tf.GraphDef()
        shape_constant = quantize_graph.create_constant_node(
            shape_constant_name, value=0, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([shape_constant])
        a_constant = quantize_graph.create_constant_node(
            a_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([a_constant])
        b_constant = quantize_graph.create_constant_node(
            b_constant_name,
            value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([b_constant])
        concat_node = quantize_graph.create_node(
            "Concat", concat_name,
            [shape_constant_name, a_constant_name, b_constant_name])
        quantize_graph.set_attr_int(concat_node, "N", 2)
        quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
        float_graph_def.node.extend([concat_node])

        test_graph(float_graph_def, {}, [concat_name])
Exemplo n.º 4
0
  def test_non_float_concat(self):
    concat_dim = quantize_graph.create_constant_node(
        "concat_dim", value=0, dtype=tf.int32, shape=[])
    a = quantize_graph.create_constant_node(
        "a", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
        dtype=tf.int32, shape=[2, 2, 3])
    b = quantize_graph.create_constant_node(
        "b", value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
        dtype=tf.int32, shape=[2, 2, 3])
    concat = quantize_graph.create_node(
        "Concat", "concat", [concat_dim.name, a.name, b.name])
    quantize_graph.set_attr_int(concat, "N", 2)
    quantize_graph.set_attr_dtype(concat, "T", tf.int32)

    g = tf.GraphDef()
    g.node.extend([concat_dim, a, b, concat])
    test_graph(g, {}, [concat.name])
Exemplo n.º 5
0
  def test_concat(self):
    shape_constant_name = "shape_constant"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    shape_constant = quantize_graph.create_constant_node(shape_constant_name,
                                                         value=0,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([shape_constant])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=[1, 2, 3, 4, 5, 6, 7,
                                                            8, 9, 10, 11, 12],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=[13, 14, 15, 16, 17,
                                                            18, 19, 20, 21, 22,
                                                            23, 24],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([b_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [shape_constant_name,
                                              a_constant_name, b_constant_name])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])

    # Verify the concat is quantized.
    eightbit_rewriter = quantize_graph.GraphRewriter(
        float_graph_def, "eightbit", quantized_input_range=None)
    eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])

    ops = [node.op for node in eightbit_graph_def.node]
    self.assertEqual(1, ops.count("QuantizedConcat"))
Exemplo n.º 6
0
  def test_non_float_concat(self):
    concat_dim = quantize_graph.create_constant_node(
        "concat_dim", value=0, dtype=dtypes.int32, shape=[])
    a = quantize_graph.create_constant_node(
        "a",
        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
        dtype=dtypes.int32,
        shape=[2, 2, 3])
    b = quantize_graph.create_constant_node(
        "b",
        value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
        dtype=dtypes.int32,
        shape=[2, 2, 3])
    concat = quantize_graph.create_node("Concat", "concat",
                                        [concat_dim.name, a.name, b.name])
    quantize_graph.set_attr_int(concat, "N", 2)
    quantize_graph.set_attr_dtype(concat, "T", dtypes.int32)

    g = graph_pb2.GraphDef()
    g.node.extend([concat_dim, a, b, concat])
    test_graph(g, {}, [concat.name])
Exemplo n.º 7
0
  def test_concat(self):
    shape_constant_name = "shape_constant"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    concat_name = "concat"

    float_graph_def = graph_pb2.GraphDef()
    shape_constant = quantize_graph.create_constant_node(
        shape_constant_name, value=0, dtype=dtypes.int32, shape=[])
    float_graph_def.node.extend([shape_constant])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name,
        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
        dtype=dtypes.float32,
        shape=[2, 2, 3])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name,
        value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
        dtype=dtypes.float32,
        shape=[2, 2, 3])
    float_graph_def.node.extend([b_constant])
    concat_node = quantize_graph.create_node(
        "Concat", concat_name,
        [shape_constant_name, a_constant_name, b_constant_name])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])

    # Verify the concat is quantized.
    eightbit_rewriter = quantize_graph.GraphRewriter(
        float_graph_def, "eightbit", quantized_input_range=None)
    eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])

    ops = [node.op for node in eightbit_graph_def.node]
    self.assertEqual(1, ops.count("QuantizedConcat"))