Beispiel #1
0
    def inference_quant(self, input_tensor_name_list, input_data_list,
                        trace_tensor_name_list):
        #tf.import_graph_def(self._graph_def, name='')
        with tf.Session() as sess:

            # insert quant ops
            quant_bits = 8
            is_training = True
            target_weight_tensor_name = "CifarNet/conv1/weights:0"
            target_weight_tensor = sess.graph.get_tensor_by_name(
                target_weight_tensor_name)
            target_weight_tensor_quant = quant_ops.LastValueQuantize(
                target_weight_tensor,
                is_training=is_training,
                narrow_range=True,
                num_bits=quant_bits,
                name_prefix=target_weight_tensor_name + "Weights")
            target_input_tensor_name = "Placeholder:0"
            target_input_tensor = sess.graph.get_tensor_by_name(
                target_input_tensor_name)
            target_input_tensor_quant = quant_ops.MovingAvgQuantize(
                target_input_tensor,
                is_training=is_training,
                narrow_range=True,
                num_bits=quant_bits,
                name_prefix="Input")

            init = tf.global_variables_initializer()
            sess.run(init)

            trace_tensor_list = []
            for tensor_name in trace_tensor_name_list:
                trace_tensor_list.append(
                    sess.graph.get_tensor_by_name(tensor_name))

            input_tensor_list = []
            for tensor_name in input_tensor_name_list:
                input_tensor_list.append(
                    sess.graph.get_tensor_by_name(tensor_name))

            feed_dict = {}
            for input_tensor, input_data in zip(input_tensor_list,
                                                input_data_list):
                feed_dict[input_tensor] = input_data

            trace_tensor_list.append(target_weight_tensor)
            trace_tensor_list.append(target_weight_tensor_quant)
            trace_tensor_list.append(target_input_tensor)
            trace_tensor_list.append(target_input_tensor_quant)

            outputs = sess.run(trace_tensor_list, feed_dict=feed_dict)

            if len(outputs) != len(trace_tensor_list):
                print("inference error")
                assert (0)

            for tensor, data in zip(trace_tensor_list, outputs):
                print("%s\n%r" % (tensor, data))
            return outputs
 def testVariablesNotPartitioned_LastValue(self):
     # Variables added should not use a default partiioner since they are
     # scalar. There would be a tensorflow error thrown if the partitioner was
     # respected by the rewrite.
     with ops.Graph().as_default():
         with variable_scope.variable_scope(
                 'part',
                 partitioner=partitioned_variables.fixed_size_partitioner(
                     2)):
             x = array_ops.placeholder(dtypes.float32, shape=[2])
             _ = quant_ops.LastValueQuantize(x,
                                             init_min=0.0,
                                             init_max=0.0,
                                             is_training=True,
                                             vars_collection=_MIN_MAX_VARS)
    def testLastValueQuantizeTrainingAssign(self):
        g = ops.Graph()
        with session.Session(graph=g) as sess:
            x = array_ops.placeholder(dtypes.float32, shape=[2])
            y = quant_ops.LastValueQuantize(x,
                                            init_min=0.0,
                                            init_max=0.0,
                                            is_training=True,
                                            vars_collection=_MIN_MAX_VARS)

            # Run the step.
            sess.run(variables.global_variables_initializer())
            sess.run(y, feed_dict={x: [-1.0, 1.0]})
            # Now check that the min_max_vars were, in fact, updated.
            min_value, max_value = self._GetMinMaxValues(sess)
            self.assertEqual(min_value, -1.0)
            self.assertEqual(max_value, 1.0)
Beispiel #4
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_prefix = common.DropStringPrefix(name_prefix,
                                          ops.get_name_scope() + '/')

    inputs = producer.outputs[0]
    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    nodes_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                   can_modify=consumers)
    if nodes_modified_count != len(consumers):
        raise ValueError('Some inputs not quantized for ops: [%s]' %
                         ', '.join([consumer.name for consumer in consumers]))
Beispiel #5
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    fake_quant_ops = set(
        ['FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs'])
    if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
        return

    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    if consumers:
        tensors_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                         can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
