def testConcatTransform(self): r"""Tests the Concat Transform. Input / \ Dense Dense \ / Concat One Dense layer has a pre-specified QuantizeConfig, whereas the other does not. The Transform should ensure both the output FakeQuants are disabled, and only a FakeQuant after Concat is present. """ dense_1 = keras.layers.Dense(3) dense_2 = keras.layers.Dense(3) concat = keras.layers.Concatenate() inp = keras.layers.Input((2, )) x1 = dense_1(inp) x2 = dense_2(inp) x = concat([x1, x2]) model = keras.Model(inp, x) layer_metadata = { # dense_1 has an existing quantize_config. dense_1.name: { 'quantize_config': default_8bit_quantize_configs.Default8BitOutputQuantizeConfig( ) } } _, updated_metadata = ModelTransformer( model, [default_8bit_transforms.ConcatTransform()], layer_metadata=layer_metadata).transform() concat_quantize_config = updated_metadata.get( concat.name).get('quantize_config') # Concat should quantize the output. self.assertIsInstance( concat_quantize_config, default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None)) dense_1_quantize_config = updated_metadata.get( dense_1.name).get('quantize_config') # The existing quantize_config should do nothing for outputs. self.assertIsInstance( dense_1_quantize_config, default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) self.assertEmpty(dense_1_quantize_config.get_output_quantizers(None)) dense_2_quantize_config = updated_metadata.get( dense_2.name).get('quantize_config') # The quantize_config from registry should do nothing at output. self.assertEqual('Default8BitQuantizeConfig', dense_2_quantize_config.__class__.__name__) self.assertEmpty(dense_2_quantize_config.get_output_quantizers(None))
def _replace(self, bn_layer_node, conv_layer_node): if _has_custom_quantize_config(bn_layer_node, conv_layer_node): return bn_layer_node conv_layer_node.layer['config']['activation'] = \ keras.activations.serialize(quantize_aware_activation.NoOpActivation()) bn_layer_node.metadata['quantize_config'] = \ default_8bit_quantize_configs.Default8BitOutputQuantizeConfig() return bn_layer_node
def replacement(self, match_layer): bn_layer_node, conv_layer_node = match_layer, match_layer.input_layers[0] if self._has_custom_quantize_config(bn_layer_node, conv_layer_node): return match_layer conv_layer_node.layer['config']['activation'] = \ keras.activations.serialize(quantize_aware_activation.NoOpActivation()) bn_layer_node.metadata['quantize_config'] = \ default_8bit_quantize_configs.Default8BitOutputQuantizeConfig() return match_layer
def replacement(self, match_layer): concat_layer_node = match_layer feeding_layer_nodes = match_layer.input_layers default_registry = ( default_8bit_quantize_registry.Default8BitQuantizeRegistry()) feed_quantize_configs = [] for feed_layer_node in feeding_layer_nodes: quantize_config = feed_layer_node.metadata.get('quantize_config') if not quantize_config: layer_class = self._get_layer_type( feed_layer_node.layer['class_name']) if layer_class is None: # Concat has an input layer we don't recognize. Return. return match_layer if layer_class == keras.layers.Concatenate: # Input layer to Concat is also Concat. Don't quantize it. feed_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.NoOpQuantizeConfig()) continue if not default_registry._is_supported_layer(layer_class): # Feeding layer is not supported by registry return match_layer quantize_config = default_registry._get_quantize_config( layer_class) feed_layer_node.metadata['quantize_config'] = quantize_config feed_quantize_configs.append(quantize_config) # TODO(pulkitb): this currently only disables output quantize config, but # cannot properly handle if the FQ was added to the activation. Hand this # properly. for quantize_config in feed_quantize_configs: self._disable_output_quantize(quantize_config) if not concat_layer_node.metadata.get('quantize_config'): concat_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.Default8BitOutputQuantizeConfig( )) return concat_layer_node
def quantize(layer, quantize_config=None): if quantize_config is None: quantize_config = default_8bit_quantize_configs.Default8BitOutputQuantizeConfig( ) return quantize_wrapper.QuantizeWrapper(layer, quantize_config=quantize_config)