Example #1
0
    def testRerouteTensor(self):
        a = constant_op.constant(1, name='a')
        b = constant_op.constant(2, name='b')
        c = constant_op.constant(3, name='c')
        d = constant_op.constant(4, name='d')

        add_ac = math_ops.add(a, c)
        add_ad = math_ops.add(a, d)

        # Ensure that before rerouting the inputs are what we think.
        self._CheckOpHasInputs(add_ac.op, [a, c])
        self._CheckOpHasInputs(add_ad.op, [a, d])

        # references to tensor a should be replaced with b for all ops in
        # can_modify. This means add_ac will be changed but add_ad will not.
        common.RerouteTensor(b, a, can_modify=[add_ac.op])
        self._CheckOpHasInputs(add_ac.op, [b, c])
        self._CheckOpHasInputs(add_ad.op, [a, d])
Example #2
0
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, true if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    for match in _FindFusedBatchNorms(graph):
        scope, sep, _ = match.layer_op.name.rpartition('/')
        # Make sure new ops are added to `graph` and put on the same device as
        # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
        # named `scope`. Otherwise, TF creates a unique scope whose name starts with
        # `scope`.
        with graph.as_default(), graph.name_scope(scope + sep):
            with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
                # new weights = old weights * gamma / sqrt(variance + epsilon)
                # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
                multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
                    match.variance_tensor + match.bn_op.get_attr('epsilon'))
                bias_tensor = math_ops.subtract(match.beta_tensor,
                                                match.mean_tensor *
                                                multiplier_tensor,
                                                name='bias')

                correction_scale, correction_recip, correction_offset = None, None, None
                if is_training:
                    correction_scale, correction_recip, correction_offset = (
                        _ComputeBatchNormCorrections(
                            context='',
                            match=match,
                            freeze_batch_norm_delay=freeze_batch_norm_delay))
                # The shape of depthwise weights is different, so we need to reshape the
                # multiplier_tensor to ensure that the scaled_weight_tensor has the
                # expected shape.
                weights = match.weight_tensor
                if match.layer_op.type == 'DepthwiseConv2dNative':
                    new_shape = [
                        match.weight_tensor.get_shape().as_list()[2],
                        match.weight_tensor.get_shape().as_list()[3]
                    ]
                    multiplier_tensor = array_ops.reshape(multiplier_tensor,
                                                          new_shape,
                                                          name='scale_reshape')

                    if correction_scale is not None:
                        correction_scale = array_ops.reshape(
                            correction_scale,
                            new_shape,
                            name='correction_reshape')

            if correction_scale is not None:
                weights = math_ops.multiply(correction_scale,
                                            weights,
                                            name='correction_mult')

            scaled_weight_tensor = math_ops.multiply(weights,
                                                     multiplier_tensor,
                                                     name='mul_fold')

            new_layer_tensor = _CloneWithNewOperands(match.layer_op,
                                                     match.input_tensor,
                                                     scaled_weight_tensor,
                                                     match.batch_to_space_op)

            if correction_recip is not None:
                new_layer_tensor = math_ops.multiply(correction_recip,
                                                     new_layer_tensor,
                                                     name='post_conv_mul')
                new_layer_tensor = math_ops.add(new_layer_tensor,
                                                (correction_offset),
                                                'correction_add')

            bias_add_tensor = math_ops.add(new_layer_tensor,
                                           bias_tensor,
                                           name='add_fold')

            nodes_modified_count = common.RerouteTensor(
                bias_add_tensor, match.output_tensor)
            if nodes_modified_count == 0:
                raise ValueError(
                    'Folding batch norms failed, %s had no outputs.' %
                    match.output_tensor.name)
Example #3
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = common.RerouteTensor(
                folded_op.outputs[0],
                original_op.outputs[0],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        # Changes to make sure that the correct scope is selected for the bypass add
        # The rule here is that if the scope is of the form: str1/str2 for the
        # batch norm,
        # the bypass add is at scope str1. If bn is of scope just str1, then the
        # bypass add is at scope ''.
        # If there is no batch norm, then there is no bypass add.
        add_bypass_ctx = ''
        if bn:
            try:
                add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
            except AttributeError:
                add_bypass_ctx = ''

        if add_bypass_ctx:
            add_bypass_ctx = add_bypass_ctx + '/'

        add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'AddV2')
        nodes_modified_count = common.RerouteTensor(folded_op.outputs[0],
                                                    original_op.outputs[0],
                                                    can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
Example #4
0
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
    """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.


  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

    g = ops.get_default_graph()
    prefix = '' if not context else context
    with g.name_scope(prefix + 'batch_norm_correction'):
        recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor +
                                        match.batch_epsilon)
        recip_sigma = math_ops.rsqrt(match.variance_tensor +
                                     match.batch_epsilon)
        correction_scale = math_ops.divide(recip_sigma_mv,
                                           recip_sigma,
                                           name='scale_compute')
        correction_scale = array_ops.identity(correction_scale,
                                              name='correction_scale')
        correction_recip = math_ops.reciprocal(correction_scale,
                                               name='reciprocal_compute')
        correction_offset = math_ops.multiply(
            match.gamma_tensor,
            match.mean_tensor * recip_sigma -
            match.moving_mean_tensor * recip_sigma_mv,
            name='offset_compute')

        if freeze_batch_norm_delay is not None:
            use_mv_avg = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                freeze_batch_norm_delay,
                name='use_moving_average')
        else:
            use_mv_avg = False

        bn_decay_zero = 0.0
        bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
        bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

        bn_decay_mean_out = utils.smart_cond(
            use_mv_avg,
            lambda: bn_decay_zero,
            lambda: match.bn_decay_mean_tensor,
            name='freeze_moving_mean')

        common.RerouteTensor(bn_decay_mean_out,
                             match.bn_decay_mean_tensor,
                             can_modify=bn_decay_mean_consumers)

        bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
        bn_decay_var_out = utils.smart_cond(use_mv_avg,
                                            lambda: bn_decay_zero,
                                            lambda: match.bn_decay_var_tensor,
                                            name='freeze_moving_var')
        common.RerouteTensor(bn_decay_var_out,
                             match.bn_decay_var_tensor,
                             can_modify=bn_decay_var_consumers)

        correction_recip = utils.smart_cond(
            use_mv_avg,
            lambda: array_ops.ones(correction_scale.shape),
            lambda: correction_recip,
            name='correction_recip')

        correction_offset = utils.smart_cond(
            use_mv_avg,
            lambda: correction_offset,
            lambda: array_ops.zeros(correction_offset.shape),
            name='correction_offset')
    return correction_scale, correction_recip, correction_offset
Example #5
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   symmetric=False,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    symmetric: (Optional) If true, use symmetric quantization limits instead of
      training the minimum and maximum of each quantization range separately.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of consumer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    if _FollowedByFakeQuant(inputs):
        return

    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             symmetric=symmetric,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             symmetric=symmetric,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    if consumers:
        tensors_modified_count = common.RerouteTensor(quant,
                                                      inputs,
                                                      can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))