def test_mat_mul(m, n, k, a, b):
    """Tests a MatMul replacement."""
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    mat_mul_name = "mat_mul"

    float_graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=a,
                                                     dtype=tf.float32,
                                                     shape=[m, k])
    float_graph_def.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=b,
                                                     dtype=tf.float32,
                                                     shape=[k, n])
    float_graph_def.node.extend([b_constant])
    mat_mul_node = quantize_graph.create_node(
        "MatMul", mat_mul_name, [a_constant_name, b_constant_name])
    quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32)
    quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
    quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
    float_graph_def.node.extend([mat_mul_node])

    test_graph(float_graph_def, {}, [mat_mul_name])
示例#2
0
def test_mat_mul(m, n, k, a, b):
  """Tests a MatMul replacement."""
  a_constant_name = "a_constant"
  b_constant_name = "b_constant"
  mat_mul_name = "mat_mul"

  float_graph_def = tf.GraphDef()
  a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                   value=a,
                                                   dtype=tf.float32,
                                                   shape=[m, k])
  float_graph_def.node.extend([a_constant])
  b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                   value=b,
                                                   dtype=tf.float32,
                                                   shape=[k, n])
  float_graph_def.node.extend([b_constant])
  mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
                                            [a_constant_name, b_constant_name])
  quantize_graph.set_attr_dtype(mat_mul_node, "T", tf.float32)
  quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
  quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
  float_graph_def.node.extend([mat_mul_node])

  test_graph(float_graph_def, {}, [mat_mul_name])
 def test_batch_norm(self):
     input_constant_name = "input_constant"
     mean_constant_name = "mean_constant"
     variance_constant_name = "variance_constant"
     beta_constant_name = "beta_constant"
     gamma_constant_name = "gamma_constant"
     batch_norm_name = "batch_norm"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
         dtype=tf.float32,
         shape=[1, 1, 6, 2])
     float_graph_def.node.extend([input_constant])
     mean_constant = quantize_graph.create_constant_node(mean_constant_name,
                                                         value=[10, 20],
                                                         dtype=tf.float32,
                                                         shape=[2])
     float_graph_def.node.extend([mean_constant])
     variance_constant = quantize_graph.create_constant_node(
         variance_constant_name,
         value=[0.25, 0.5],
         dtype=tf.float32,
         shape=[2])
     float_graph_def.node.extend([variance_constant])
     beta_constant = quantize_graph.create_constant_node(beta_constant_name,
                                                         value=[0.1, 0.6],
                                                         dtype=tf.float32,
                                                         shape=[2])
     float_graph_def.node.extend([beta_constant])
     gamma_constant = quantize_graph.create_constant_node(
         gamma_constant_name, value=[0, 0], dtype=tf.float32, shape=[2])
     float_graph_def.node.extend([gamma_constant])
     batch_norm_node = quantize_graph.create_node(
         "BatchNormWithGlobalNormalization", batch_norm_name, [
             input_constant_name, mean_constant_name,
             variance_constant_name, beta_constant_name, gamma_constant_name
         ])
     quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32)
     quantize_graph.set_attr_bool(batch_norm_node,
                                  "scale_after_normalization", False)
     quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon",
                                   0.001)
     float_graph_def.node.extend([batch_norm_node])
     test_graph(float_graph_def, {}, [batch_norm_name])
示例#4
0
 def test_batch_norm(self):
   input_constant_name = "input_constant"
   mean_constant_name = "mean_constant"
   variance_constant_name = "variance_constant"
   beta_constant_name = "beta_constant"
   gamma_constant_name = "gamma_constant"
   batch_norm_name = "batch_norm"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 4, 2, 5, 3,
                                                               6, -1, -4, -2,
                                                               -5, -3, -6],
                                                        dtype=tf.float32,
                                                        shape=[1, 1, 6, 2])
   float_graph_def.node.extend([input_constant])
   mean_constant = quantize_graph.create_constant_node(mean_constant_name,
                                                       value=[10, 20],
                                                       dtype=tf.float32,
                                                       shape=[2])
   float_graph_def.node.extend([mean_constant])
   variance_constant = quantize_graph.create_constant_node(
       variance_constant_name, value=[0.25, 0.5], dtype=tf.float32, shape=[2])
   float_graph_def.node.extend([variance_constant])
   beta_constant = quantize_graph.create_constant_node(beta_constant_name,
                                                       value=[0.1, 0.6],
                                                       dtype=tf.float32,
                                                       shape=[2])
   float_graph_def.node.extend([beta_constant])
   gamma_constant = quantize_graph.create_constant_node(gamma_constant_name,
                                                        value=[0, 0],
                                                        dtype=tf.float32,
                                                        shape=[2])
   float_graph_def.node.extend([gamma_constant])
   batch_norm_node = quantize_graph.create_node(
       "BatchNormWithGlobalNormalization", batch_norm_name,
       [input_constant_name, mean_constant_name, variance_constant_name,
        beta_constant_name, gamma_constant_name])
   quantize_graph.set_attr_dtype(batch_norm_node, "T", tf.float32)
   quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
                                False)
   quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
   float_graph_def.node.extend([batch_norm_node])
   test_graph(float_graph_def, {}, [batch_norm_name])