Exemple #1
0
    def _TestCreateOrGetQuantizationStep(self, use_resource):
        g = ops.Graph()
        with session.Session(graph=g) as sess:
            variable_scope.get_variable_scope().set_use_resource(use_resource)
            quantization_step_tensor = common.CreateOrGetQuantizationStep()

            # Check that operations are added to the graph.
            num_nodes = len(g.get_operations())
            self.assertGreater(num_nodes, 0)

            # Check that getting the quantization step doesn't change the graph.
            get_quantization_step_tensor = common.CreateOrGetQuantizationStep()
            self.assertEqual(quantization_step_tensor,
                             get_quantization_step_tensor)
            self.assertEqual(num_nodes, len(g.get_operations()))

            # Ensure that running the graph increments the quantization step.
            sess.run(variables.global_variables_initializer())
            step_val = sess.run(quantization_step_tensor)
            self.assertEqual(step_val, 1)

            # Ensure that even running a graph that depends on the quantization step
            # multiple times only executes it once.
            a = quantization_step_tensor + 1
            b = a + quantization_step_tensor
            _, step_val = sess.run([b, quantization_step_tensor])
            self.assertEqual(step_val, 2)
  def delayed_quant(self,
                    inputs,
                    quant_min,
                    quant_max,
                    per_channel=False,
                    num_bits=4,
                    narrow_range=True,
                    quant_delay=None):
    """Turn on fake quantization after certain delay."""

    # The fake quantization operation does not support tf.float16 yet.
    quant = self._fake_quant_with_min_max_vars(
        tf.cast(inputs, tf.float32),
        tf.cast(quant_min, tf.float32),
        tf.cast(quant_max, tf.float32),
        per_channel=per_channel,
        num_bits=num_bits,
        narrow_range=narrow_range)
    quant = tf.cast(quant, self.dtype)
    if quant_delay and quant_delay > 0:
      activate_quant = tf.greater_equal(
          common.CreateOrGetQuantizationStep(),
          quant_delay,
          name='activate_quant')
      quant = tf.cond(
          activate_quant,
          lambda: quant,
          lambda: inputs,
          name='delayed_quant')
    return quant
