def quantize_scope(*args): """Required scope to deserialize quantized models stored in tf.keras h5 format. Args: *args: Variable length list of dictionaries of name, class pairs to add to the scope created by this method. Returns: Object of type `CustomObjectScope` with quantization objects included. Example: ```python tf.keras.models.save_model(quantized_model, keras_file) with quantize_scope(): loaded_model = tf.keras.models.load_model(keras_file) ``` """ quantization_objects = { 'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate, 'QuantizeAwareActivation': quantize_aware_activation.QuantizeAwareActivation, 'NoOpActivation': quantize_aware_activation.NoOpActivation, 'QuantizeWrapper': quantize_wrapper.QuantizeWrapper, 'QuantizeLayer': quantize_layer.QuantizeLayer, # TODO(tf-mot): add way for different quantization schemes to modify this. '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D, # pylint: disable=protected-access '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D # pylint: disable=protected-access } quantization_objects.update(default_8bit_quantize_registry._types_dict()) # pylint: disable=protected-access quantization_objects.update(quantizers._types_dict()) # pylint: disable=protected-access return tf.keras.utils.custom_object_scope(*(args + (quantization_objects, )))
def quantize_scope(*args): """Scope which can be used to deserialize quantized Keras models and layers. Under `quantize_scope`, Keras methods such as `tf.keras.load_model` and `tf.keras.models.model_from_config` will be able to deserialize Keras models and layers which contain quantization classes such as `QuantizeConfig` and `Quantizer`. Example: ```python tf.keras.models.save_model(quantized_model, keras_file) with quantize_scope(): loaded_model = tf.keras.models.load_model(keras_file) # If your quantized model uses custom objects such as a specific `Quantizer`, # you can pass them to quantize_scope to deserialize your model. with quantize_scope({'FixedRangeQuantizer', FixedRangeQuantizer} loaded_model = tf.keras.models.load_model(keras_file) ``` For further understanding, see `tf.keras.utils.custom_object_scope`. Args: *args: Variable length list of dictionaries of `{name, class}` pairs to add to the scope created by this method. Returns: Object of type `CustomObjectScope` with quantization objects included. """ quantization_objects = { 'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate, 'QuantizeAwareActivation': quantize_aware_activation.QuantizeAwareActivation, 'NoOpActivation': quantize_aware_activation.NoOpActivation, 'QuantizeWrapper': quantize_wrapper.QuantizeWrapper, 'QuantizeLayer': quantize_layer.QuantizeLayer, # TODO(tf-mot): add way for different quantization schemes to modify this. '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D, # pylint: disable=protected-access '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D # pylint: disable=protected-access } quantization_objects.update(default_8bit_quantize_registry._types_dict()) # pylint: disable=protected-access quantization_objects.update(quantizers._types_dict()) # pylint: disable=protected-access return tf.keras.utils.custom_object_scope(*(args + (quantization_objects, )))
def testSerializationQuantizeWrapper(self): input_shape = (2, ) layer = keras.layers.Dense(3) wrapper = QuantizeWrapper( layer=layer, quantize_config=self.quantize_registry.get_quantize_config(layer), input_shape=input_shape) custom_objects = { 'QuantizeAwareActivation': QuantizeAwareActivation, 'QuantizeWrapper': QuantizeWrapper } custom_objects.update(default_8bit_quantize_registry._types_dict()) serialized_wrapper = serialize_layer(wrapper) with custom_object_scope(custom_objects): wrapper_from_config = deserialize_layer(serialized_wrapper) self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
def testSerialization(self): expected_config = { 'class_name': 'Default8BitQuantizeConfigRNN', 'config': { 'weight_attrs': [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']], 'activation_attrs': [['activation', 'recurrent_activation'], ['activation', 'recurrent_activation']], 'quantize_output': False } } serialized_quantize_config = serialize_keras_object(self.quantize_config) self.assertEqual(expected_config, serialized_quantize_config) quantize_config_from_config = deserialize_keras_object( serialized_quantize_config, module_objects=globals(), custom_objects=default_8bit_quantize_registry._types_dict()) self.assertEqual(self.quantize_config, quantize_config_from_config)