Beispiel #6
0
    def _InsertQuantOp(
        self,
        context,
        producer,
        consumers,
        name,
        moving_avg=True,
        init_min=-6.0,
        init_max=6.0,
        delay_requested=True,
        bits=8,
        narrow_range=False,
    ):
        """Inserts a quant op between a producer op and (multiple) consumer ops.

    Args:
      context: Context where producer and consumer operations are nested.
      producer: Producer operation of the pairs where quantization will be
        inserted.
      consumers: Consumer operations of the pairs.
      name: Name for the new quantization op within the context.
      moving_avg: Specifies whether to use exponential moving average or just
        the last value seen.
      init_min: Starting minimum value for the new quantization op.
      init_max: Starting maximum value for the new quantization op.
      delay_requested: If true, implement quantization delay where needed.
        False value explicitly disables delay quantization everywhere.
      bits: Number of bits to use for quantization, must be between 2 and 8.
      narrow_range: Whether to use the narrow quantization range
        [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    Raises:
      ValueError: When producer operation is not directly connected to the
        consumer operation.
    """
        scope = context + '/' + name
        inputs = producer.outputs[0]
        if moving_avg:
            quant = (quant_ops.MovingAvgQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                ema_decay=self.ema_decay,
                is_training=self.is_training,
                num_bits=bits,
                narrow_range=narrow_range,
                updates_collection=_UPDATE_QUANT_OPS,
                vars_collection=self.vars_collection,
                scope=scope))
        else:
            quant = (quant_ops.LastValueQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                is_training=self.is_training,
                num_bits=bits,
                narrow_range=narrow_range,
                updates_collection=_UPDATE_QUANT_OPS,
                vars_collection=self.vars_collection,
                scope=scope))

        if delay_requested and self.quant_delay and self.quant_delay > 0:
            activate_quant = math_ops.greater_equal(
                training_util.get_or_create_global_step(),
                self.quant_delay,
                name=scope + '/activate_quant')
            quant = control_flow_ops.cond(activate_quant,
                                          lambda: quant,
                                          lambda: inputs,
                                          name=scope + '/delayed_quant')

        nodes_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                       can_modify=consumers)
        if nodes_modified_count != len(consumers):
            raise ValueError(
                'Some inputs not quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
Beispiel #7
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   symmetric=False,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    symmetric: (Optional) If true, use symmetric quantization limits instead of
      training the minimum and maximum of each quantization range separately.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    fake_quant_op = _GetFollowingFakeQuantOp(inputs)

    # If we find that we are attempting to insert a fake quant op following
    # a fake quant, we skip inserting a fake quant op

    if fake_quant_op is None:
        if moving_avg:
            quant = (quant_ops.MovingAvgQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                ema_decay=ema_decay,
                is_training=is_training,
                num_bits=bits,
                symmetric=symmetric,
                narrow_range=narrow_range,
                vars_collection=vars_collection,
                name_prefix=name_prefix))
        else:
            quant = (quant_ops.LastValueQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                is_training=is_training,
                num_bits=bits,
                symmetric=symmetric,
                narrow_range=narrow_range,
                vars_collection=vars_collection,
                name_prefix=name_prefix))

        if quant_delay and quant_delay > 0:
            activate_quant = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                quant_delay,
                name=name_prefix + '/activate_quant')
            quant = control_flow_ops.cond(activate_quant,
                                          lambda: quant,
                                          lambda: inputs,
                                          name=name_prefix + '/delayed_quant')
    else:
        #  return
        # If a fake quant op is present already, make sure that
        # any downstream use of the tensor reroutes to the appropriate quantized
        # tensor. If there is no quant_delay, this is simply the output of the
        # fake quant op. If there is a quant delay, we reroute to the output
        # of the delayed quant operation, which inserts quantization only after
        # a specified quant_delay

        quant = fake_quant_op.outputs[0]
        if quant_delay and quant_delay > 0:
            name_prefix = '/'.join(quant.name.split('/')[:-1])
            quant = quant.graph.get_tensor_by_name(name_prefix +
                                                   '/delayed_quant/Merge:0')
        pruned_consumer_set = set()
        for consumer in consumers:
            fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
            if (fake_quant_dest_op is None
                    or fake_quant_dest_op.name != fake_quant_op.name):
                pruned_consumer_set.add(consumer)
        consumers = pruned_consumer_set

        # If we have
        # input->pass_through->fake_quant
        # there is nothing to reroute.
        #
        # If we have
        #  input-> pass_through->fake_quant
        #                |-> consumer
        # Then we reroute such that:
        # input-> pass_through->fake_quant
        #                            |-> consumer
    if consumers:
        tensors_modified_count = common.RerouteTensor(quant,
                                                      inputs,
                                                      can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))