Example #1
0
    def replacement(self, match_layer):
        concat_layer_node = match_layer
        feeding_layer_nodes = match_layer.input_layers

        default_registry = default_8bit_quantize_registry.QuantizeRegistry()

        feed_quantize_configs = []
        for feed_layer_node in feeding_layer_nodes:
            quantize_config = feed_layer_node.metadata.get('quantize_config')
            if not quantize_config:
                layer_class = self._get_layer_type(
                    feed_layer_node.layer['class_name'])
                if layer_class is None:
                    # Concat has an input layer we don't recognize. Return.
                    return match_layer

                if layer_class == keras.layers.Concatenate:
                    # Input layer to Concat is also Concat. Don't quantize it.
                    feed_layer_node.metadata['quantize_config'] = \
                      default_8bit_quantize_configs.NoOpQuantizeConfig()
                    continue

                if not default_registry._is_supported_layer(layer_class):
                    # Feeding layer is not supported by registry
                    return match_layer

                quantize_config = default_registry._get_quantize_config(
                    layer_class)
                feed_layer_node.metadata['quantize_config'] = quantize_config

            feed_quantize_configs.append(quantize_config)

        # TODO(pulkitb): this currently only disables output quantize config, but
        # cannot properly handle if the FQ was added to the activation. Hand this
        # properly.
        for quantize_config in feed_quantize_configs:
            self._disable_output_quantize(quantize_config)

        if not concat_layer_node.metadata.get('quantize_config'):
            concat_layer_node.metadata['quantize_config'] = \
              default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()

        return concat_layer_node
Example #2
0
def quantize_apply(model):
    """Introduce quantization operations to a tf.keras model.

  This function takes a tf.keras model which has been annotated with
  `quantize_annotate` and constructs a new model in which each of the
  annotated layers will ultimately be quantized. The new quantization
  operations enable the model to **emulate* quantization during training
  and store information that downstream tools will use to produce
  an actually quantized model.

  Apply quantization to a model:

  ```python
  model = quantize_apply(annotated_model)
  ```

  Note that this function removes the optimizer from the original model.
  Additionally, training the model returned by `quantize_apply` will not affect
  the weights of the original model.

  Args:
    model: A tf.keras Sequential or Functional model which has been annotated
    with `quantize_annotate`. It can have pre-trained weights.

  Returns:
    Returns a new tf.keras model in which the annotated layers have been
    prepared for quantization.
  """
    if model is None:
        raise ValueError('`model` cannot be None')

    if not isinstance(model, keras.Model):
        raise ValueError('`model` can only be a `tf.keras.Model` instance.'
                         'You passed an instance of type: {input}.'.format(
                             input=model.__class__.__name__))

    if not isinstance(model, keras.Sequential) \
        and not model._is_graph_network:  # pylint: disable=protected-access
        raise ValueError('`model` can only either be a tf.keras Sequential or '
                         'Functional model.')

    # Have at least 1 layer annotated with QuantizeAnnotate
    if not any(
            isinstance(layer, quantize_annotate_mod.QuantizeAnnotate)
            for layer in model.layers):
        raise ValueError(
            '`model` must contain at least one layer which have been '
            'annotated with `quantize_annotate*`. There are no layers '
            'to quantize.')

    if not model.built:
        raise ValueError(
            '`model` must be a built model. '
            'been built yet. Please call `model.build(input_shape)` '
            'before quantizing your model.')

    def _clone_model_with_weights(model_to_clone):
        cloned_model = keras.models.clone_model(model_to_clone)
        cloned_model.set_weights(model_to_clone.get_weights())

        return cloned_model

    def _extract_original_model(model_to_unwrap):
        """Extracts original model by removing wrappers."""
        layer_quantize_map = {}

        def _unwrap(layer):
            if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
                return layer

            annotate_wrapper = layer
            layer_quantize_map[annotate_wrapper.layer.name] = {
                'quantize_config': annotate_wrapper.quantize_config
            }
            return annotate_wrapper.layer

        unwrapped_model = keras.models.clone_model(model_to_unwrap,
                                                   input_tensors=None,
                                                   clone_function=_unwrap)

        return unwrapped_model, layer_quantize_map

    def _quantize(layer):  # pylint: disable=missing-docstring
        if layer.name not in layer_quantize_map:
            return layer

        quantize_config = layer_quantize_map[layer.name].get('quantize_config')
        if not quantize_config and quantize_registry.supports(layer):
            quantize_config = quantize_registry.get_quantize_config(layer)

        if not quantize_config:
            error_msg = (
                'Layer {}:{} is not supported. You can quantize this '
                'layer by passing a `tfmot.quantization.keras.QuantizeConfig` '
                'instance to the `quantize_annotate_layer` '
                'API.')
            raise RuntimeError(
                error_msg.format(layer.name, layer.__class__,
                                 quantize_registry.__class__))

        # `QuantizeWrapper` does not copy any additional layer params from
        # `QuantizeAnnotate`. This should generally be fine, but occasionally
        # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
        # TODO(pulkitb): Ensure this does not affect model cloning.
        return quantize_wrapper.QuantizeWrapper(layer, quantize_config)

    # 1. Create a copy of the model with the same weights. This ensures
    # modifications don't affect the original model, or its weights.
    model_copy = _clone_model_with_weights(model)

    # 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
    # extracts the original model structure (easier to transform), and
    # stores relevant quantization information in a map.
    unwrapped_model, layer_quantize_map = _extract_original_model(model_copy)
    # Model cloning excludes input layers. Add input layers into the map
    # since they need to be matched for patterns as well.
    # pylint: disable=protected-access
    for input_layer in unwrapped_model._input_layers:
        for outbound_node in input_layer._outbound_nodes:
            if outbound_node.outbound_layer.name in layer_quantize_map:
                layer_quantize_map[input_layer.name] = {}
    # pylint: enable=protected-access

    # 3. Apply the graph transformations required to match model passes on
    # target device/dialect.
    quantize_transform = \
      default_8bit_quantize_layout_transform.QuantizeLayoutTransform()
    # layer_quantize_map gets modified by the transformations.
    transformed_model, layer_quantize_map = quantize_transform.apply(
        unwrapped_model, layer_quantize_map)

    # TODO(pulkitb): Think more about how to introduce Default specific code.
    quantize_registry = default_8bit_quantize_registry.QuantizeRegistry()

    # 4. Actually quantize all the relevant layers in the model. This is done by
    # wrapping the layers with QuantizeWrapper, and passing the associated
    # `QuantizeConfig`.

    return keras.models.clone_model(transformed_model,
                                    input_tensors=None,
                                    clone_function=_quantize)
 def setUp(self):
     super(QuantizeRegistryTest, self).setUp()
     self.quantize_registry = default_8bit_quantize_registry.QuantizeRegistry(
     )
Example #4
0
 def get_quantize_registry(self):
     return default_8bit_quantize_registry.QuantizeRegistry()