def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
              filter_count, stride, padding, input_values, filter_values):
  """Tests a Conv replacement."""
  input_constant_name = "input_constant"
  filter_constant_name = "filter_constant"
  conv_name = "conv"

  float_graph_def = graph_pb2.GraphDef()
  input_constant = quantize_graph.create_constant_node(
      input_constant_name,
      value=input_values,
      dtype=dtypes.float32,
      shape=[image_batch_count, image_height, image_width, depth])
  float_graph_def.node.extend([input_constant])
  filter_constant = quantize_graph.create_constant_node(
      filter_constant_name,
      value=filter_values,
      dtype=dtypes.float32,
      shape=[filter_size, filter_size, depth, filter_count])
  float_graph_def.node.extend([filter_constant])
  conv_node = quantize_graph.create_node(
      "Conv2D", conv_name, [input_constant_name, filter_constant_name])
  quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32)
  quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
  quantize_graph.set_attr_string(conv_node, "padding", padding)
  float_graph_def.node.extend([conv_node])

  test_graph(float_graph_def, {}, [conv_name])
 def test_max_pool(self):
   input_constant_name = "input_constant"
   max_pool_name = "max_pool"
   float_graph_def = graph_pb2.GraphDef()
   input_constant = quantize_graph.create_constant_node(
       input_constant_name,
       value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
       dtype=dtypes.float32,
       shape=[1, 2, 6, 1])
   float_graph_def.node.extend([input_constant])
   max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
                                              [input_constant_name])
   quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
   quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
   quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
   float_graph_def.node.extend([max_pool_node])
   test_graph(float_graph_def, {}, [max_pool_name])
 def test_avg_pool(self):
   input_constant_name = "input_constant"
   avg_pool_name = "avg_pool"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 2, 3, 4, 5,
                                                               6, 7, 8, 9, 10,
                                                               11, 12],
                                                        dtype=tf.float32,
                                                        shape=[1, 2, 6, 1])
   float_graph_def.node.extend([input_constant])
   avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
                                              [input_constant_name])
   quantize_graph.set_attr_dtype(avg_pool_node, "T", tf.float32)
   quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
   quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1])
   quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
   float_graph_def.node.extend([avg_pool_node])
   test_graph(float_graph_def, {}, [avg_pool_name])
 def test_avg_pool(self):
     input_constant_name = "input_constant"
     avg_pool_name = "avg_pool"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
         dtype=tf.float32,
         shape=[1, 2, 6, 1])
     float_graph_def.node.extend([input_constant])
     avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
                                                [input_constant_name])
     quantize_graph.set_attr_dtype(avg_pool_node, "T", tf.float32)
     quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
     quantize_graph.set_attr_int_list(avg_pool_node, "strides",
                                      [1, 1, 1, 1])
     quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
     float_graph_def.node.extend([avg_pool_node])
     test_graph(float_graph_def, {}, [avg_pool_name])