Exemple #1
0
    def apply(self, model, layer_quantize_map):
        """Implement TFLite transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeConfig`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

        transforms = [
            tflite_transforms.InputLayerQuantize(),
            tflite_transforms.Conv2DBatchNormReLUQuantize(),
            tflite_transforms.Conv2DBatchNormActivationQuantize(),
            tflite_transforms.Conv2DBatchNormQuantize(),
            tflite_transforms.ConcatTransform6Inputs(),
            tflite_transforms.ConcatTransform5Inputs(),
            tflite_transforms.ConcatTransform4Inputs(),
            tflite_transforms.ConcatTransform3Inputs(),
            tflite_transforms.ConcatTransform(),
        ]

        return model_transformer.ModelTransformer(
            model, transforms, layer_quantize_map.keys(),
            layer_quantize_map).transform()
  def apply(self, model, layer_quantize_map):
    """Implement TFLite transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeProvider`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

    # TODO(pulkitb): Sequential models not supported yet. Remove once support is
    # added.
    if isinstance(model, keras.Sequential):
      return model, layer_quantize_map

    transforms = [
        tflite_transforms.InputLayerQuantize(),
        tflite_transforms.Conv2DBatchNormReLUQuantize(),
        tflite_transforms.Conv2DBatchNormQuantize(),
    ]

    return model_transformer.ModelTransformer(
        model, transforms,
        layer_quantize_map.keys(), layer_quantize_map).transform()
    def testConv2DBatchNormQuantize(self, layer_type):
        model = self._get_model(layer_type, False)
        input_shape = self._get_input_shape(layer_type)

        transformed_model, updated_metadata = ModelTransformer(
            model,
            [tflite_transforms.Conv2DBatchNormQuantize()],
        ).transform()

        conv_layer = transformed_model.layers[1]
        bn_layer = transformed_model.layers[2]

        self.assertIsInstance(conv_layer.activation,
                              quantize_aware_activation.NoOpActivation)
        self.assertIsInstance(
            updated_metadata.get(bn_layer.name).get('quantize_provider'),
            tflite_quantize_providers.OutputQuantizeProvider)

        inputs = np.random.standard_normal(input_shape)
        self.assertAllClose(transformed_model.predict(inputs),
                            model.predict(inputs))