def quantize_scope(*args): """Provides a scope in which Quantized layers and models can be deserialized. If a keras model or layer has been quantized, it needs to be within this scope to be successfully deserialized. Args: *args: Variable length list of dictionaries of name, class pairs to add to the scope created by this method. Returns: Object of type `CustomObjectScope` with quantization objects included. Example: ```python keras.models.save_model(quantized_model, keras_file) with quantize_scope(): loaded_model = keras.models.load_model(keras_file) ``` """ quantization_objects = { 'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate, 'QuantizeAwareActivation': quantize_aware_activation.QuantizeAwareActivation, 'QuantizeWrapper': quantize_wrapper.QuantizeWrapper, # TODO(tf-mot): add way for different quantization schemes to modify this. '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D, # pylint: disable=protected-access '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D # pylint: disable=protected-access } quantization_objects.update(tflite_quantize_registry._types_dict()) # pylint: disable=protected-access quantization_objects.update(quantizers._types_dict()) # pylint: disable=protected-access return custom_object_scope(*(args + (quantization_objects, )))
def quantize_scope(*args): """Required scope to deserialize quantized models stored in tf.keras h5 format. Args: *args: Variable length list of dictionaries of name, class pairs to add to the scope created by this method. Returns: Object of type `CustomObjectScope` with quantization objects included. Example: ```python tf.keras.models.save_model(quantized_model, keras_file) with quantize_scope(): loaded_model = tf.keras.models.load_model(keras_file) ``` """ quantization_objects = { 'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate, 'QuantizeAwareActivation': quantize_aware_activation.QuantizeAwareActivation, 'NoOpActivation': quantize_aware_activation.NoOpActivation, 'QuantizeWrapper': quantize_wrapper.QuantizeWrapper, 'QuantizeLayer': quantize_layer.QuantizeLayer, # TODO(tf-mot): add way for different quantization schemes to modify this. '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D, # pylint: disable=protected-access '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D # pylint: disable=protected-access } quantization_objects.update(tflite_quantize_registry._types_dict()) # pylint: disable=protected-access quantization_objects.update(quantizers._types_dict()) # pylint: disable=protected-access return tf.keras.utils.custom_object_scope(*(args + (quantization_objects,)))
def testSerializationQuantizeWrapper(self): input_shape = (2,) layer = keras.layers.Dense(3) wrapper = QuantizeWrapper( layer=layer, quantize_provider=self.quantize_registry.get_quantize_provider(layer), input_shape=input_shape) custom_objects = { 'QuantizeAwareActivation': QuantizeAwareActivation, 'QuantizeWrapper': QuantizeWrapper } custom_objects.update(tflite_quantize_registry._types_dict()) serialized_wrapper = serialize_layer(wrapper) with keras.utils.custom_object_scope(custom_objects): wrapper_from_config = deserialize_layer(serialized_wrapper) self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
def testSerialization(self): expected_config = { 'class_name': 'TFLiteQuantizeConfigRNN', 'config': { 'weight_attrs': [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']], 'activation_attrs': [['activation', 'recurrent_activation'], ['activation', 'recurrent_activation']], 'quantize_output': False } } serialized_quantize_config = serialize_keras_object(self.quantize_config) self.assertEqual(expected_config, serialized_quantize_config) quantize_config_from_config = deserialize_keras_object( serialized_quantize_config, module_objects=globals(), custom_objects=tflite_quantize_registry._types_dict()) self.assertEqual(self.quantize_config, quantize_config_from_config)
def testSerialization(self): quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider( ['kernel'], ['activation'], False) expected_config = { 'class_name': 'TFLiteQuantizeProvider', 'config': { 'weight_attrs': ['kernel'], 'activation_attrs': ['activation'], 'quantize_output': False } } serialized_quantize_provider = serialize_keras_object(quantize_provider) self.assertEqual(expected_config, serialized_quantize_provider) quantize_provider_from_config = deserialize_keras_object( serialized_quantize_provider, module_objects=globals(), custom_objects=tflite_quantize_registry._types_dict()) self.assertEqual(quantize_provider, quantize_provider_from_config)