def quantize_scope(*args):
    """Provides a scope in which Quantized layers and models can be deserialized.

  If a keras model or layer has been quantized, it needs to be within this scope
  to be successfully deserialized.

  Args:
    *args: Variable length list of dictionaries of name, class pairs to add to
    the scope created by this method.

  Returns:
    Object of type `CustomObjectScope` with quantization objects included.

  Example:

  ```python
  keras.models.save_model(quantized_model, keras_file)

  with quantize_scope():
    loaded_model = keras.models.load_model(keras_file)
  ```
  """
    quantization_objects = {
        'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate,
        'QuantizeAwareActivation':
        quantize_aware_activation.QuantizeAwareActivation,
        'QuantizeWrapper': quantize_wrapper.QuantizeWrapper,
        # TODO(tf-mot): add way for different quantization schemes to modify this.
        '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D,  # pylint: disable=protected-access
        '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D  # pylint: disable=protected-access
    }
    quantization_objects.update(tflite_quantize_registry._types_dict())  # pylint: disable=protected-access
    quantization_objects.update(quantizers._types_dict())  # pylint: disable=protected-access

    return custom_object_scope(*(args + (quantization_objects, )))
Exemple #2
0
def quantize_scope(*args):
  """Required scope to deserialize quantized models stored in tf.keras h5 format.

  Args:
    *args: Variable length list of dictionaries of name, class pairs to add to
      the scope created by this method.

  Returns:
    Object of type `CustomObjectScope` with quantization objects included.

  Example:

  ```python
  tf.keras.models.save_model(quantized_model, keras_file)

  with quantize_scope():
    loaded_model = tf.keras.models.load_model(keras_file)
  ```
  """
  quantization_objects = {
      'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate,
      'QuantizeAwareActivation':
          quantize_aware_activation.QuantizeAwareActivation,
      'NoOpActivation': quantize_aware_activation.NoOpActivation,
      'QuantizeWrapper': quantize_wrapper.QuantizeWrapper,
      'QuantizeLayer': quantize_layer.QuantizeLayer,
      # TODO(tf-mot): add way for different quantization schemes to modify this.
      '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D,  # pylint: disable=protected-access
      '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D  # pylint: disable=protected-access
  }
  quantization_objects.update(tflite_quantize_registry._types_dict())  # pylint: disable=protected-access
  quantization_objects.update(quantizers._types_dict())  # pylint: disable=protected-access

  return tf.keras.utils.custom_object_scope(*(args + (quantization_objects,)))
Exemple #3
0
  def testSerializationQuantizeWrapper(self):
    input_shape = (2,)
    layer = keras.layers.Dense(3)
    wrapper = QuantizeWrapper(
        layer=layer,
        quantize_provider=self.quantize_registry.get_quantize_provider(layer),
        input_shape=input_shape)

    custom_objects = {
        'QuantizeAwareActivation': QuantizeAwareActivation,
        'QuantizeWrapper': QuantizeWrapper
    }
    custom_objects.update(tflite_quantize_registry._types_dict())

    serialized_wrapper = serialize_layer(wrapper)
    with keras.utils.custom_object_scope(custom_objects):
      wrapper_from_config = deserialize_layer(serialized_wrapper)

    self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
  def testSerialization(self):
    expected_config = {
        'class_name': 'TFLiteQuantizeConfigRNN',
        'config': {
            'weight_attrs': [['kernel', 'recurrent_kernel'],
                             ['kernel', 'recurrent_kernel']],
            'activation_attrs': [['activation', 'recurrent_activation'],
                                 ['activation', 'recurrent_activation']],
            'quantize_output': False
        }
    }
    serialized_quantize_config = serialize_keras_object(self.quantize_config)

    self.assertEqual(expected_config, serialized_quantize_config)

    quantize_config_from_config = deserialize_keras_object(
        serialized_quantize_config,
        module_objects=globals(),
        custom_objects=tflite_quantize_registry._types_dict())

    self.assertEqual(self.quantize_config, quantize_config_from_config)
Exemple #5
0
  def testSerialization(self):
    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)

    expected_config = {
        'class_name': 'TFLiteQuantizeProvider',
        'config': {
            'weight_attrs': ['kernel'],
            'activation_attrs': ['activation'],
            'quantize_output': False
        }
    }
    serialized_quantize_provider = serialize_keras_object(quantize_provider)

    self.assertEqual(expected_config, serialized_quantize_provider)

    quantize_provider_from_config = deserialize_keras_object(
        serialized_quantize_provider,
        module_objects=globals(),
        custom_objects=tflite_quantize_registry._types_dict())

    self.assertEqual(quantize_provider, quantize_provider_from_config)