示例#1
0
def insert_quant_op(graph, node_name, is_train):
    """Insert quantization operations to the specified activation node.

  Args:
  * graph: TensorFlow graph
  * node_name: activation node's name
  * is_train: insert training-related operations or not
  """

    # locate the node & activation operation
    for op in graph.get_operations():
        if node_name in [node.name for node in op.outputs]:
            tf.logging.info('op: {} / inputs: {} / outputs: {}'.format(
                op.name, [node.name for node in op.inputs],
                [node.name for node in op.outputs]))
            node = op.outputs[0]
            activation_op = op
            break

    # re-route the graph to insert quantization operations
    input_to_ops_map = input_to_ops.InputToOps(graph)
    consumer_ops = input_to_ops_map.ConsumerOperations(activation_op)
    node_quant = quant_ops.MovingAvgQuantize(
        node, is_training=is_train, num_bits=FLAGS.uqtf_activation_bits)
    nb_update_inputs = common.RerouteTensor(node_quant, node, consumer_ops)
    tf.logging.info('nb_update_inputs = %d' % nb_update_inputs)
示例#2
0
    def testNoConsumerOperations(self):
        graph = ops.Graph()
        with graph.as_default():
            input_tensor = array_ops.zeros((1, 2, 3, 4))

        input_to_ops_map = input_to_ops.InputToOps(graph)
        consumer_operations = input_to_ops_map.ConsumerOperations(
            input_tensor.op)

        self.assertEqual(0, len(consumer_operations))
示例#3
0
    def testOneConsumerOperation(self):
        graph = ops.Graph()
        with graph.as_default():
            input_tensor = array_ops.zeros((1, 2, 3, 4))
            output_tensor = nn_ops.relu6(input_tensor)

        input_to_ops_map = input_to_ops.InputToOps(graph)
        consumer_operations = input_to_ops_map.ConsumerOperations(
            input_tensor.op)

        self.assertEqual(consumer_operations, {output_tensor.op})
