Exemplo n.º 1
0
    def apply(self, model, layer_quantize_map):
        """Implement TFLite transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeConfig`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

        transforms = [
            tflite_transforms.InputLayerQuantize(),
            tflite_transforms.Conv2DBatchNormReLUQuantize(),
            tflite_transforms.Conv2DBatchNormActivationQuantize(),
            tflite_transforms.Conv2DBatchNormQuantize(),
            tflite_transforms.ConcatTransform6Inputs(),
            tflite_transforms.ConcatTransform5Inputs(),
            tflite_transforms.ConcatTransform4Inputs(),
            tflite_transforms.ConcatTransform3Inputs(),
            tflite_transforms.ConcatTransform(),
        ]

        return model_transformer.ModelTransformer(
            model, transforms, layer_quantize_map.keys(),
            layer_quantize_map).transform()
  def apply(self, model, layer_quantize_map):
    """Implement TFLite transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeProvider`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

    # TODO(pulkitb): Sequential models not supported yet. Remove once support is
    # added.
    if isinstance(model, keras.Sequential):
      return model, layer_quantize_map

    transforms = [
        tflite_transforms.InputLayerQuantize(),
        tflite_transforms.Conv2DBatchNormReLUQuantize(),
        tflite_transforms.Conv2DBatchNormQuantize(),
    ]

    return model_transformer.ModelTransformer(
        model, transforms,
        layer_quantize_map.keys(), layer_quantize_map).transform()
    def apply(self, model, layer_quantize_map):
        """Implement default 8-bit transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeConfig`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

        transforms = [
            default_n_bit_transforms.InputLayerQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.SeparableConv1DQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.SeparableConvQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DBatchNormReLUQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DBatchNormActivationQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DReshapeBatchNormQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.Conv2DBatchNormQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.ConcatTransform6Inputs(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.ConcatTransform5Inputs(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.ConcatTransform4Inputs(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.ConcatTransform3Inputs(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.ConcatTransform(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.LayerReLUQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
            default_n_bit_transforms.LayerReluActivationQuantize(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation),
        ]
        return model_transformer.ModelTransformer(
            model, transforms, set(layer_quantize_map.keys()),
            layer_quantize_map).transform()