def apply(self, model, layer_quantize_map): """Implement default 8-bit transforms. Currently this means the following. 1. Pull activations into layers, and apply fuse activations. (TODO) 2. Modify range in incoming layers for Concat. (TODO) 3. Fuse Conv2D/DepthwiseConv2D + BN into single layer. Args: model: Keras model to be quantized. layer_quantize_map: Map with keys as layer names, and values as dicts containing custom `QuantizeConfig`s which may have been passed with layers. Returns: (Transformed Keras model to better match TensorFlow Lite backend, updated layer quantize map.) """ transforms = [ default_8bit_transforms.InputLayerQuantize(), default_8bit_transforms.Conv2DBatchNormReLUQuantize(), default_8bit_transforms.Conv2DBatchNormActivationQuantize(), default_8bit_transforms.Conv2DBatchNormQuantize(), default_8bit_transforms.ConcatTransform6Inputs(), default_8bit_transforms.ConcatTransform5Inputs(), default_8bit_transforms.ConcatTransform4Inputs(), default_8bit_transforms.ConcatTransform3Inputs(), default_8bit_transforms.ConcatTransform(), default_8bit_transforms.AddReLUQuantize(), default_8bit_transforms.AddActivationQuantize(), ] return model_transformer.ModelTransformer( model, transforms, layer_quantize_map.keys(), layer_quantize_map).transform()
def testConcatTransform(self): r"""Tests the Concat Transform. Input / \ Dense Dense \ / Concat One Dense layer has a pre-specified QuantizeConfig, whereas the other does not. The Transform should ensure both the output FakeQuants are disabled, and only a FakeQuant after Concat is present. """ dense_1 = keras.layers.Dense(3) dense_2 = keras.layers.Dense(3) concat = keras.layers.Concatenate() inp = keras.layers.Input((2, )) x1 = dense_1(inp) x2 = dense_2(inp) x = concat([x1, x2]) model = keras.Model(inp, x) layer_metadata = { # dense_1 has an existing quantize_config. dense_1.name: { 'quantize_config': default_8bit_quantize_configs.Default8BitOutputQuantizeConfig( ) } } _, updated_metadata = ModelTransformer( model, [default_8bit_transforms.ConcatTransform()], layer_metadata=layer_metadata).transform() concat_quantize_config = updated_metadata.get( concat.name).get('quantize_config') # Concat should quantize the output. self.assertIsInstance( concat_quantize_config, default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None)) dense_1_quantize_config = updated_metadata.get( dense_1.name).get('quantize_config') # The existing quantize_config should do nothing for outputs. self.assertIsInstance( dense_1_quantize_config, default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) self.assertEmpty(dense_1_quantize_config.get_output_quantizers(None)) dense_2_quantize_config = updated_metadata.get( dense_2.name).get('quantize_config') # The quantize_config from registry should do nothing at output. self.assertEqual('Default8BitQuantizeConfig', dense_2_quantize_config.__class__.__name__) self.assertEmpty(dense_2_quantize_config.get_output_quantizers(None))
def testConcatMultipleLevels(self): r"""Tests case when concats applied to concats. Input --------------. / \ | | Dense Dense | | \ / | | Concat Dense Dense \ / | Concat | \ / Concat The last Concat layer should be quantized but the rest of the outputs should just feed into it. """ inp = keras.layers.Input((3,)) x1 = keras.layers.Dense(3)(inp) x2 = keras.layers.Dense(3)(inp) x3 = keras.layers.Dense(3)(inp) x4 = keras.layers.Dense(3)(inp) c1 = keras.layers.Concatenate()([x1, x2]) c2 = keras.layers.Concatenate()([c1, x3]) c3 = keras.layers.Concatenate()([c2, x4]) model = keras.Model(inp, c3) model.summary() _, layer_metadata = ModelTransformer( model, [default_8bit_transforms.ConcatTransform()]).transform() for layer in model.layers[1:-1]: quantize_config = layer_metadata[layer.name].get('quantize_config') self.assertEmpty(quantize_config.get_output_quantizers(None)) c3_layer = model.layers[-1] quantize_config = layer_metadata[c3_layer.name].get('quantize_config') self.assertIsInstance( quantize_config, default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) self.assertNotEmpty(quantize_config.get_output_quantizers(None))