def test_multiple_outputs(self):
        input_constant_name = "input_constant"
        split_constant_name = "split_constant"
        split_name = "split"
        concat_constant_name = "concat_constant"
        concat_name = "concat"

        float_graph_def = tf.GraphDef()
        input_constant = quantize_graph.create_constant_node(
            input_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 6])
        float_graph_def.node.extend([input_constant])
        split_constant = quantize_graph.create_constant_node(
            split_constant_name, value=1, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([split_constant])
        split_node = quantize_graph.create_node(
            "Split", split_name, [split_constant_name, input_constant_name])
        quantize_graph.set_attr_int(split_node, "num_split", 2)
        quantize_graph.set_attr_dtype(split_node, "T", tf.float32)
        float_graph_def.node.extend([split_node])
        concat_constant = quantize_graph.create_constant_node(
            concat_constant_name, value=1, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([concat_constant])
        concat_node = quantize_graph.create_node(
            "Concat", concat_name,
            [concat_constant_name, split_name + ":0", split_name + ":1"])
        quantize_graph.set_attr_int(concat_node, "N", 2)
        quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
        float_graph_def.node.extend([concat_node])

        test_graph(float_graph_def, {}, [concat_name])
def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
              filter_count, stride, padding, input_values, filter_values):
  """Tests a Conv replacement."""
  input_constant_name = "input_constant"
  filter_constant_name = "filter_constant"
  conv_name = "conv"

  float_graph_def = tf.GraphDef()
  input_constant = quantize_graph.create_constant_node(
      input_constant_name,
      value=input_values,
      dtype=tf.float32,
      shape=[
          image_batch_count, image_height, image_width, depth
      ])
  float_graph_def.node.extend([input_constant])
  filter_constant = quantize_graph.create_constant_node(
      filter_constant_name,
      value=filter_values,
      dtype=tf.float32,
      shape=[
          filter_size, filter_size, depth, filter_count
      ])
  float_graph_def.node.extend([filter_constant])
  conv_node = quantize_graph.create_node("Conv2D", conv_name,
                                         [input_constant_name,
                                          filter_constant_name])
  quantize_graph.set_attr_dtype(conv_node, "T", tf.float32)
  quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
  quantize_graph.set_attr_string(conv_node, "padding", padding)
  float_graph_def.node.extend([conv_node])

  test_graph(float_graph_def, {}, [conv_name])
    def test_concat(self):
        shape_constant_name = "shape_constant"
        a_constant_name = "a_constant"
        b_constant_name = "b_constant"
        concat_name = "concat"

        float_graph_def = tf.GraphDef()
        shape_constant = quantize_graph.create_constant_node(
            shape_constant_name, value=0, dtype=tf.int32, shape=[])
        float_graph_def.node.extend([shape_constant])
        a_constant = quantize_graph.create_constant_node(
            a_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([a_constant])
        b_constant = quantize_graph.create_constant_node(
            b_constant_name,
            value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
            dtype=tf.float32,
            shape=[2, 2, 3])
        float_graph_def.node.extend([b_constant])
        concat_node = quantize_graph.create_node(
            "Concat", concat_name,
            [shape_constant_name, a_constant_name, b_constant_name])
        quantize_graph.set_attr_int(concat_node, "N", 2)
        quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
        float_graph_def.node.extend([concat_node])

        test_graph(float_graph_def, {}, [concat_name])
def test_mat_mul(m, n, k, a, b):
  """Tests a MatMul replacement."""
  a_constant_name = "a_constant"
  b_constant_name = "b_constant"
  mat_mul_name = "mat_mul"

  float_graph_def = tf.GraphDef()
  a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                   value=a,
                                                   dtype=tf.float32,
                                                   shape=[m, k])
  float_graph_def.node.extend([a_constant])
  b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                   value=b,
                                                   dtype=tf.float32,
                                                   shape=[k, n])
  float_graph_def.node.extend([b_constant])
  mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
                                            [a_constant_name, b_constant_name])
  quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32)
  quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
  quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
  float_graph_def.node.extend([mat_mul_node])

  test_graph(float_graph_def, {}, [mat_mul_name])
  def test_concat(self):
    shape_constant_name = "shape_constant"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    shape_constant = quantize_graph.create_constant_node(shape_constant_name,
                                                         value=0,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([shape_constant])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=[1, 2, 3, 4, 5, 6, 7,
                                                            8, 9, 10, 11, 12],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=[13, 14, 15, 16, 17,
                                                            18, 19, 20, 21, 22,
                                                            23, 24],
                                                     dtype=tf.float32,
                                                     shape=[2, 2, 3])
    float_graph_def.node.extend([b_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [shape_constant_name,
                                              a_constant_name, b_constant_name])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])
