def testConcatTransform(self):
        r"""Tests the Concat Transform.

               Input
              /     \
         Dense       Dense
             \      /
              Concat

      One Dense layer has a pre-specified QuantizeConfig, whereas the other does
      not. The Transform should ensure both the output FakeQuants are disabled,
      and only a FakeQuant after Concat is present.
    """
        dense_1 = keras.layers.Dense(3)
        dense_2 = keras.layers.Dense(3)
        concat = keras.layers.Concatenate()

        inp = keras.layers.Input((2, ))
        x1 = dense_1(inp)
        x2 = dense_2(inp)
        x = concat([x1, x2])
        model = keras.Model(inp, x)

        layer_metadata = {
            # dense_1 has an existing quantize_config.
            dense_1.name: {
                'quantize_config':
                default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
                )
            }
        }
        _, updated_metadata = ModelTransformer(
            model, [default_8bit_transforms.ConcatTransform()],
            layer_metadata=layer_metadata).transform()

        concat_quantize_config = updated_metadata.get(
            concat.name).get('quantize_config')
        # Concat should quantize the output.
        self.assertIsInstance(
            concat_quantize_config,
            default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
        self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None))

        dense_1_quantize_config = updated_metadata.get(
            dense_1.name).get('quantize_config')
        # The existing quantize_config should do nothing for outputs.
        self.assertIsInstance(
            dense_1_quantize_config,
            default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
        self.assertEmpty(dense_1_quantize_config.get_output_quantizers(None))

        dense_2_quantize_config = updated_metadata.get(
            dense_2.name).get('quantize_config')
        # The quantize_config from registry should do nothing at output.
        self.assertEqual('Default8BitQuantizeConfig',
                         dense_2_quantize_config.__class__.__name__)
        self.assertEmpty(dense_2_quantize_config.get_output_quantizers(None))
예제 #2
0
  def _replace(self, bn_layer_node, conv_layer_node):
    if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
      return bn_layer_node

    conv_layer_node.layer['config']['activation'] = \
      keras.activations.serialize(quantize_aware_activation.NoOpActivation())
    bn_layer_node.metadata['quantize_config'] = \
      default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()

    return bn_layer_node
예제 #3
0
  def replacement(self, match_layer):
    bn_layer_node, conv_layer_node = match_layer, match_layer.input_layers[0]

    if self._has_custom_quantize_config(bn_layer_node, conv_layer_node):
      return match_layer

    conv_layer_node.layer['config']['activation'] = \
      keras.activations.serialize(quantize_aware_activation.NoOpActivation())
    bn_layer_node.metadata['quantize_config'] = \
      default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()

    return match_layer
예제 #4
0
    def replacement(self, match_layer):
        concat_layer_node = match_layer
        feeding_layer_nodes = match_layer.input_layers

        default_registry = (
            default_8bit_quantize_registry.Default8BitQuantizeRegistry())

        feed_quantize_configs = []
        for feed_layer_node in feeding_layer_nodes:
            quantize_config = feed_layer_node.metadata.get('quantize_config')
            if not quantize_config:
                layer_class = self._get_layer_type(
                    feed_layer_node.layer['class_name'])
                if layer_class is None:
                    # Concat has an input layer we don't recognize. Return.
                    return match_layer

                if layer_class == keras.layers.Concatenate:
                    # Input layer to Concat is also Concat. Don't quantize it.
                    feed_layer_node.metadata['quantize_config'] = (
                        default_8bit_quantize_configs.NoOpQuantizeConfig())
                    continue

                if not default_registry._is_supported_layer(layer_class):
                    # Feeding layer is not supported by registry
                    return match_layer

                quantize_config = default_registry._get_quantize_config(
                    layer_class)
                feed_layer_node.metadata['quantize_config'] = quantize_config

            feed_quantize_configs.append(quantize_config)

        # TODO(pulkitb): this currently only disables output quantize config, but
        # cannot properly handle if the FQ was added to the activation. Hand this
        # properly.
        for quantize_config in feed_quantize_configs:
            self._disable_output_quantize(quantize_config)

        if not concat_layer_node.metadata.get('quantize_config'):
            concat_layer_node.metadata['quantize_config'] = (
                default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
                ))

        return concat_layer_node
예제 #5
0
def quantize(layer, quantize_config=None):
    if quantize_config is None:
        quantize_config = default_8bit_quantize_configs.Default8BitOutputQuantizeConfig(
        )
    return quantize_wrapper.QuantizeWrapper(layer,
                                            quantize_config=quantize_config)