示例#4
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = common.RerouteTensor(
                folded_op.outputs[0],
                original_op.outputs[0],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
示例#5
0
def _QuantizeActivationLayers(quantized_ops,
                              graph,
                              is_training,
                              activation_bits=8,
                              ema_decay=0.999,
                              quant_delay=None,
                              vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                              scope=None,
                              use_qdq=False):
    """Quantize intermediate activation tensors after addition and multiplication.

  Args:
    quantized_ops: Set of previously quantized activation ops.
    graph: Graph to modify.
    is_training: Whether quantizing training graph or eval graph.
    activation_bits: Number of bits to use for quantizing activations.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    scope: The scope to be transformed. If it's not None, only the ops which are
      in this scope will be transformed.
    use_qdq: Use tf.quantize_and_dequantize_v3 (qdq) op instead of fake_quant_with_min_max_vars
      for quantization. The qdq op is used for scaling with no zero point.

  Raises:
    ValueError: When quantization fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)
    for op in (op for op in graph.get_operations()):
        if _CheckIfQuantizableOp(op, quantized_ops):
            logging.info(
                'Inserting fake quant op activation_%s_quant after %s',
                op.type, op.name)
            consumers = input_to_ops_map.ConsumerOperations(op)
            _InsertQuantOp(op.name,
                           'activation_' + op.type + '_quant',
                           op,
                           consumers,
                           is_training,
                           moving_avg=True,
                           ema_decay=ema_decay,
                           quant_delay=quant_delay,
                           vars_collection=vars_collection,
                           bits=activation_bits,
                           producer_scope=scope,
                           use_qdq=use_qdq)
示例#6
0
    def testSeveralConsumerOperations(self):
        graph = ops.Graph()
        with graph.as_default():
            input_tensor = array_ops.zeros((1, 2, 3, 4))
            output_tensor_1 = nn_ops.relu6(input_tensor)
            output_tensor_2 = input_tensor + output_tensor_1
            output_tensor_3 = input_tensor * output_tensor_2

        input_to_ops_map = input_to_ops.InputToOps(graph)
        consumer_operations = input_to_ops_map.ConsumerOperations(
            input_tensor.op)

        self.assertEqual(
            consumer_operations,
            {output_tensor_1.op, output_tensor_2.op, output_tensor_3.op})
示例#7
0
def FoldBatchNorms(graph):
    """Finds batch norm layers in the graph, folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
    # Fail immediately when the graph contains unsupported fused batch norm ops.
    if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
        raise ValueError('Fused batch norm is not supported')

    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(graph,
                                                 bn,
                                                 has_scaling=has_scaling)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = graph_editor.reroute_ts(
                [folded_op.outputs[0]], [original_op.outputs[0]],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = graph_editor.reroute_ts(
            [folded_op.outputs[0]], [original_op.outputs[0]],
            can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
示例#8
0
    def graph_rewrite_fn():
        """Function to quantize weights and activation of the default graph."""
        if (graph_rewriter_config.quantization.weight_bits != 8
                or graph_rewriter_config.quantization.activation_bits != 8):
            raise ValueError('Only 8bit quantization is supported')

        graph = tf.get_default_graph()

        # Insert custom quant ops.
        if quant_overrides_config is not None:
            input_to_ops_map = input_to_ops.InputToOps(graph)
            for q in quant_overrides_config.quant_configs:
                producer = graph.get_operation_by_name(q.op_name)
                if producer is None:
                    raise ValueError('Op name does not exist in graph.')
                context = _get_context_from_op(producer)
                consumers = input_to_ops_map.ConsumerOperations(producer)
                if q.fixed_range:
                    _insert_fixed_quant_op(
                        context,
                        q.quant_op_name,
                        producer,
                        consumers,
                        init_min=q.min,
                        init_max=q.max,
                        quant_delay=q.delay if is_training else 0)
                else:
                    raise ValueError('Learned ranges are not yet supported.')

        # Quantize the graph by inserting quantize ops for weights and activations
        if is_training:
            tf.contrib.quantize.experimental_create_training_graph(
                input_graph=graph,
                quant_delay=graph_rewriter_config.quantization.delay,
                freeze_bn_delay=graph_rewriter_config.quantization.delay)
        else:
            tf.contrib.quantize.experimental_create_eval_graph(
                input_graph=graph,
                quant_delay=graph_rewriter_config.quantization.delay
                if not is_export else 0)

        tf.contrib.layers.summarize_collection('quant_vars')
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        # TODO: generalise
        activation = input_to_ops_map.ConsumerOperations(original_op).pop()
        # assert any(activation.type == o or
        #            o.lower() in activation.name.split("/")[-1].lower()
        #            for o in (common._ACTIVATION_OP_SUFFIXES + ["Add"]))

        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[activation])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % activation.name)
示例#10
0
    def __init__(self,
                 graph,
                 weight_bits,
                 weight_narrow_range,
                 activation_bits,
                 ema_decay=0.999,
                 quant_delay=None,
                 vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
                 is_training=True,
                 quantize_folded_weights_use_ema=False):
        """Initializes context to hold references needed for quantization.

    Args:
      graph: Graph to modify.
      weight_bits: Number of bits to use for quantizing weights.
      weight_narrow_range: Whether to use a more efficient narrow range for
        weights quantization.  With weight_narrow_range true, the range is
        [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
      activation_bits: Number of bits to use for quantizing activations.
      ema_decay: (Optional) Float, EMA decay parameter.
      quant_delay: (Optional, default None) Int, count of global steps for which
        to delay quantization.  This helps weights stabilize at the start of
        training.
      vars_collection: (Optional) Collection where to store the variables for
        quantization interval ends.
      is_training: (Optional) Whether quantizing training or eval graph.
      quantize_folded_weights_use_ema: (Optional, default False) Whether to
        quantize weights after batchnorm-folding with exponential average
        quantization.
    """
        self.graph = graph
        self.weight_bits = weight_bits
        self.weight_narrow_range = weight_narrow_range
        self.activation_bits = activation_bits
        self.ema_decay = ema_decay
        self.quant_delay = quant_delay
        self.vars_collection = vars_collection
        self.is_training = is_training
        self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
        self.input_to_ops_map = input_to_ops.InputToOps(graph)
        self.add_contexts = set()
示例#11
0
文件: calib.py 项目: pyjennings/tf_pg
def ModifyForCalibration(graph,
                         vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                         scope=None):
    """Update graph with calibration operations.

  Args:
    graph: Graph to modify.
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed
  Raises:
    ValueError:When modification fails.
  """
    if scope and not scope.endswith('/'):
        scope += '/'

    input_to_ops_map = input_to_ops.InputToOps(graph)
    calib_ops = []

    for layer_match in _FindLayersToCalibration(graph):
        # Quantize the weights.
        context = _GetContextFromOp(layer_match.layer_op)

        # If `scope` is given, only quantize it if the consumer of weights
        # (the layer op) is in the right scope.
        if layer_match.weight_tensor is not None:
            print(layer_match.weight_tensor.op)
            calib_op = _InsertCalibOp(context,
                                      'weights_calib',
                                      layer_match.weight_tensor.op,
                                      input_to_ops_map.ConsumerOperations(
                                          layer_match.weight_tensor.op),
                                      vars_collection=vars_collection,
                                      consumer_scope=scope)
            calib_ops.append(calib_op)

    return calib_ops
示例#12
0
def Quantize(graph,
             is_training,
             weight_bits=8,
             activation_bits=8,
             ema_decay=0.999,
             quant_delay=None,
             vars_collection=ops.GraphKeys.GLOBAL_VARIABLES):
    """Updates graph with quantization operations.

  Args:
    graph: Graph to modify.
    is_training: Whether quantizing training graph or eval graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
  Raises:
    ValueError: When quantization fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)
    for layer_match in _FindLayersToQuantize(graph):
        # Quantize the weights.
        context = _GetContextFromOp(layer_match.layer_op)
        _InsertQuantOp(context,
                       'weights_quant',
                       layer_match.weight_tensor.op, [layer_match.layer_op],
                       is_training,
                       moving_avg=False,
                       ema_decay=ema_decay,
                       quant_delay=quant_delay,
                       narrow_range=True,
                       vars_collection=vars_collection,
                       bits=weight_bits)

        # Quantize the activations.
        consumer_ops = input_to_ops_map.ConsumerOperations(
            layer_match.activation_op)
        add_context = context
        if layer_match.bypass_op:
            add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
        _InsertQuantOp(add_context,
                       'act_quant',
                       layer_match.activation_op,
                       consumer_ops,
                       is_training,
                       moving_avg=True,
                       ema_decay=ema_decay,
                       quant_delay=quant_delay,
                       vars_collection=vars_collection,
                       bits=activation_bits,
                       init_min=0.0)

        # Quantize the inputs and output to the bypass (if it exists). The input to
        # the bypass is the bias add, and the output is the activation.
        if layer_match.bypass_op is not None:
            _InsertQuantOp(context,
                           'conv_quant',
                           layer_match.bias_add_op, [layer_match.bypass_op],
                           is_training,
                           moving_avg=True,
                           ema_decay=ema_decay,
                           quant_delay=quant_delay,
                           vars_collection=vars_collection,
                           bits=activation_bits)
            _InsertQuantOp(add_context,
                           'add_quant',
                           layer_match.bypass_op,
                           input_to_ops_map.ConsumerOperations(
                               layer_match.bypass_op),
                           is_training,
                           moving_avg=True,
                           ema_decay=ema_decay,
                           quant_delay=quant_delay,
                           vars_collection=vars_collection,
                           bits=activation_bits)

        if layer_match.post_activation_bypass_op is not None:
            _InsertQuantOp(add_context,
                           'post_activation_bypass_quant',
                           layer_match.post_activation_bypass_op,
                           input_to_ops_map.ConsumerOperations(
                               layer_match.post_activation_bypass_op),
                           is_training,
                           moving_avg=True,
                           ema_decay=ema_decay,
                           quant_delay=quant_delay,
                           vars_collection=vars_collection,
                           bits=activation_bits)
示例#13
0
def Quantize(graph,
             is_training,
             weight_bits=8,
             activation_bits=8,
             ema_decay=0.999,
             quant_delay=None,
             vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
             scope=None):
    """Updates graph with quantization operations.

  Currently we quantize the following tensors:
  * Conv/MatMul: Quantize the weights if it matches.
  * Activation: Quantize the output if it matches.
  * Bypass/Post-activation Bypass: Quantize both input and output
    if it matches.

  Args:
    graph: Graph to modify.
    is_training: Whether quantizing training graph or eval graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.
  Raises:
    ValueError: When quantization fails.
  """
    if scope and not scope.endswith('/'):
        scope += '/'

    input_to_ops_map = input_to_ops.InputToOps(graph)
    for layer_match in _FindLayersToQuantize(graph):
        # Quantize the weights.
        context = _GetContextFromOp(layer_match.layer_op)

        # If `scope` is given, only quantize it if the consumer of weights
        # (the layer op) is in the right scope.
        _InsertQuantOp(context,
                       'weights_quant',
                       layer_match.weight_tensor.op, [layer_match.layer_op],
                       is_training,
                       moving_avg=False,
                       ema_decay=ema_decay,
                       quant_delay=quant_delay,
                       narrow_range=True,
                       vars_collection=vars_collection,
                       bits=weight_bits,
                       consumer_scope=scope)

        # Quantize the activations.
        consumer_ops = input_to_ops_map.ConsumerOperations(
            layer_match.activation_op)
        add_context = context
        if layer_match.bypass_op:
            add_context = re.search(r'^(.*)/([^/]+)', context).group(1)

        # If `scope` is given, only quantize it if the producer of weights
        # (usually it's the layer op) is in the right scope.
        _InsertQuantOp(add_context,
                       'act_quant',
                       layer_match.activation_op,
                       consumer_ops,
                       is_training,
                       moving_avg=True,
                       ema_decay=ema_decay,
                       quant_delay=quant_delay,
                       vars_collection=vars_collection,
                       bits=activation_bits,
                       init_min=0.0,
                       producer_scope=scope)

        # Quantize the inputs and output to the bypass (if it exists). The input to
        # the bypass is the bias add, and the output is the activation.
        if layer_match.bypass_op is not None:
            # If `scope` is given, only quantize it if the both the producer and the
            # consumer are in the right scope.
            _InsertQuantOp(context,
                           'conv_quant',
                           layer_match.bias_add_op, [layer_match.bypass_op],
                           is_training,
                           moving_avg=True,
                           ema_decay=ema_decay,
                           quant_delay=quant_delay,
                           vars_collection=vars_collection,
                           bits=activation_bits,
                           producer_scope=scope,
                           consumer_scope=scope)
            # Make sure the op following this isn't an activation. In which case, we
            # shouldn't quantize it, since the activation will be Fused into the
            # Add at inference time.
            consumers = input_to_ops_map.ConsumerOperations(
                layer_match.bypass_op)
            if any(
                [consumer.type in _ACTIVATION_TYPES
                 for consumer in consumers]):
                logging.info(
                    'Skipping %s, because its followed by an activation.',
                    layer_match.bypass_op.name)
            else:
                _InsertQuantOp(add_context,
                               'add_quant',
                               layer_match.bypass_op,
                               input_to_ops_map.ConsumerOperations(
                                   layer_match.bypass_op),
                               is_training,
                               moving_avg=True,
                               ema_decay=ema_decay,
                               quant_delay=quant_delay,
                               vars_collection=vars_collection,
                               bits=activation_bits,
                               producer_scope=scope,
                               consumer_scope=scope)

        # Quantize bypass ops that occur after the activation.
        if layer_match.post_activation_bypass_op is not None:
            post_activation_bypass_context = re.search(
                r'^(.*)/([^/]+)',
                layer_match.post_activation_bypass_op.name).group(1)
            # If `scope` is given, only quantize it if the producer is in the right
            # scope.
            # Make sure the op following this isn't an activation. In which case, we
            # shouldn't quantize it, since the activation will be Fused into the
            # Add at inference time.
            consumers = input_to_ops_map.ConsumerOperations(
                layer_match.post_activation_bypass_op)
            if any(
                [consumer.type in _ACTIVATION_TYPES
                 for consumer in consumers]):
                logging.info(
                    'Skipping %s, because its followed by an activation.',
                    layer_match.post_activation_bypass_op.name)
            else:
                _InsertQuantOp(post_activation_bypass_context,
                               'post_activation_bypass_quant',
                               layer_match.post_activation_bypass_op,
                               consumers,
                               is_training,
                               moving_avg=True,
                               ema_decay=ema_decay,
                               quant_delay=quant_delay,
                               vars_collection=vars_collection,
                               bits=activation_bits,
                               producer_scope=scope)
示例#14
0
def Quantize(graph,
             is_training,
             weight_bits=8,
             activation_bits=8,
             symmetric=False,
             ema_decay=0.999,
             quant_delay=None,
             vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
             scope=None):
  """Updates graph with quantization operations.

  Currently we quantize the following tensors:
  * Conv/MatMul: Quantize the weights if it matches.
  * Activation: Quantize the output if it matches.
  * Bypass/Post-activation Bypass: Quantize both input and output
    if it matches.

  Args:
    graph: Graph to modify.
    is_training: Whether quantizing training graph or eval graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    symmetric: (Optional) If true, use symmetric quantization limits instead of
      training the minimum and maximum of each quantization range separately.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.
  Raises:
    ValueError: When quantization fails.
  """
  if scope and not scope.endswith('/'):
    scope += '/'

  input_to_ops_map = input_to_ops.InputToOps(graph)
  quantized_ops = set()
  for layer_match in _FindLayersToQuantize(graph):
    # Quantize the weights.
    context = _GetContextFromOp(layer_match.layer_op)

    op_re = re.search(r'^(.*)/([^/]+)', layer_match.layer_op.name)
    if op_re:
      op_name = op_re.group(2)
    else:
     op_name = layer_match.layer_op.type
    # print(op_name)

    _InsertQuantOp(
        context,
        op_name + '_quant',
        layer_match.layer_op,
        input_to_ops_map.ConsumerOperations(layer_match.layer_op),
        is_training,
        moving_avg=True,
        ema_decay=ema_decay,
        quant_delay=quant_delay,
        vars_collection=vars_collection,
        bits=activation_bits,
        symmetric=symmetric,
        producer_scope=scope)
示例#15
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        print("found unfused batchnarm")
        raise Exception("Not Implemented")

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = common.RerouteTensor(
                folded_op.outputs[0],
                original_op.outputs[0],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        # Changes to make sure that the correct scope is selected for the bypass add
        # The rule here is that if the scope is of the form: str1/str2 for the
        # batch norm,
        # the bypass add is at scope str1. If bn is of scope just str1, then the
        # bypass add is at scope ''.
        # If there is no batch norm, then there is no bypass add.
        add_bypass_ctx = ''
        if bn:
            try:
                add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
            except AttributeError:
                add_bypass_ctx = ''

        if add_bypass_ctx:
            add_bypass_ctx = add_bypass_ctx + '/'

        add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)