예제 #1
0
  def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass,
                                  has_scaling, fused_batch_norm):
    """Tests that running folded and unfolded BN returns the same results.

    Args:
      relu: Callable that returns an Operation, a factory method for the Relu*.
      relu_op_name: String, name of the Relu* operation.
      with_bypass: Bool, when true there is an extra connection added from
        inputs to just before Relu*.
      has_scaling: Bool, when true the batch norm has scaling.
      fused_batch_norm: Bool, when true the batch norm is fused.
    """
    random_seed.set_random_seed(1234)
    unfolded_g = ops.Graph()
    with unfolded_g.as_default():
      batch_size, height, width = 5, 128, 128
      inputs = random_ops.random_uniform(
          (batch_size, height, width, 3), dtype=dtypes.float32, seed=1234)
      out_depth = 3 if with_bypass else 32
      stride = 1 if with_bypass else 2
      activation_fn = None if with_bypass else relu
      scope = 'test/test2' if with_bypass else 'test'
      node = conv2d(
          inputs,
          out_depth, [5, 5],
          stride=stride,
          padding='SAME',
          weights_initializer=self._WeightInit(0.09),
          activation_fn=activation_fn,
          normalizer_fn=batch_norm,
          normalizer_params=self._BatchNormParams(
              scale=has_scaling, fused=fused_batch_norm),
          scope=scope)
      if with_bypass:
        node = math_ops.add(inputs, node, name='test/Add')
      relu_node = relu(node, name='test/' + relu_op_name)

    folded_g = copy_graph.CopyGraph(unfolded_g)
    with folded_g.as_default():
      fold_batch_norms.FoldBatchNorms(folded_g)

    with session.Session(graph=unfolded_g) as sess:
      sess.run(variables.global_variables_initializer())
      grad_node = gradients.gradients(relu_node, inputs)
      results = sess.run([relu_node, grad_node])
      unfolded_forward, unfolded_backward = results[0], results[1]

    with session.Session(graph=folded_g) as sess:
      sess.run(variables.global_variables_initializer())
      relu_node = folded_g.get_tensor_by_name(relu_node.name)
      inputs = folded_g.get_tensor_by_name(inputs.name)
      grad_node = gradients.gradients(relu_node, inputs)
      results = sess.run([relu_node, grad_node])
      folded_forward, folded_backward = results[0], results[1]

    # Check that the folded and unfolded results match.
    self.assertAllClose(unfolded_forward, folded_forward, atol=1e-3)
    self.assertAllClose(unfolded_backward, folded_backward, atol=1e-3)
예제 #2
0
def _create_graph(input_graph,
                  is_training,
                  elements=None,
                  device_name_or_function=None):
    """Returns a transformed training input_graph for simulated quantization.

  The forward pass has fake quantization ops inserted to simulate the error
  introduced by quantization.

  Args:
    input_graph: The tf.Graph to be transformed.
    is_training: Whether quantizing training or eval graph.
    elements: (Optional) List of Tensors and Operations in input_graph whose
        corresponding elements in the new graph will be returned.
    device_name_or_function: (Optional) The device name or function to use.

  Returns:
    g is new tf.Graph that is rewritten for simulated quantization.
    l is a list of Tensors/Operations in g corresponding to the provided input
        elements, if elements is not None.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
        tf.Operation.
  """
    # TODO(suharshs): Describe the process in more detail in the doc string.
    g = copy_graph.CopyGraph(input_graph)
    if is_training:
        # TODO(raghuramank): Need to make freeze_batch_norm_delay
        # a function of the batch size. For now setting this to 250 epochs
        # This corresponds to 5 million steps at a batch size of 64.
        freeze_batch_norm_delay = 5000000
    else:
        freeze_batch_norm_delay = None
    with g.as_default():
        with ops.device(device_name_or_function):
            fold_batch_norms.FoldBatchNorms(
                g,
                freeze_batch_norm_delay=freeze_batch_norm_delay,
                is_training=is_training)
            quantize.Quantize(g, is_training=is_training)
    if elements is None:
        return g

    return_elements = []
    for element in elements:
        if isinstance(element, (ops.Tensor, variables.Variable)):
            return_elements.append(g.get_tensor_by_name(element.name))
        elif isinstance(element, ops.Operation):
            return_elements.append(g.get_operation_by_name(element.name))
        else:
            raise ValueError(
                'elements must consist of Tensor or Operation objects, got: ',
                str(element))
    return g, return_elements
예제 #3
0
 def testCopyGraph(self):
     graph = ops.Graph()
     with graph.as_default():
         a = constant_op.constant(1.0)
         b = variables.Variable(2.0)
         c = a + b
     graph_copy = copy_graph.CopyGraph(graph)
     # Ensure that the three original nodes are in the new graph.
     # import_meta_graph also adds a saver node to the graph which we don't care
     # about in this specific use case.
     for tensor in [a, b, c]:
         self._CompareNodeInGraph(tensor.op, graph_copy)
     # Test that the graph collections are the same.
     for key in graph.get_all_collection_keys():
         self.assertEqual(len(graph.get_collection(key)),
                          len(graph_copy.get_collection(key)),
                          'Collection %s differs.')
예제 #4
0
def _create_graph(input_graph, is_training, elements=None):
    """Returns a transformed training input_graph for simulated quantization.

  The forward pass has fake quantization ops inserted to simulate the error
  introduced by quantization.

  Args:
    input_graph: The tf.Graph to be transformed.
    is_training: Whether quantizing training or eval graph.
    elements: (Optional) List of Tensors and Operations in input_graph whose
        corresponding elements in the new graph will be returned.

  Returns:
    Returns a tuple(g, l) where:
    g is new tf.Graph that is rewritten for simulated quantization.
    l is a list of Tensors/Operations in g corresponding to the provided input
        elements.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
        tf.Operation.
  """
    # TODO(suharshs): Describe the process in more detail in the doc string.
    g = copy_graph.CopyGraph(input_graph)
    fold_batch_norms.FoldBatchNorms(g)
    quantize.Quantize(g, is_training=is_training)
    return_elements = []
    if elements is None:
        elements = []
    for element in elements:
        if isinstance(element, (ops.Tensor, variables.Variable)):
            return_elements.append(g.get_tensor_by_name(element.name))
        elif isinstance(element, ops.Operation):
            return_elements.append(g.get_operation_by_name(element.name))
        else:
            raise ValueError(
                'elements must consist of Tensor or Operation objects, got: ',
                str(element))
    return g, return_elements