def _quantize(layer):  # pylint: disable=missing-docstring
    if (layer.name not in layer_quantize_map and
        layer.name not in requires_output_quantize):
      return layer

    if layer.name in requires_output_quantize:
      if not quantize_registry.supports(layer):
        return layer
      full_quantize_config = quantize_registry.get_quantize_config(layer)
      if not full_quantize_config:
        return layer
      quantize_config = OutputOnlyConfig(full_quantize_config)
    else:
      quantize_config = layer_quantize_map[layer.name].get('quantize_config')
      if not quantize_config and quantize_registry.supports(layer):
        quantize_config = quantize_registry.get_quantize_config(layer)

    if not quantize_config:
      error_msg = (
          'Layer {}:{} is not supported. You can quantize this '
          'layer by passing a `tfmot.quantization.keras.QuantizeConfig` '
          'instance to the `quantize_annotate_layer` '
          'API.')
      raise RuntimeError(
          error_msg.format(layer.name, layer.__class__,
                           quantize_registry.__class__))

    # `QuantizeWrapper` does not copy any additional layer params from
    # `QuantizeAnnotate`. This should generally be fine, but occasionally
    # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
    # TODO(pulkitb): Ensure this does not affect model cloning.
    return quantize_wrapper.QuantizeWrapper(layer, quantize_config)
def quantize_layer(layer, apply_quantization=True, quantize_config=None):
    """Quantizes a layer.

  It is useful for quantization aware training
  Args:
    layer: input layer to quantize
    apply_quantization: if True layer is quantized, otherwise not
    quantize_config: quantization config for special cases such as
      sequence of convolution and batch normalization

  Returns:
    quantized layer
  """
    if apply_quantization:
        scheme = tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme(
        )

        quantize_registry = scheme.get_quantize_registry()

        if not quantize_registry.supports(layer):
            logging.info('layer is not supported: %s', str(layer))
            return layer

        if quantize_config is None:
            quantize_config = quantize_registry.get_quantize_config(layer)
        return quantize_wrapper.QuantizeWrapper(layer, quantize_config)
    else:
        return layer
Example #3
0
  def _quantize(layer):  # pylint: disable=missing-docstring
    if layer.name not in layer_quantize_map:
      return layer

    quantize_config = layer_quantize_map[layer.name].get('quantize_config')
    if not quantize_config and quantize_registry.supports(layer):
      quantize_config = quantize_registry.get_quantize_config(layer)

    if not quantize_config:
      error_msg = ('Could not find a suitable QuantizeConfig for layer {}. '
                   'Either the registry {} should be provide one, or the user '
                   'should provide one while annotating the layer using '
                   'QuantizeAnnotate.')
      raise RuntimeError(error_msg.format(
          layer.__class__, quantize_registry.__class__))

    # `QuantizeWrapper` does not copy any additional layer params from
    # `QuantizeAnnotate`. This should generally be fine, but occasionally
    # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
    # TODO(pulkitb): Ensure this does not affect model cloning.
    return quantize_wrapper.QuantizeWrapper(layer, quantize_config)
Example #4
0
def quantize_layer(layer, apply_quantization=True):
    """Quantizes a layer.

  It is useful for quantization aware training
  Args:
    layer: input layer to quantize
    apply_quantization: if True layer is quantized, otherwise not
      returned

  Returns:
    quantized layer
  """
    if apply_quantization:
        scheme = tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme(
        )

        quantize_registry = scheme.get_quantize_registry()

        if not quantize_registry.supports(layer):
            return layer
        quantize_config = quantize_registry.get_quantize_config(layer)
        return quantize_wrapper.QuantizeWrapper(layer, quantize_config)
    else:
        return layer
Example #5
0
def quantize(layer, quantize_config=None):
    if quantize_config is None:
        quantize_config = default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
        )
    return quantize_wrapper.QuantizeWrapper(layer,
                                            quantize_config=quantize_config)