def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
              filter_count, stride, padding, input_values, filter_values):
    """Tests a Conv replacement."""
    input_constant_name = "input_constant"
    filter_constant_name = "filter_constant"
    conv_name = "conv"

    float_graph_def = tf.GraphDef()
    input_constant = quantize_graph.create_constant_node(
        input_constant_name,
        value=input_values,
        dtype=tf.float32,
        shape=[image_batch_count, image_height, image_width, depth])
    float_graph_def.node.extend([input_constant])
    filter_constant = quantize_graph.create_constant_node(
        filter_constant_name,
        value=filter_values,
        dtype=tf.float32,
        shape=[filter_size, filter_size, depth, filter_count])
    float_graph_def.node.extend([filter_constant])
    conv_node = quantize_graph.create_node(
        "Conv2D", conv_name, [input_constant_name, filter_constant_name])
    quantize_graph.set_attr_dtype(conv_node, "T", tf.float32)
    quantize_graph.set_attr_int_list(conv_node, "strides",
                                     [1, stride, stride, 1])
    quantize_graph.set_attr_string(conv_node, "padding", padding)
    float_graph_def.node.extend([conv_node])

    test_graph(float_graph_def, {}, [conv_name])
def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
              filter_count, stride, padding, input_values, filter_values):
  """Tests a Conv replacement."""
  input_constant_name = "input_constant"
  filter_constant_name = "filter_constant"
  conv_name = "conv"

  float_graph_def = tf.GraphDef()
  input_constant = quantize_graph.create_constant_node(
      input_constant_name,
      value=input_values,
      dtype=tf.float32,
      shape=[
          image_batch_count, image_height, image_width, depth
      ])
  float_graph_def.node.extend([input_constant])
  filter_constant = quantize_graph.create_constant_node(
      filter_constant_name,
      value=filter_values,
      dtype=tf.float32,
      shape=[
          filter_size, filter_size, depth, filter_count
      ])
  float_graph_def.node.extend([filter_constant])
  conv_node = quantize_graph.create_node("Conv2D", conv_name,
                                         [input_constant_name,
                                          filter_constant_name])
  quantize_graph.set_attr_dtype(conv_node, "T", tf.float32)
  quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
  quantize_graph.set_attr_string(conv_node, "padding", padding)
  float_graph_def.node.extend([conv_node])

  test_graph(float_graph_def, {}, [conv_name])
 def test_max_pool(self):
     input_constant_name = "input_constant"
     max_pool_name = "max_pool"
     float_graph_def = tf.GraphDef()
     input_constant = quantize_graph.create_constant_node(
         input_constant_name,
         value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
         dtype=tf.float32,
         shape=[1, 2, 6, 1])
     float_graph_def.node.extend([input_constant])
     max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
                                                [input_constant_name])
     quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
     quantize_graph.set_attr_int_list(max_pool_node, "strides",
                                      [1, 1, 1, 1])
     quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
     float_graph_def.node.extend([max_pool_node])
     test_graph(float_graph_def, {}, [max_pool_name])
 def test_max_pool(self):
   input_constant_name = "input_constant"
   max_pool_name = "max_pool"
   float_graph_def = tf.GraphDef()
   input_constant = quantize_graph.create_constant_node(input_constant_name,
                                                        value=[1, 2, 3, 4, 5,
                                                               6, 7, 8, 9, 10,
                                                               11, 12],
                                                        dtype=tf.float32,
                                                        shape=[1, 2, 6, 1])
   float_graph_def.node.extend([input_constant])
   max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
                                              [input_constant_name])
   quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
   quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
   quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
   float_graph_def.node.extend([max_pool_node])
   test_graph(float_graph_def, {}, [max_pool_name])