Exemple #3
0
def _insert_fixed_quant_op(context,
                           name,
                           producer,
                           consumers,
                           init_min=-6.0,
                           init_max=6.0,
                           quant_delay=None):
    """Adds a fake quant op with fixed ranges.

  Args:
    context: The parent scope of the op to be quantized.
    name: The name of the fake quant op.
    producer: The producer op to be quantized.
    consumers: The consumer ops to the producer op.
    init_min: The minimum range for the fake quant op.
    init_max: The maximum range for the fake quant op.
    quant_delay: Number of steps to wait before activating the fake quant op.

  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    name_prefix = name if not context else context + '/' + name
    inputs = producer.outputs[0]
    quant = quant_ops.FixedQuantize(inputs,
                                    init_min=init_min,
                                    init_max=init_max,
                                    scope=name_prefix)

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    if consumers:
        tensors_modified_count = common.RerouteTensor(quant,
                                                      inputs,
                                                      can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
Exemple #4
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_prefix = common.DropStringPrefix(name_prefix,
                                          ops.get_name_scope() + '/')

    inputs = producer.outputs[0]
    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    nodes_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                   can_modify=consumers)
    if nodes_modified_count != len(consumers):
        raise ValueError('Some inputs not quantized for ops: [%s]' %
                         ', '.join([consumer.name for consumer in consumers]))
Exemple #5
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    fake_quant_ops = set(
        ['FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs'])
    if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
        return

    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    if consumers:
        tensors_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                         can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
    """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.


  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

    g = ops.get_default_graph()
    prefix = '' if not context else context
    with g.name_scope(prefix + 'batch_norm_correction'):
        recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor +
                                        match.batch_epsilon)
        recip_sigma = math_ops.rsqrt(match.variance_tensor +
                                     match.batch_epsilon)
        correction_scale = math_ops.divide(recip_sigma_mv,
                                           recip_sigma,
                                           name='scale_compute')
        correction_scale = array_ops.identity(correction_scale,
                                              name='correction_scale')
        correction_recip = math_ops.reciprocal(correction_scale,
                                               name='reciprocal_compute')
        mv = match.moving_mean_tensor  #if match.moving_mean_tensor is not None else 0
        correction_offset = math_ops.multiply(match.gamma_tensor,
                                              match.mean_tensor * recip_sigma -
                                              mv,
                                              name='offset_compute')

        if freeze_batch_norm_delay is not None:
            use_mv_avg = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                freeze_batch_norm_delay,
                name='use_moving_average')
        else:
            use_mv_avg = False

        bn_decay_zero = 0.0
        bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
        bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

        bn_decay_mean_out = utils.smart_cond(
            use_mv_avg,
            lambda: bn_decay_zero,
            lambda: match.bn_decay_mean_tensor,
            name='freeze_moving_mean')

        common.RerouteTensor(bn_decay_mean_out,
                             match.bn_decay_mean_tensor,
                             can_modify=bn_decay_mean_consumers)

        bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
        bn_decay_var_out = utils.smart_cond(use_mv_avg,
                                            lambda: bn_decay_zero,
                                            lambda: match.bn_decay_var_tensor,
                                            name='freeze_moving_var')
        common.RerouteTensor(bn_decay_var_out,
                             match.bn_decay_var_tensor,
                             can_modify=bn_decay_var_consumers)

        correction_recip = utils.smart_cond(
            use_mv_avg,
            lambda: array_ops.ones(correction_scale.shape),
            lambda: correction_recip,
            name='correction_recip')

        correction_offset = utils.smart_cond(
            use_mv_avg,
            lambda: correction_offset,
            lambda: array_ops.zeros(correction_offset.shape),
            name='correction_offset')
    return correction_scale, correction_recip, correction_offset
  def last_value_quantize(self,
                          inputs,
                          per_channel=False,
                          init_min=-6.0,
                          init_max=6.0,
                          name_prefix='FixedValueQuant',
                          reuse=None,
                          is_training=False,
                          num_bits=8,
                          narrow_range=False,
                          relative_quantile=0,
                          freeze=False,
                          quant_delay=False):
    """Adds a layer that collects quantization ranges as last input ranges.

    LastValueQuantize creates variables called 'min' and 'max', representing the
    interval used for quantization and clamping.

    Args:
      inputs: a tensor containing values to be quantized.
      per_channel: (Optional) a boolean specifying whether to use different
        quantization ranges per output channel.
      init_min: a float scalar, the initial value for variable min.
      init_max: a float scalar, the initial value for variable max.
      name_prefix: name_prefix for created nodes.
      reuse: whether or not the layer and its variables should be reused. To be
        able to reuse the layer scope must be given.
      is_training: Whether the op is applied to a training or eval graph.
      num_bits: Number of bits to use for quantization, must be between 2 and 8.
      narrow_range: Whether to use the narrow quantization range
        [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
      relative_quantile: Specify the location of quantization min and max
        parameters. relative_quantile = 0 is equivalent to using min and max
        of input; relative_quantile = 1 set min and max the optimal location
        assuming the input distribution is uniform. In reality, a good value
        should be in the range [0 1].
      freeze: If True, the min and max variables are calculated once at the
        begining of training and then freeze. This is used for quantized
        fine-tuning of a pretrained checkpoint. If False, the min and max are
        calculated and updated every cycle.
      quant_delay: The number of global steps after which the fake quantization
        are turned on. Used for performing fine-tuning experiment without
        starting from a pre-trained checkpoint.
    Returns:
      a tensor containing quantized values.
    """

    with tf.variable_scope(
        None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
      scope.set_partitioner(None)
      input_shape = inputs.get_shape()
      input_dim = len(input_shape)
      if per_channel:
        # Only support quantizing 1-, 2- and 4-dimensional tensors.
        assert input_dim in [1, 2, 4]
        min_max_shape = [input_shape[-1]]
      else:
        min_max_shape = []

      min_var = tf.get_variable('min',
                                min_max_shape,
                                tf.float32,
                                initializer=tf.constant_initializer(init_min),
                                trainable=False)
      max_var = tf.get_variable('max',
                                min_max_shape,
                                tf.float32,
                                initializer=tf.constant_initializer(init_max),
                                trainable=False)
      if not is_training:
        return self.delayed_quant(
            inputs,
            min_var,
            max_var,
            per_channel=per_channel,
            num_bits=num_bits,
            narrow_range=narrow_range,
            quant_delay=None)

      if per_channel:
        if input_dim == 2:
          reduce_dims = [0]
        elif input_dim == 4:
          reduce_dims = [0, 1, 2]

      if num_bits >= 4:
        quantile = 0
      else:
        quantile = (1.0 / 2.0**(num_bits + 1.0)) * relative_quantile * 100

      if per_channel:
        if input_dim >= 2:
          batch_min = tfp.stats.percentile(
              inputs, q=quantile, axis=reduce_dims, name='BatchMin')
        else:
          batch_min = inputs
      else:
        batch_min = tfp.stats.percentile(
            inputs, q=quantile, name='BatchMin')

      if per_channel:
        if input_dim >= 2:
          batch_max = tfp.stats.percentile(
              inputs, q=100 - quantile, axis=reduce_dims, name='BatchMax')
        else:
          batch_max = inputs
      else:
        batch_max = tfp.stats.percentile(
            inputs, q=100 - quantile, name='BatchMax')

      if narrow_range:
        multiplier = 1.0
      else:
        multiplier = 1.0 + 1.0 / (2.0**(num_bits-1.0) - 1.0)

      batch_abs_max = tf.maximum(tf.abs(batch_min), tf.abs(batch_max))

      if narrow_range:
        batch_adjusted_min = 0 - batch_abs_max
      else:
        multiplier = 1.0 + 1.0 / (2.0**(num_bits-1.0) - 1.0)
        batch_adjusted_min = 0 - tf.scalar_mul(multiplier, batch_abs_max)

      batch_abs_max = tf.cast(batch_abs_max, tf.float32)
      batch_adjusted_min = tf.cast(batch_adjusted_min, tf.float32)

      if freeze:
        def make_var_op(var):
          def f():
            return var
          return f

        quant_step = common.CreateOrGetQuantizationStep()
        min_max_assign = tf.less_equal(
            quant_step, 1, name='MinMaxAssign')
        min_value = tf.cond(min_max_assign,
                            make_var_op(batch_adjusted_min),
                            make_var_op(min_var),
                            name='AssignMinCond')
        max_value = tf.cond(min_max_assign,
                            make_var_op(batch_abs_max),
                            make_var_op(max_var),
                            name='AssignMaxCond')
      else:
        min_value = batch_adjusted_min
        max_value = batch_abs_max

      assign_min = tf.assign(min_var, min_value)
      assign_max = tf.assign(max_var, max_value)

      return self.delayed_quant(
          inputs,
          assign_min,
          assign_max,
          per_channel=per_channel,
          num_bits=num_bits,
          narrow_range=narrow_range,
          quant_delay=quant_delay)
Exemple #8
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   symmetric=False,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    symmetric: (Optional) If true, use symmetric quantization limits instead of
      training the minimum and maximum of each quantization range separately.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    fake_quant_op = _GetFollowingFakeQuantOp(inputs)

    # If we find that we are attempting to insert a fake quant op following
    # a fake quant, we skip inserting a fake quant op

    if fake_quant_op is None:
        if moving_avg:
            quant = (quant_ops.MovingAvgQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                ema_decay=ema_decay,
                is_training=is_training,
                num_bits=bits,
                symmetric=symmetric,
                narrow_range=narrow_range,
                vars_collection=vars_collection,
                name_prefix=name_prefix))
        else:
            quant = (quant_ops.LastValueQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                is_training=is_training,
                num_bits=bits,
                symmetric=symmetric,
                narrow_range=narrow_range,
                vars_collection=vars_collection,
                name_prefix=name_prefix))

        if quant_delay and quant_delay > 0:
            activate_quant = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                quant_delay,
                name=name_prefix + '/activate_quant')
            quant = control_flow_ops.cond(activate_quant,
                                          lambda: quant,
                                          lambda: inputs,
                                          name=name_prefix + '/delayed_quant')
    else:
        #  return
        # If a fake quant op is present already, make sure that
        # any downstream use of the tensor reroutes to the appropriate quantized
        # tensor. If there is no quant_delay, this is simply the output of the
        # fake quant op. If there is a quant delay, we reroute to the output
        # of the delayed quant operation, which inserts quantization only after
        # a specified quant_delay

        quant = fake_quant_op.outputs[0]
        if quant_delay and quant_delay > 0:
            name_prefix = '/'.join(quant.name.split('/')[:-1])
            quant = quant.graph.get_tensor_by_name(name_prefix +
                                                   '/delayed_quant/Merge:0')
        pruned_consumer_set = set()
        for consumer in consumers:
            fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
            if (fake_quant_dest_op is None
                    or fake_quant_dest_op.name != fake_quant_op.name):
                pruned_consumer_set.add(consumer)
        consumers = pruned_consumer_set

        # If we have
        # input->pass_through->fake_quant
        # there is nothing to reroute.
        #
        # If we have
        #  input-> pass_through->fake_quant
        #                |-> consumer
        # Then we reroute such that:
        # input-> pass_through->fake_quant
        #                            |-> consumer
    if consumers:
        tensors_modified_count = common.RerouteTensor(quant,
                                                      inputs,
                                                      can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))