def _TestCreateOrGetQuantizationStep(self, use_resource): g = ops.Graph() with session.Session(graph=g) as sess: variable_scope.get_variable_scope().set_use_resource(use_resource) quantization_step_tensor = common.CreateOrGetQuantizationStep() # Check that operations are added to the graph. num_nodes = len(g.get_operations()) self.assertGreater(num_nodes, 0) # Check that getting the quantization step doesn't change the graph. get_quantization_step_tensor = common.CreateOrGetQuantizationStep() self.assertEqual(quantization_step_tensor, get_quantization_step_tensor) self.assertEqual(num_nodes, len(g.get_operations())) # Ensure that running the graph increments the quantization step. sess.run(variables.global_variables_initializer()) step_val = sess.run(quantization_step_tensor) self.assertEqual(step_val, 1) # Ensure that even running a graph that depends on the quantization step # multiple times only executes it once. a = quantization_step_tensor + 1 b = a + quantization_step_tensor _, step_val = sess.run([b, quantization_step_tensor]) self.assertEqual(step_val, 2)
def delayed_quant(self, inputs, quant_min, quant_max, per_channel=False, num_bits=4, narrow_range=True, quant_delay=None): """Turn on fake quantization after certain delay.""" # The fake quantization operation does not support tf.float16 yet. quant = self._fake_quant_with_min_max_vars( tf.cast(inputs, tf.float32), tf.cast(quant_min, tf.float32), tf.cast(quant_max, tf.float32), per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range) quant = tf.cast(quant, self.dtype) if quant_delay and quant_delay > 0: activate_quant = tf.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name='activate_quant') quant = tf.cond( activate_quant, lambda: quant, lambda: inputs, name='delayed_quant') return quant
def _insert_fixed_quant_op(context, name, producer, consumers, init_min=-6.0, init_max=6.0, quant_delay=None): """Adds a fake quant op with fixed ranges. Args: context: The parent scope of the op to be quantized. name: The name of the fake quant op. producer: The producer op to be quantized. consumers: The consumer ops to the producer op. init_min: The minimum range for the fake quant op. init_max: The maximum range for the fake quant op. quant_delay: Number of steps to wait before activating the fake quant op. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ name_prefix = name if not context else context + '/' + name inputs = producer.outputs[0] quant = quant_ops.FixedQuantize(inputs, init_min=init_min, init_max=init_max, scope=name_prefix) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') if consumers: tensors_modified_count = common.RerouteTensor(quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError( 'No inputs quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') inputs = producer.outputs[0] if moving_avg: quant = (quant_ops.MovingAvgQuantize(inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = (quant_ops.LastValueQuantize(inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') nodes_modified_count = graph_editor.reroute_ts([quant], [inputs], can_modify=consumers) if nodes_modified_count != len(consumers): raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False, producer_scope=None, consumer_scope=None): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. consumer_scope: The restriction of producer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ if producer_scope and not producer.name.startswith(producer_scope): logging.info( '_InsertQuantOp ignores context="%s" name="%s" ' 'because producer "%s" is not in scope "%s"', context, name, producer.name, producer_scope) return if consumer_scope: consumers_in_scope = [] for consumer in consumers: if consumer.name.startswith(consumer_scope): consumers_in_scope.append(consumer) else: logging.info( '_InsertQuantOp context="%s" name="%s" ignores ' 'consumer "%s" because it is not in scope "%s"', context, name, consumer.name, consumer_scope) return consumers = consumers_in_scope name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_scope = ops.get_name_scope() if name_scope: name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. fake_quant_ops = set( ['FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs']) if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): return if moving_avg: quant = (quant_ops.MovingAvgQuantize(inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = (quant_ops.LastValueQuantize(inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') if consumers: tensors_modified_count = graph_editor.reroute_ts([quant], [inputs], can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError( 'No inputs quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide(recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity(correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal(correction_scale, name='reciprocal_compute') mv = match.moving_mean_tensor #if match.moving_mean_tensor is not None else 0 correction_offset = math_ops.multiply(match.gamma_tensor, match.mean_tensor * recip_sigma - mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') common.RerouteTensor(bn_decay_mean_out, match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond(use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') common.RerouteTensor(bn_decay_var_out, match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def last_value_quantize(self, inputs, per_channel=False, init_min=-6.0, init_max=6.0, name_prefix='FixedValueQuant', reuse=None, is_training=False, num_bits=8, narrow_range=False, relative_quantile=0, freeze=False, quant_delay=False): """Adds a layer that collects quantization ranges as last input ranges. LastValueQuantize creates variables called 'min' and 'max', representing the interval used for quantization and clamping. Args: inputs: a tensor containing values to be quantized. per_channel: (Optional) a boolean specifying whether to use different quantization ranges per output channel. init_min: a float scalar, the initial value for variable min. init_max: a float scalar, the initial value for variable max. name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. num_bits: Number of bits to use for quantization, must be between 2 and 8. narrow_range: Whether to use the narrow quantization range [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. relative_quantile: Specify the location of quantization min and max parameters. relative_quantile = 0 is equivalent to using min and max of input; relative_quantile = 1 set min and max the optimal location assuming the input distribution is uniform. In reality, a good value should be in the range [0 1]. freeze: If True, the min and max variables are calculated once at the begining of training and then freeze. This is used for quantized fine-tuning of a pretrained checkpoint. If False, the min and max are calculated and updated every cycle. quant_delay: The number of global steps after which the fake quantization are turned on. Used for performing fine-tuning experiment without starting from a pre-trained checkpoint. Returns: a tensor containing quantized values. """ with tf.variable_scope( None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: scope.set_partitioner(None) input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4] min_max_shape = [input_shape[-1]] else: min_max_shape = [] min_var = tf.get_variable('min', min_max_shape, tf.float32, initializer=tf.constant_initializer(init_min), trainable=False) max_var = tf.get_variable('max', min_max_shape, tf.float32, initializer=tf.constant_initializer(init_max), trainable=False) if not is_training: return self.delayed_quant( inputs, min_var, max_var, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range, quant_delay=None) if per_channel: if input_dim == 2: reduce_dims = [0] elif input_dim == 4: reduce_dims = [0, 1, 2] if num_bits >= 4: quantile = 0 else: quantile = (1.0 / 2.0**(num_bits + 1.0)) * relative_quantile * 100 if per_channel: if input_dim >= 2: batch_min = tfp.stats.percentile( inputs, q=quantile, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: batch_min = tfp.stats.percentile( inputs, q=quantile, name='BatchMin') if per_channel: if input_dim >= 2: batch_max = tfp.stats.percentile( inputs, q=100 - quantile, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: batch_max = tfp.stats.percentile( inputs, q=100 - quantile, name='BatchMax') if narrow_range: multiplier = 1.0 else: multiplier = 1.0 + 1.0 / (2.0**(num_bits-1.0) - 1.0) batch_abs_max = tf.maximum(tf.abs(batch_min), tf.abs(batch_max)) if narrow_range: batch_adjusted_min = 0 - batch_abs_max else: multiplier = 1.0 + 1.0 / (2.0**(num_bits-1.0) - 1.0) batch_adjusted_min = 0 - tf.scalar_mul(multiplier, batch_abs_max) batch_abs_max = tf.cast(batch_abs_max, tf.float32) batch_adjusted_min = tf.cast(batch_adjusted_min, tf.float32) if freeze: def make_var_op(var): def f(): return var return f quant_step = common.CreateOrGetQuantizationStep() min_max_assign = tf.less_equal( quant_step, 1, name='MinMaxAssign') min_value = tf.cond(min_max_assign, make_var_op(batch_adjusted_min), make_var_op(min_var), name='AssignMinCond') max_value = tf.cond(min_max_assign, make_var_op(batch_abs_max), make_var_op(max_var), name='AssignMaxCond') else: min_value = batch_adjusted_min max_value = batch_abs_max assign_min = tf.assign(min_var, min_value) assign_max = tf.assign(max_var, max_value) return self.delayed_quant( inputs, assign_min, assign_max, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range, quant_delay=quant_delay)
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, symmetric=False, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False, producer_scope=None, consumer_scope=None): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. symmetric: (Optional) If true, use symmetric quantization limits instead of training the minimum and maximum of each quantization range separately. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. consumer_scope: The restriction of producer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ if producer_scope and not producer.name.startswith(producer_scope): logging.info( '_InsertQuantOp ignores context="%s" name="%s" ' 'because producer "%s" is not in scope "%s"', context, name, producer.name, producer_scope) return if consumer_scope: consumers_in_scope = [] for consumer in consumers: if consumer.name.startswith(consumer_scope): consumers_in_scope.append(consumer) else: logging.info( '_InsertQuantOp context="%s" name="%s" ignores ' 'consumer "%s" because it is not in scope "%s"', context, name, consumer.name, consumer_scope) return consumers = consumers_in_scope name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_scope = ops.get_name_scope() if name_scope: name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. fake_quant_op = _GetFollowingFakeQuantOp(inputs) # If we find that we are attempting to insert a fake quant op following # a fake quant, we skip inserting a fake quant op if fake_quant_op is None: if moving_avg: quant = (quant_ops.MovingAvgQuantize( inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, symmetric=symmetric, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = (quant_ops.LastValueQuantize( inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, symmetric=symmetric, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') else: # return # If a fake quant op is present already, make sure that # any downstream use of the tensor reroutes to the appropriate quantized # tensor. If there is no quant_delay, this is simply the output of the # fake quant op. If there is a quant delay, we reroute to the output # of the delayed quant operation, which inserts quantization only after # a specified quant_delay quant = fake_quant_op.outputs[0] if quant_delay and quant_delay > 0: name_prefix = '/'.join(quant.name.split('/')[:-1]) quant = quant.graph.get_tensor_by_name(name_prefix + '/delayed_quant/Merge:0') pruned_consumer_set = set() for consumer in consumers: fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0]) if (fake_quant_dest_op is None or fake_quant_dest_op.name != fake_quant_op.name): pruned_consumer_set.add(consumer) consumers = pruned_consumer_set # If we have # input->pass_through->fake_quant # there is nothing to reroute. # # If we have # input-> pass_through->fake_quant # |-> consumer # Then we reroute such that: # input-> pass_through->fake_quant # |-> consumer if consumers: tensors_modified_count = common.RerouteTensor(quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError( 'No inputs quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))