def test_mat_mul(m, n, k, a, b):
    """Tests a MatMul replacement."""
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    mat_mul_name = "mat_mul"

    float_graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=a,
                                                     dtype=tf.float32,
                                                     shape=[m, k])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=b,
                                                     dtype=tf.float32,
                                                     shape=[k, n])
    float_graph_def.node.extend([b_constant])
    mat_mul_node = quantize_graph.create_node(
        "MatMul", mat_mul_name, [a_constant_name, b_constant_name])
    quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32)
    quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
    quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
    float_graph_def.node.extend([mat_mul_node])

    test_graph(float_graph_def, {}, [mat_mul_name])
def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
              filter_count, stride, padding, input_values, filter_values):
    """Tests a Conv replacement."""
    input_constant_name = "input_constant"
    filter_constant_name = "filter_constant"
    conv_name = "conv"

    float_graph_def = tf.GraphDef()
    input_constant = quantize_graph.create_constant_node(
        input_constant_name,
        value=input_values,
        dtype=tf.float32,
        shape=[image_batch_count, image_height, image_width, depth])
    float_graph_def.node.extend([input_constant])
    filter_constant = quantize_graph.create_constant_node(
        filter_constant_name,
        value=filter_values,
        dtype=tf.float32,
        shape=[filter_size, filter_size, depth, filter_count])
    float_graph_def.node.extend([filter_constant])
    conv_node = quantize_graph.create_node(
        "Conv2D", conv_name, [input_constant_name, filter_constant_name])
    quantize_graph.set_attr_dtype(conv_node, "T", tf.float32)
    quantize_graph.set_attr_int_list(conv_node, "strides",
                                     [1, stride, stride, 1])
    quantize_graph.set_attr_string(conv_node, "padding", padding)
    float_graph_def.node.extend([conv_node])

    test_graph(float_graph_def, {}, [conv_name])
  def test_multiple_outputs(self):
    input_constant_name = "input_constant"
    split_constant_name = "split_constant"
    split_name = "split"
    concat_constant_name = "concat_constant"
    concat_name = "concat"

    float_graph_def = tf.GraphDef()
    input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                         value=[1, 2, 3, 4, 5,
                                                                6, 7, 8, 9, 10,
                                                                11, 12],
                                                         dtype=tf.float32,
                                                         shape=[2, 6])
    float_graph_def.node.extend([input_constant])
    split_constant = quantize_graph.create_constant_node(split_constant_name,
                                                         value=1,
                                                         dtype=tf.int32,
                                                         shape=[])
    float_graph_def.node.extend([split_constant])
    split_node = quantize_graph.create_node("Split", split_name,
                                            [split_constant_name,
                                             input_constant_name])
    quantize_graph.set_attr_int(split_node, "num_split", 2)
    quantize_graph.set_attr_dtype(split_node, "T", tf.float32)
    float_graph_def.node.extend([split_node])
    concat_constant = quantize_graph.create_constant_node(concat_constant_name,
                                                          value=1,
                                                          dtype=tf.int32,
                                                          shape=[])
    float_graph_def.node.extend([concat_constant])
    concat_node = quantize_graph.create_node("Concat", concat_name,
                                             [concat_constant_name,
                                              split_name + ":0",
                                              split_name + ":1"])
    quantize_graph.set_attr_int(concat_node, "N", 2)
    quantize_graph.set_attr_dtype(concat_node, "T", tf.float32)
    float_graph_def.node.extend([concat_node])

    test_graph(float_graph_def, {}, [concat_name])
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Beispiel #10
0
  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
    def test_identity(self):
        input_constant_name = "input_constant"
        identity_name = "identity"
        float_graph_def = tf.GraphDef()
        input_constant = quantize_graph.create_constant_node(
            input_constant_name,
            value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            dtype=tf.float32,
            shape=[2, 6])
        float_graph_def.node.extend([input_constant])
        identity_node = quantize_graph.create_node("Identity", identity_name,
                                                   [input_constant_name])
        quantize_graph.set_attr_dtype(identity_node, "T", tf.float32)
        float_graph_def.node.extend([identity_node])

        mul_name = "mul"
        mul_node = quantize_graph.create_node("Mul", mul_name,
                                              [identity_name, identity_name])
        quantize_graph.set_attr_dtype(mul_node, "T", tf.float32)
        float_graph_def.node.extend([mul_node])

        test_graph(float_graph_def, {}, [mul_name])
  def test_identity(self):
    input_constant_name = "input_constant"
    identity_name = "identity"
    float_graph_def = tf.GraphDef()
    input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                         value=[1, 2, 3, 4, 5,
                                                                6, 7, 8, 9, 10,
                                                                11, 12],
                                                         dtype=tf.float32,
                                                         shape=[2, 6])
    float_graph_def.node.extend([input_constant])
    identity_node = quantize_graph.create_node("Identity", identity_name,
                                               [input_constant_name])
    quantize_graph.set_attr_dtype(identity_node, "T", tf.float32)
    float_graph_def.node.extend([identity_node])

    mul_name = "mul"
    mul_node = quantize_graph.create_node("Mul", mul_name,
                                          [identity_name, identity_name])
    quantize_graph.set_attr_dtype(mul_node, "T", tf.float32)
    float_graph_def.node.extend([mul_node])

    test_graph(float_graph_def, {}, [mul_name])
 def test_relu6(self):
     input_constant_name = "input_constant"
     relu6_name = "relu6"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
         dtype=tf.float32,
         shape=[1, 2, 6, 1])
     float_graph_def.node.extend([input_constant])
     relu6_node = quantize_graph.create_node("Relu6", relu6_name,
                                             [input_constant_name])
     quantize_graph.set_attr_dtype(relu6_node, "T", tf.float32)
     float_graph_def.node.extend([relu6_node])
     test_graph(float_graph_def, {}, [relu6_name])
 def test_relu6(self):
   input_constant_name = "input_constant"
   relu6_name = "relu6"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 2, 3, 4, 5,
                                                               6, 7, 8, 9, 10,
                                                               11, 12],
                                                        dtype=tf.float32,
                                                        shape=[1, 2, 6, 1])
   float_graph_def.node.extend([input_constant])
   relu6_node = quantize_graph.create_node("Relu6", relu6_name,
                                           [input_constant_name])
   quantize_graph.set_attr_dtype(relu6_node, "T", tf.float32)
   float_graph_def.node.extend([relu6_node])
   test_graph(float_graph_def, {}, [relu6_name])
 def test_batch_norm(self):
     input_constant_name = "input_constant"
     mean_constant_name = "mean_constant"
     variance_constant_name = "variance_constant"
     beta_constant_name = "beta_constant"
     gamma_constant_name = "gamma_constant"
     batch_norm_name = "batch_norm"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
         dtype=tf.float32,
         shape=[1, 1, 6, 2])
     float_graph_def.node.extend([input_constant])
     mean_constant = quantize_graph.create_constant_node(mean_constant_name,
                                                         value=[10, 20],
                                                         dtype=tf.float32,
                                                         shape=[2])
     float_graph_def.node.extend([mean_constant])
     variance_constant = quantize_graph.create_constant_node(
         variance_constant_name,
         value=[0.25, 0.5],
         dtype=tf.float32,
         shape=[2])
     float_graph_def.node.extend([variance_constant])
     beta_constant = quantize_graph.create_constant_node(beta_constant_name,
                                                         value=[0.1, 0.6],
                                                         dtype=tf.float32,
                                                         shape=[2])
     float_graph_def.node.extend([beta_constant])
     gamma_constant = quantize_graph.create_constant_node(
         gamma_constant_name, value=[0, 0], dtype=tf.float32, shape=[2])
     float_graph_def.node.extend([gamma_constant])
     batch_norm_node = quantize_graph.create_node(
         "BatchNormWithGlobalNormalization", batch_norm_name, [
             input_constant_name, mean_constant_name,
             variance_constant_name, beta_constant_name, gamma_constant_name
         ])
     quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32)
     quantize_graph.set_attr_bool(batch_norm_node,
                                  "scale_after_normalization", False)
     quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon",
                                   0.001)
     float_graph_def.node.extend([batch_norm_node])
     test_graph(float_graph_def, {}, [batch_norm_name])
 def test_max_pool(self):
     input_constant_name = "input_constant"
     max_pool_name = "max_pool"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
         dtype=tf.float32,
         shape=[1, 2, 6, 1])
     float_graph_def.node.extend([input_constant])
     max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
                                                [input_constant_name])
     quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
     quantize_graph.set_attr_int_list(max_pool_node, "strides",
                                      [1, 1, 1, 1])
     quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
     float_graph_def.node.extend([max_pool_node])
     test_graph(float_graph_def, {}, [max_pool_name])
 def test_max_pool(self):
   input_constant_name = "input_constant"
   max_pool_name = "max_pool"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 2, 3, 4, 5,
                                                               6, 7, 8, 9, 10,
                                                               11, 12],
                                                        dtype=tf.float32,
                                                        shape=[1, 2, 6, 1])
   float_graph_def.node.extend([input_constant])
   max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
                                              [input_constant_name])
   quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
   quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
   quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
   float_graph_def.node.extend([max_pool_node])
   test_graph(float_graph_def, {}, [max_pool_name])
 def test_batch_norm(self):
   input_constant_name = "input_constant"
   mean_constant_name = "mean_constant"
   variance_constant_name = "variance_constant"
   beta_constant_name = "beta_constant"
   gamma_constant_name = "gamma_constant"
   batch_norm_name = "batch_norm"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 4, 2, 5, 3,
                                                               6, -1, -4, -2,
                                                               -5, -3, -6],
                                                        dtype=tf.float32,
                                                        shape=[1, 1, 6, 2])
   float_graph_def.node.extend([input_constant])
   mean_constant = quantize_graph.create_constant_node(mean_constant_name,
                                                       value=[10, 20],
                                                       dtype=tf.float32,
                                                       shape=[2])
   float_graph_def.node.extend([mean_constant])
   variance_constant = quantize_graph.create_constant_node(
       variance_constant_name, value=[0.25, 0.5], dtype=tf.float32, shape=[2])
   float_graph_def.node.extend([variance_constant])
   beta_constant = quantize_graph.create_constant_node(beta_constant_name,
                                                       value=[0.1, 0.6],
                                                       dtype=tf.float32,
                                                       shape=[2])
   float_graph_def.node.extend([beta_constant])
   gamma_constant = quantize_graph.create_constant_node(gamma_constant_name,
                                                        value=[0, 0],
                                                        dtype=tf.float32,
                                                        shape=[2])
   float_graph_def.node.extend([gamma_constant])
   batch_norm_node = quantize_graph.create_node(
       "BatchNormWithGlobalNormalization", batch_norm_name,
       [input_constant_name, mean_constant_name, variance_constant_name,
        beta_constant_name, gamma_constant_name])
   quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32)
   quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
                                False)
   quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
   float_graph_def.node.extend([batch_norm_node])
   test_graph(float_graph_def, {}, [batch_norm_name])
 def test_bias_add(self):
     input_constant_name = "input_constant"
     offset_constant_name = "offset_constant"
     bias_add_name = "bias_add"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
         dtype=tf.float32,
         shape=[1, 1, 2, 6])
     float_graph_def.node.extend([input_constant])
     offset_constant = quantize_graph.create_constant_node(
         offset_constant_name,
         value=[1, 2, 3, 4, 5, 6],
         dtype=tf.float32,
         shape=[6])
     float_graph_def.node.extend([offset_constant])
     bias_add_node = quantize_graph.create_node(
         "BiasAdd", bias_add_name,
         [input_constant_name, offset_constant_name])
     quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
     float_graph_def.node.extend([bias_add_node])
     test_graph(float_graph_def, {}, [bias_add_name])
 def test_bias_add(self):
   input_constant_name = "input_constant"
   offset_constant_name = "offset_constant"
   bias_add_name = "bias_add"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 2, 3, 4, 5,
                                                               6, 7, 8, 9, 10,
                                                               11, 12],
                                                        dtype=tf.float32,
                                                        shape=[1, 1, 2, 6])
   float_graph_def.node.extend([input_constant])
   offset_constant = quantize_graph.create_constant_node(offset_constant_name,
                                                         value=[1, 2, 3, 4, 5,
                                                                6],
                                                         dtype=tf.float32,
                                                         shape=[6])
   float_graph_def.node.extend([offset_constant])
   bias_add_node = quantize_graph.create_node("BiasAdd", bias_add_name,
                                              [input_constant_name,
                                               offset_constant_name])
   quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
   float_graph_def.node.extend([bias_add_node])
   test_graph(float_graph_def, {}, [bias_add_name])
    def test_remove_redundant_quantization(self):
        a_constant_name = "a_constant"
        a_constant_min_name = "a_constant_min"
        a_constant_max_name = "a_constant_max"
        a_dequantize_name = "a_dequantize"
        a_quantize_name = "a_quantize"
        b_constant_name = "b_constant"
        b_constant_min_name = "b_constant_min"
        b_constant_max_name = "b_constant_max"
        b_dequantize_name = "b_dequantize"
        b_quantize_name = "b_quantize"
        mat_mul_name = "mat_mul"
        graph_def = tf.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        graph_def.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=tf.float32, shape=[])
        graph_def.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=tf.float32, shape=[])
        graph_def.node.extend([a_constant_max])
        a_dequantize_node = quantize_graph.create_node(
            "Dequantize", a_dequantize_name,
            [a_constant_name, a_constant_min_name, a_constant_max_name])
        quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8)
        graph_def.node.extend([a_dequantize_node])
        a_quantize_node = quantize_graph.create_node(
            "QuantizeV2", a_quantize_name, [
                a_dequantize_name, a_dequantize_name + ":1",
                a_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8)
        graph_def.node.extend([a_quantize_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        graph_def.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=tf.float32, shape=[])
        graph_def.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=tf.float32, shape=[])
        graph_def.node.extend([b_constant_max])
        b_dequantize_node = quantize_graph.create_node(
            "Dequantize", b_dequantize_name,
            [b_constant_name, b_constant_min_name, b_constant_max_name])
        quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8)
        graph_def.node.extend([b_dequantize_node])
        b_quantize_node = quantize_graph.create_node(
            "QuantizeV2", b_quantize_name, [
                b_dequantize_name, b_dequantize_name + ":1",
                b_dequantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8)
        graph_def.node.extend([b_quantize_node])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_quantize_name, b_quantize_name, a_quantize_name + ":1",
                a_quantize_name + ":2", b_quantize_name + ":1",
                b_quantize_name + ":2"
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
        graph_def.node.extend([mat_mul_node])

        expected_output = tf.GraphDef()
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        expected_output.node.extend([a_constant])
        a_constant_min = quantize_graph.create_constant_node(
            a_constant_min_name, value=2, dtype=tf.float32, shape=[])
        expected_output.node.extend([a_constant_min])
        a_constant_max = quantize_graph.create_constant_node(
            a_constant_max_name, value=2, dtype=tf.float32, shape=[])
        expected_output.node.extend([a_constant_max])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=(0, ),
                                                         dtype=tf.quint8,
                                                         shape=[])
        expected_output.node.extend([b_constant])
        b_constant_min = quantize_graph.create_constant_node(
            b_constant_min_name, value=3, dtype=tf.float32, shape=[])
        expected_output.node.extend([b_constant_min])
        b_constant_max = quantize_graph.create_constant_node(
            b_constant_max_name, value=3, dtype=tf.float32, shape=[])
        expected_output.node.extend([b_constant_max])
        mat_mul_node = quantize_graph.create_node(
            "QuantizedMatMul", mat_mul_name, [
                a_constant_name, b_constant_name, a_constant_min_name,
                a_constant_max_name, b_constant_min_name, b_constant_max_name
            ])
        quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
        quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
        expected_output.node.extend([mat_mul_node])

        rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name])
        output = rewriter.remove_redundant_quantization(graph_def)
        stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
        self.assertProtoEquals(expected_output, stripped_output)
  def test_keep_control_edges(self):
    no_op_name = "no_op"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    graph_def.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name,
                                                  "^" + no_op_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    expected_output.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + no_op_name])
    expected_output.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    output = graph_util.remove_training_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
