Esempio n. 1
0
 def setUp(self):
   super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp()
   self.default_8bit_quantize_registry = (
       default_8bit_quantize_registry.Default8BitQuantizeRegistry())
   self.cluster_registry = clustering_registry.ClusteringRegistry()
   self.cluster_preserve_quantize_registry = (
       cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry())
Esempio n. 2
0
 def __init__(self, registry=None):
     if registry is None:
         registry = default_8bit_quantize_registry.Default8BitQuantizeRegistry(
         )
     self.registry = registry
     self.wmap = OrderedDict()
     self.amap = OrderedDict()
Esempio n. 3
0
 def setUp(self):
   super(PrunePreserveDefault8bitQuantizeRegistryTest, self).setUp()
   self.default_8bit_quantize_registry = default_8bit_quantize_registry.Default8BitQuantizeRegistry(
   )
   self.prune_registry = prune_registry.PruneRegistry()
   self.prune_preserve_quantize_registry = prune_preserve_quantize_registry.PrunePreserveQuantizeRegistry(
   )
    def get_quantize_config(self, layer):
        """Returns the quantization config with addon sparsity.

    Args:
      layer: input layer to return quantize config for.

    Returns:
      Returns the quantization config with sparsity preserve weight_quantizer.
    """
        quantize_config = (
            default_8bit_quantize_registry.Default8BitQuantizeRegistry(
            ).get_quantize_config(layer))
        prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config(
            layer, quantize_config)

        return prune_aware_quantize_config
    def get_quantize_config(self, layer):
        """Returns the quantization config with weight_quantizer for a given layer.

    Args:
      layer: input layer to return quantize config for.
    Returns:
      Returns the quantization config for cluster preserve weight_quantizer.
    """
        quantize_config = (
            default_8bit_quantize_registry.Default8BitQuantizeRegistry(
            ).get_quantize_config(layer))
        cluster_aware_quantize_config = super(
            Default8bitClusterPreserveQuantizeRegistry,
            self).apply_cluster_preserve_quantize_config(
                layer, quantize_config)

        return cluster_aware_quantize_config
Esempio n. 6
0
    def replacement(self, match_layer):
        concat_layer_node = match_layer
        feeding_layer_nodes = match_layer.input_layers

        default_registry = (
            default_8bit_quantize_registry.Default8BitQuantizeRegistry())

        feed_quantize_configs = []
        for feed_layer_node in feeding_layer_nodes:
            quantize_config = feed_layer_node.metadata.get('quantize_config')
            if not quantize_config:
                layer_class = self._get_layer_type(
                    feed_layer_node.layer['class_name'])
                if layer_class is None:
                    # Concat has an input layer we don't recognize. Return.
                    return match_layer

                if layer_class == keras.layers.Concatenate:
                    # Input layer to Concat is also Concat. Don't quantize it.
                    feed_layer_node.metadata['quantize_config'] = (
                        default_8bit_quantize_configs.NoOpQuantizeConfig())
                    continue

                if not default_registry._is_supported_layer(layer_class):
                    # Feeding layer is not supported by registry
                    return match_layer

                quantize_config = default_registry._get_quantize_config(
                    layer_class)
                feed_layer_node.metadata['quantize_config'] = quantize_config

            feed_quantize_configs.append(quantize_config)

        # TODO(pulkitb): this currently only disables output quantize config, but
        # cannot properly handle if the FQ was added to the activation. Hand this
        # properly.
        for quantize_config in feed_quantize_configs:
            self._disable_output_quantize(quantize_config)

        if not concat_layer_node.metadata.get('quantize_config'):
            concat_layer_node.metadata['quantize_config'] = (
                default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
                ))

        return concat_layer_node
Esempio n. 7
0
 def get_quantize_registry(self):
     return default_8bit_quantize_registry.Default8BitQuantizeRegistry()
 def get_quantize_registry(self):
     return (default_8bit_quantize_registry.Default8BitQuantizeRegistry(
         disable_per_axis=self._disable_per_axis))