def apply(self, model, layer_quantize_map): """Implement TFLite transforms. Currently this means the following. 1. Pull activations into layers, and apply fuse activations. (TODO) 2. Modify range in incoming layers for Concat. (TODO) 3. Fuse Conv2D/DepthwiseConv2D + BN into single layer. Args: model: Keras model to be quantized. layer_quantize_map: Map with keys as layer names, and values as dicts containing custom `QuantizeConfig`s which may have been passed with layers. Returns: (Transformed Keras model to better match TensorFlow Lite backend, updated layer quantize map.) """ transforms = [ tflite_transforms.InputLayerQuantize(), tflite_transforms.Conv2DBatchNormReLUQuantize(), tflite_transforms.Conv2DBatchNormActivationQuantize(), tflite_transforms.Conv2DBatchNormQuantize(), tflite_transforms.ConcatTransform6Inputs(), tflite_transforms.ConcatTransform5Inputs(), tflite_transforms.ConcatTransform4Inputs(), tflite_transforms.ConcatTransform3Inputs(), tflite_transforms.ConcatTransform(), ] return model_transformer.ModelTransformer( model, transforms, layer_quantize_map.keys(), layer_quantize_map).transform()
def apply(self, model, layer_quantize_map): """Implement TFLite transforms. Currently this means the following. 1. Pull activations into layers, and apply fuse activations. (TODO) 2. Modify range in incoming layers for Concat. (TODO) 3. Fuse Conv2D/DepthwiseConv2D + BN into single layer. Args: model: Keras model to be quantized. layer_quantize_map: Map with keys as layer names, and values as dicts containing custom `QuantizeProvider`s which may have been passed with layers. Returns: (Transformed Keras model to better match TensorFlow Lite backend, updated layer quantize map.) """ # TODO(pulkitb): Sequential models not supported yet. Remove once support is # added. if isinstance(model, keras.Sequential): return model, layer_quantize_map transforms = [ tflite_transforms.InputLayerQuantize(), tflite_transforms.Conv2DBatchNormReLUQuantize(), tflite_transforms.Conv2DBatchNormQuantize(), ] return model_transformer.ModelTransformer( model, transforms, layer_quantize_map.keys(), layer_quantize_map).transform()
def apply(self, model, layer_quantize_map): """Implement default 8-bit transforms. Currently this means the following. 1. Pull activations into layers, and apply fuse activations. (TODO) 2. Modify range in incoming layers for Concat. (TODO) 3. Fuse Conv2D/DepthwiseConv2D + BN into single layer. Args: model: Keras model to be quantized. layer_quantize_map: Map with keys as layer names, and values as dicts containing custom `QuantizeConfig`s which may have been passed with layers. Returns: (Transformed Keras model to better match TensorFlow Lite backend, updated layer quantize map.) """ transforms = [ default_n_bit_transforms.InputLayerQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.SeparableConv1DQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.SeparableConvQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DBatchNormReLUQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DBatchNormActivationQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DReshapeBatchNormQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.Conv2DBatchNormQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.ConcatTransform6Inputs( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.ConcatTransform5Inputs( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.ConcatTransform4Inputs( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.ConcatTransform3Inputs( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.ConcatTransform( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.LayerReLUQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), default_n_bit_transforms.LayerReluActivationQuantize( num_bits_weight=self._num_bits_weight, num_bits_activation=self._num_bits_activation), ] return model_transformer.ModelTransformer( model, transforms, set(layer_quantize_map.keys()), layer_quantize_map).transform()