예제 #1
0
    def test_multiple_outputs(self):
        input_constant_name = "input_constant"
        split_constant_name = "split_constant"
        split_name = "split"
        concat_constant_name = "concat_constant"
        concat_name = "concat"

        float_graph_def = tf.GraphDef()
        input_constant = quantize_graph.create_constant_node(
            input_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 6])
        float_graph_def.node.extend([input_constant])
        split_constant = quantize_graph.create_constant_node(
            split_constant_name, value=1, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([split_constant])
        split_node = quantize_graph.create_node(
            "Split", split_name, [split_constant_name, input_constant_name])
        quantize_graph.set_attr_int(split_node, "num_split", 2)
        quantize_graph.set_attr_dtype(split_node, "T", tf.float32)
        float_graph_def.node.extend([split_node])
        concat_constant = quantize_graph.create_constant_node(
            concat_constant_name, value=1, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([concat_constant])
        concat_node = quantize_graph.create_node(
            "Concat", concat_name,
            [concat_constant_name, split_name + ":0", split_name + ":1"])
        quantize_graph.set_attr_int(concat_node, "N", 2)
        quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
        float_graph_def.node.extend([concat_node])

        test_graph(float_graph_def, {}, [concat_name])
예제 #2
0
    def test_concat(self):
        shape_constant_name = "shape_constant"
        a_constant_name = "a_constant"
        b_constant_name = "b_constant"
        concat_name = "concat"

        float_graph_def = tf.GraphDef()
        shape_constant = quantize_graph.create_constant_node(
            shape_constant_name, value=0, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([shape_constant])
        a_constant = quantize_graph.create_constant_node(
            a_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([a_constant])
        b_constant = quantize_graph.create_constant_node(
            b_constant_name,
            value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([b_constant])
        concat_node = quantize_graph.create_node(
            "Concat", concat_name,
            [shape_constant_name, a_constant_name, b_constant_name])
        quantize_graph.set_attr_int(concat_node, "N", 2)
        quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
        float_graph_def.node.extend([concat_node])

        test_graph(float_graph_def, {}, [concat_name])
예제 #3
0
  def test_concat(self):
    shape_constant_name = "shape_constant"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    shape_constant = quantize_graph.create_constant_node(shape_constant_name,
                                                         value=0,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([shape_constant])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=[1, 2, 3, 4, 5, 6, 7,
                                                            8, 9, 10, 11, 12],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=[13, 14, 15, 16, 17,
                                                            18, 19, 20, 21, 22,
                                                            23, 24],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([b_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [shape_constant_name,
                                              a_constant_name, b_constant_name])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])
예제 #4
0
  def test_multiple_outputs(self):
    input_constant_name = "input_constant"
    split_constant_name = "split_constant"
    split_name = "split"
    concat_constant_name = "concat_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                         value=[1, 2, 3, 4, 5,
                                                                6, 7, 8, 9, 10,
                                                                11, 12],
                                                         dtype=tf.float32,
                                                         shape=[2, 6])
    float_graph_def.node.extend([input_constant])
    split_constant = quantize_graph.create_constant_node(split_constant_name,
                                                         value=1,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([split_constant])
    split_node = quantize_graph.create_node("Split", split_name,
                                            [split_constant_name,
                                             input_constant_name])
    quantize_graph.set_attr_int(split_node, "num_split", 2)
    quantize_graph.set_attr_dtype(split_node, "T", tf.float32)
    float_graph_def.node.extend([split_node])
    concat_constant = quantize_graph.create_constant_node(concat_constant_name,
                                                          value=1,
                                                          dtype=tf.int32,
                                                          shape=[])
    float_graph_def.node.extend([concat_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [concat_constant_name,
                                              split_name + ":0",
                                              split_name + ":1"])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])