def replacement(self, match_layer): concat_layer_node = match_layer feeding_layer_nodes = match_layer.input_layers default_registry = default_8bit_quantize_registry.QuantizeRegistry() feed_quantize_configs = [] for feed_layer_node in feeding_layer_nodes: quantize_config = feed_layer_node.metadata.get('quantize_config') if not quantize_config: layer_class = self._get_layer_type( feed_layer_node.layer['class_name']) if layer_class is None: # Concat has an input layer we don't recognize. Return. return match_layer if layer_class == keras.layers.Concatenate: # Input layer to Concat is also Concat. Don't quantize it. feed_layer_node.metadata['quantize_config'] = \ default_8bit_quantize_configs.NoOpQuantizeConfig() continue if not default_registry._is_supported_layer(layer_class): # Feeding layer is not supported by registry return match_layer quantize_config = default_registry._get_quantize_config( layer_class) feed_layer_node.metadata['quantize_config'] = quantize_config feed_quantize_configs.append(quantize_config) # TODO(pulkitb): this currently only disables output quantize config, but # cannot properly handle if the FQ was added to the activation. Hand this # properly. for quantize_config in feed_quantize_configs: self._disable_output_quantize(quantize_config) if not concat_layer_node.metadata.get('quantize_config'): concat_layer_node.metadata['quantize_config'] = \ default_8bit_quantize_configs.Default8BitOutputQuantizeConfig() return concat_layer_node
def quantize_apply(model): """Introduce quantization operations to a tf.keras model. This function takes a tf.keras model which has been annotated with `quantize_annotate` and constructs a new model in which each of the annotated layers will ultimately be quantized. The new quantization operations enable the model to **emulate* quantization during training and store information that downstream tools will use to produce an actually quantized model. Apply quantization to a model: ```python model = quantize_apply(annotated_model) ``` Note that this function removes the optimizer from the original model. Additionally, training the model returned by `quantize_apply` will not affect the weights of the original model. Args: model: A tf.keras Sequential or Functional model which has been annotated with `quantize_annotate`. It can have pre-trained weights. Returns: Returns a new tf.keras model in which the annotated layers have been prepared for quantization. """ if model is None: raise ValueError('`model` cannot be None') if not isinstance(model, keras.Model): raise ValueError('`model` can only be a `tf.keras.Model` instance.' 'You passed an instance of type: {input}.'.format( input=model.__class__.__name__)) if not isinstance(model, keras.Sequential) \ and not model._is_graph_network: # pylint: disable=protected-access raise ValueError('`model` can only either be a tf.keras Sequential or ' 'Functional model.') # Have at least 1 layer annotated with QuantizeAnnotate if not any( isinstance(layer, quantize_annotate_mod.QuantizeAnnotate) for layer in model.layers): raise ValueError( '`model` must contain at least one layer which have been ' 'annotated with `quantize_annotate*`. There are no layers ' 'to quantize.') if not model.built: raise ValueError( '`model` must be a built model. ' 'been built yet. Please call `model.build(input_shape)` ' 'before quantizing your model.') def _clone_model_with_weights(model_to_clone): cloned_model = keras.models.clone_model(model_to_clone) cloned_model.set_weights(model_to_clone.get_weights()) return cloned_model def _extract_original_model(model_to_unwrap): """Extracts original model by removing wrappers.""" layer_quantize_map = {} def _unwrap(layer): if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate): return layer annotate_wrapper = layer layer_quantize_map[annotate_wrapper.layer.name] = { 'quantize_config': annotate_wrapper.quantize_config } return annotate_wrapper.layer unwrapped_model = keras.models.clone_model(model_to_unwrap, input_tensors=None, clone_function=_unwrap) return unwrapped_model, layer_quantize_map def _quantize(layer): # pylint: disable=missing-docstring if layer.name not in layer_quantize_map: return layer quantize_config = layer_quantize_map[layer.name].get('quantize_config') if not quantize_config and quantize_registry.supports(layer): quantize_config = quantize_registry.get_quantize_config(layer) if not quantize_config: error_msg = ( 'Layer {}:{} is not supported. You can quantize this ' 'layer by passing a `tfmot.quantization.keras.QuantizeConfig` ' 'instance to the `quantize_annotate_layer` ' 'API.') raise RuntimeError( error_msg.format(layer.name, layer.__class__, quantize_registry.__class__)) # `QuantizeWrapper` does not copy any additional layer params from # `QuantizeAnnotate`. This should generally be fine, but occasionally # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params. # TODO(pulkitb): Ensure this does not affect model cloning. return quantize_wrapper.QuantizeWrapper(layer, quantize_config) # 1. Create a copy of the model with the same weights. This ensures # modifications don't affect the original model, or its weights. model_copy = _clone_model_with_weights(model) # 2. Remove QuantizeAnnotate wrappers from the layers in the model. This # extracts the original model structure (easier to transform), and # stores relevant quantization information in a map. unwrapped_model, layer_quantize_map = _extract_original_model(model_copy) # Model cloning excludes input layers. Add input layers into the map # since they need to be matched for patterns as well. # pylint: disable=protected-access for input_layer in unwrapped_model._input_layers: for outbound_node in input_layer._outbound_nodes: if outbound_node.outbound_layer.name in layer_quantize_map: layer_quantize_map[input_layer.name] = {} # pylint: enable=protected-access # 3. Apply the graph transformations required to match model passes on # target device/dialect. quantize_transform = \ default_8bit_quantize_layout_transform.QuantizeLayoutTransform() # layer_quantize_map gets modified by the transformations. transformed_model, layer_quantize_map = quantize_transform.apply( unwrapped_model, layer_quantize_map) # TODO(pulkitb): Think more about how to introduce Default specific code. quantize_registry = default_8bit_quantize_registry.QuantizeRegistry() # 4. Actually quantize all the relevant layers in the model. This is done by # wrapping the layers with QuantizeWrapper, and passing the associated # `QuantizeConfig`. return keras.models.clone_model(transformed_model, input_tensors=None, clone_function=_quantize)
def setUp(self): super(QuantizeRegistryTest, self).setUp() self.quantize_registry = default_8bit_quantize_registry.QuantizeRegistry( )
def get_quantize_registry(self): return default_8bit_quantize_registry.QuantizeRegistry()