コード例 #1
0
ファイル: quantize.py プロジェクト: Ruomei/model-optimization
def quantize_scope(*args):
    """Required scope to deserialize quantized models stored in tf.keras h5 format.

  Args:
    *args: Variable length list of dictionaries of name, class pairs to add to
      the scope created by this method.

  Returns:
    Object of type `CustomObjectScope` with quantization objects included.

  Example:

  ```python
  tf.keras.models.save_model(quantized_model, keras_file)

  with quantize_scope():
    loaded_model = tf.keras.models.load_model(keras_file)
  ```
  """
    quantization_objects = {
        'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate,
        'QuantizeAwareActivation':
        quantize_aware_activation.QuantizeAwareActivation,
        'NoOpActivation': quantize_aware_activation.NoOpActivation,
        'QuantizeWrapper': quantize_wrapper.QuantizeWrapper,
        'QuantizeLayer': quantize_layer.QuantizeLayer,
        # TODO(tf-mot): add way for different quantization schemes to modify this.
        '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D,  # pylint: disable=protected-access
        '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D  # pylint: disable=protected-access
    }
    quantization_objects.update(default_8bit_quantize_registry._types_dict())  # pylint: disable=protected-access
    quantization_objects.update(quantizers._types_dict())  # pylint: disable=protected-access

    return tf.keras.utils.custom_object_scope(*(args +
                                                (quantization_objects, )))
コード例 #2
0
ファイル: quantize.py プロジェクト: nowke/model-optimization
def quantize_scope(*args):
    """Scope which can be used to deserialize quantized Keras models and layers.

  Under `quantize_scope`, Keras methods such as `tf.keras.load_model` and
  `tf.keras.models.model_from_config` will be able to deserialize Keras models
  and layers which contain quantization classes such as `QuantizeConfig`
  and `Quantizer`.

  Example:

  ```python
  tf.keras.models.save_model(quantized_model, keras_file)

  with quantize_scope():
    loaded_model = tf.keras.models.load_model(keras_file)

  # If your quantized model uses custom objects such as a specific `Quantizer`,
  # you can pass them to quantize_scope to deserialize your model.
  with quantize_scope({'FixedRangeQuantizer', FixedRangeQuantizer}
    loaded_model = tf.keras.models.load_model(keras_file)
  ```

  For further understanding, see `tf.keras.utils.custom_object_scope`.

  Args:
    *args: Variable length list of dictionaries of `{name, class}` pairs to add
      to the scope created by this method.

  Returns:
    Object of type `CustomObjectScope` with quantization objects included.
  """
    quantization_objects = {
        'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate,
        'QuantizeAwareActivation':
        quantize_aware_activation.QuantizeAwareActivation,
        'NoOpActivation': quantize_aware_activation.NoOpActivation,
        'QuantizeWrapper': quantize_wrapper.QuantizeWrapper,
        'QuantizeLayer': quantize_layer.QuantizeLayer,
        # TODO(tf-mot): add way for different quantization schemes to modify this.
        '_DepthwiseConvBatchNorm2D': conv_batchnorm._DepthwiseConvBatchNorm2D,  # pylint: disable=protected-access
        '_ConvBatchNorm2D': conv_batchnorm._ConvBatchNorm2D  # pylint: disable=protected-access
    }
    quantization_objects.update(default_8bit_quantize_registry._types_dict())  # pylint: disable=protected-access
    quantization_objects.update(quantizers._types_dict())  # pylint: disable=protected-access

    return tf.keras.utils.custom_object_scope(*(args +
                                                (quantization_objects, )))
コード例 #3
0
    def testSerializationQuantizeWrapper(self):
        input_shape = (2, )
        layer = keras.layers.Dense(3)
        wrapper = QuantizeWrapper(
            layer=layer,
            quantize_config=self.quantize_registry.get_quantize_config(layer),
            input_shape=input_shape)

        custom_objects = {
            'QuantizeAwareActivation': QuantizeAwareActivation,
            'QuantizeWrapper': QuantizeWrapper
        }
        custom_objects.update(default_8bit_quantize_registry._types_dict())

        serialized_wrapper = serialize_layer(wrapper)
        with custom_object_scope(custom_objects):
            wrapper_from_config = deserialize_layer(serialized_wrapper)

        self.assertEqual(wrapper_from_config.get_config(),
                         wrapper.get_config())
コード例 #4
0
  def testSerialization(self):
    expected_config = {
        'class_name': 'Default8BitQuantizeConfigRNN',
        'config': {
            'weight_attrs': [['kernel', 'recurrent_kernel'],
                             ['kernel', 'recurrent_kernel']],
            'activation_attrs': [['activation', 'recurrent_activation'],
                                 ['activation', 'recurrent_activation']],
            'quantize_output': False
        }
    }
    serialized_quantize_config = serialize_keras_object(self.quantize_config)

    self.assertEqual(expected_config, serialized_quantize_config)

    quantize_config_from_config = deserialize_keras_object(
        serialized_quantize_config,
        module_objects=globals(),
        custom_objects=default_8bit_quantize_registry._types_dict())

    self.assertEqual(self.quantize_config, quantize_config_from_config)