Beispiel #23
0
    def test_keep_control_edges(self):
        no_op_name = "no_op"
        a_constant_name = "a_constant"
        b_constant_name = "b_constant"
        a_check_name = "a_check"
        b_check_name = "b_check"
        a_identity_name = "a_identity"
        b_identity_name = "b_identity"
        add_name = "add"
        graph_def = tf.GraphDef()
        no_op = quantize_graph.create_node("NoOp", no_op_name, [])
        graph_def.node.extend([no_op])
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=1,
                                                         dtype=tf.float32,
                                                         shape=[])
        graph_def.node.extend([a_constant])
        a_check_node = quantize_graph.create_node("CheckNumerics",
                                                  a_check_name,
                                                  [a_constant_name])
        graph_def.node.extend([a_check_node])
        a_identity_node = quantize_graph.create_node(
            "Identity", a_identity_name,
            [a_constant_name, "^" + a_check_name, "^" + no_op_name])
        graph_def.node.extend([a_identity_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=1,
                                                         dtype=tf.float32,
                                                         shape=[])
        graph_def.node.extend([b_constant])
        b_check_node = quantize_graph.create_node("CheckNumerics",
                                                  b_check_name,
                                                  [b_constant_name])
        graph_def.node.extend([b_check_node])
        b_identity_node = quantize_graph.create_node(
            "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
        graph_def.node.extend([b_identity_node])
        add_node = quantize_graph.create_node(
            "Add", add_name, [a_identity_name, b_identity_name])
        quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
        graph_def.node.extend([add_node])

        expected_output = tf.GraphDef()
        no_op = quantize_graph.create_node("NoOp", no_op_name, [])
        expected_output.node.extend([no_op])
        a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                         value=1,
                                                         dtype=tf.float32,
                                                         shape=[])
        expected_output.node.extend([a_constant])
        a_identity_node = quantize_graph.create_node(
            "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
        expected_output.node.extend([a_identity_node])
        b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                         value=1,
                                                         dtype=tf.float32,
                                                         shape=[])
        expected_output.node.extend([b_constant])
        add_node = quantize_graph.create_node(
            "Add", add_name, [a_identity_name, b_constant_name])
        quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
        expected_output.node.extend([add_node])

        output = graph_util.remove_training_nodes(graph_def)
        stripped_output = graph_util.extract_sub_graph(output, [add_name])
        self.assertProtoEquals(expected_output, stripped_output)
  def test_remove_redundant_quantization(self):
    a_constant_name = "a_constant"
    a_constant_min_name = "a_constant_min"
    a_constant_max_name = "a_constant_max"
    a_dequantize_name = "a_dequantize"
    a_quantize_name = "a_quantize"
    b_constant_name = "b_constant"
    b_constant_min_name = "b_constant_min"
    b_constant_max_name = "b_constant_max"
    b_dequantize_name = "b_dequantize"
    b_quantize_name = "b_quantize"
    mat_mul_name = "mat_mul"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([a_constant_max])
    a_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   a_dequantize_name,
                                                   [a_constant_name,
                                                    a_constant_min_name,
                                                    a_constant_max_name])
    quantize_graph.set_attr_dtype(a_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([a_dequantize_node])
    a_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 a_quantize_name,
                                                 [a_dequantize_name,
                                                  a_dequantize_name + ":1",
                                                  a_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(a_quantize_node, "T", tf.uint8)
    graph_def.node.extend([a_quantize_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    graph_def.node.extend([b_constant_max])
    b_dequantize_node = quantize_graph.create_node("Dequantize",
                                                   b_dequantize_name,
                                                   [b_constant_name,
                                                    b_constant_min_name,
                                                    b_constant_max_name])
    quantize_graph.set_attr_dtype(b_dequantize_node, "T", tf.uint8)
    graph_def.node.extend([b_dequantize_node])
    b_quantize_node = quantize_graph.create_node("QuantizeV2",
                                                 b_quantize_name,
                                                 [b_dequantize_name,
                                                  b_dequantize_name + ":1",
                                                  b_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(b_quantize_node, "T", tf.uint8)
    graph_def.node.extend([b_quantize_node])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_quantize_name,
                                               b_quantize_name,
                                               a_quantize_name + ":1",
                                               a_quantize_name + ":2",
                                               b_quantize_name + ":1",
                                               b_quantize_name + ":2"])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    graph_def.node.extend([mat_mul_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(a_constant_min_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(a_constant_max_name,
                                                         value=2,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([a_constant_max])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=(0,),
                                                     dtype=tf.quint8,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(b_constant_min_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(b_constant_max_name,
                                                         value=3,
                                                         dtype=tf.float32,
                                                         shape=[])
    expected_output.node.extend([b_constant_max])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name,
                                              [a_constant_name,
                                               b_constant_name,
                                               a_constant_min_name,
                                               a_constant_max_name,
                                               b_constant_min_name,
                                               b_constant_max_name])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", tf.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", tf.int32)
    expected_output.node.extend([mat_mul_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [mat_mul_name])
    output = rewriter.remove_redundant_quantization(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
    self.assertProtoEquals(expected_output, stripped_output)