def quantize_apply(model): """Apply quantization operations to a keras model. This function takes a keras model which has been annotated with `quantize_annotate` and constructs a new keras model in which each of the annotated layers have been quantized. The quantization process introduces new quantization ops in the Tensorflow graph to appropriately emulate quantization loss. Note that to exactly emulate quantization loss, certain graph/model transformations may be applied. This is required since the actual quantized kernel implementations may apply similar transformations. Args: model: A keras Sequential or Functional model which has been annotated with `quantize_annotate`. Returns: Returns a new cloned keras model in which the annotated layers have been quantized. All the existing layers are cloned. """ if not isinstance(model, keras.Model): raise ValueError('Only a keras `Model` instance can be used.') if not isinstance(model, keras.Sequential) \ and not model._is_graph_network: # pylint: disable=protected-access raise ValueError('model should be either a keras.Sequential or a ' 'keras functional model.') # Have at least 1 layer annotated with QuantizeAnnotate if not any( isinstance(layer, quantize_annotate_mod.QuantizeAnnotate) for layer in model.layers): raise ValueError( 'model does not contain any layers which have been ' 'annotated with `quantize_annotate`. There are no layers ' 'to quantize.') if not model.built: raise ValueError( 'quantization cannot be applied to a model which has not' 'been built yet. Please call `model.build(input_shape)`' 'before quantizing your model.') def _clone_model_with_weights(model_to_clone): cloned_model = keras.models.clone_model(model_to_clone) cloned_model.set_weights(model_to_clone.get_weights()) return cloned_model def _extract_original_model(model_to_unwrap): """Extracts original model by removing wrappers.""" layer_quantize_map = {} def _unwrap(layer): if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate): return layer annotate_wrapper = layer layer_quantize_map[annotate_wrapper.layer.name] = { 'quantize_provider': annotate_wrapper.quantize_provider } return annotate_wrapper.layer unwrapped_model = keras.models.clone_model(model_to_unwrap, input_tensors=None, clone_function=_unwrap) return unwrapped_model, layer_quantize_map def _quantize(layer): # pylint: disable=missing-docstring if layer.name not in layer_quantize_map: return layer quantize_provider = layer_quantize_map[layer.name].get( 'quantize_provider') if not quantize_provider and quantize_registry.supports(layer): quantize_provider = quantize_registry.get_quantize_provider(layer) if not quantize_provider: error_msg = ( 'Could not find a suitable QuantizeProvider for layer {}. ' 'Either the registry {} should be provide one, or the user ' 'should provide one while annotating the layer using ' 'QuantizeAnnotate.') raise RuntimeError( error_msg.format(layer.__class__, quantize_registry.__class__)) # `QuantizeWrapper` does not copy any additional layer params from # `QuantizeAnnotate`. This should generally be fine, but occasionally # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params. # TODO(pulkitb): Ensure this does not affect model cloning. return quantize_wrapper.QuantizeWrapper(layer, quantize_provider) # 1. Create a copy of the model with the same weights. This ensures # modifications don't affect the original model, or its weights. model_copy = _clone_model_with_weights(model) # 2. Remove QuantizeAnnotate wrappers from the layers in the model. This # extracts the original model structure (easier to transform), and # stores relevant quantization information in a map. unwrapped_model, layer_quantize_map = _extract_original_model(model_copy) # 3. Apply the graph transformations required to match model passes on # target device/dialect. quantize_transform = \ tflite_quantize_layout_transform.TFLiteQuantizeLayoutTransform() transformed_model = quantize_transform.apply(unwrapped_model, layer_quantize_map) # TODO(pulkitb): Think more about how to introduce TFLite specific code. quantize_registry = tflite_quantize_registry.TFLiteQuantizeRegistry() # 4. Actually quantize all the relevant layers in the model. This is done by # wrapping the layers with QuantizeWrapper, and passing the associated # `QuantizeProvider`. return keras.models.clone_model(transformed_model, input_tensors=None, clone_function=_quantize)
def quantize_apply(model): """Introduce quantization operations to a tf.keras model. This function takes a tf.keras model which has been annotated with `quantize_annotate` and constructs a new model in which each of the annotated layers will ultimately be quantized. The new quantization operations enable the model to **emulate* quantization during training and store information that downstream tools will use to produce an actually quantized model. Apply quantization to a model: ```python model = quantize_apply(annotated_model) ``` Note that this function removes the optimizer from the original model. Additionally, training the model returned by `quantize_apply` will not affect the weights of the original model. Args: model: A tf.keras Sequential or Functional model which has been annotated with `quantize_annotate`. It can have pre-trained weights. Returns: Returns a new tf.keras model in which the annotated layers have been prepared for quantization. """ if not isinstance(model, keras.Model): raise ValueError('Only a tf.keras `Model` instance can be used.') if not isinstance(model, keras.Sequential) \ and not model._is_graph_network: # pylint: disable=protected-access raise ValueError('model should be either a tf.keras Sequential or ' 'Functional model.') # Have at least 1 layer annotated with QuantizeAnnotate if not any(isinstance(layer, quantize_annotate_mod.QuantizeAnnotate) for layer in model.layers): raise ValueError('model does not contain any layers which have been ' 'annotated with `quantize_annotate`. There are no layers ' 'to quantize.') if not model.built: raise ValueError('quantization cannot be applied to a model which has not ' 'been built yet. Please call `model.build(input_shape)` ' 'before quantizing your model.') def _clone_model_with_weights(model_to_clone): cloned_model = keras.models.clone_model(model_to_clone) cloned_model.set_weights(model_to_clone.get_weights()) return cloned_model def _extract_original_model(model_to_unwrap): """Extracts original model by removing wrappers.""" layer_quantize_map = {} def _unwrap(layer): if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate): return layer annotate_wrapper = layer layer_quantize_map[annotate_wrapper.layer.name] = { 'quantize_config': annotate_wrapper.quantize_config } return annotate_wrapper.layer unwrapped_model = keras.models.clone_model( model_to_unwrap, input_tensors=None, clone_function=_unwrap) return unwrapped_model, layer_quantize_map def _quantize(layer): # pylint: disable=missing-docstring if layer.name not in layer_quantize_map: return layer quantize_config = layer_quantize_map[layer.name].get('quantize_config') if not quantize_config and quantize_registry.supports(layer): quantize_config = quantize_registry.get_quantize_config(layer) if not quantize_config: error_msg = ('Could not find a suitable QuantizeConfig for layer {}. ' 'Either the registry {} should be provide one, or the user ' 'should provide one while annotating the layer using ' 'QuantizeAnnotate.') raise RuntimeError(error_msg.format( layer.__class__, quantize_registry.__class__)) # `QuantizeWrapper` does not copy any additional layer params from # `QuantizeAnnotate`. This should generally be fine, but occasionally # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params. # TODO(pulkitb): Ensure this does not affect model cloning. return quantize_wrapper.QuantizeWrapper(layer, quantize_config) # 1. Create a copy of the model with the same weights. This ensures # modifications don't affect the original model, or its weights. model_copy = _clone_model_with_weights(model) # 2. Remove QuantizeAnnotate wrappers from the layers in the model. This # extracts the original model structure (easier to transform), and # stores relevant quantization information in a map. unwrapped_model, layer_quantize_map = _extract_original_model(model_copy) # Model cloning excludes input layers. Add input layers into the map # since they need to be matched for patterns as well. # pylint: disable=protected-access for input_layer in unwrapped_model._input_layers: for outbound_node in input_layer._outbound_nodes: if outbound_node.outbound_layer.name in layer_quantize_map: layer_quantize_map[input_layer.name] = {} # pylint: enable=protected-access # 3. Apply the graph transformations required to match model passes on # target device/dialect. quantize_transform = \ tflite_quantize_layout_transform.TFLiteQuantizeLayoutTransform() # layer_quantize_map gets modified by the transformations. transformed_model, layer_quantize_map = quantize_transform.apply( unwrapped_model, layer_quantize_map) # TODO(pulkitb): Think more about how to introduce TFLite specific code. quantize_registry = tflite_quantize_registry.TFLiteQuantizeRegistry() # 4. Actually quantize all the relevant layers in the model. This is done by # wrapping the layers with QuantizeWrapper, and passing the associated # `QuantizeConfig`. return keras.models.clone_model( transformed_model, input_tensors=None, clone_function=_quantize)