def apply(self, model, layer_quantize_map):
        """Implement TFLite transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeProvider`s which may have been passed with
        layers.

    Returns:
      Transformed Keras model to better match TensorFlow Lite backend.
    """

        # TODO(pulkitb): Sequential models not supported yet. Remove once support is
        # added.
        if isinstance(model, keras.Sequential):
            return model

        transforms = [
            tflite_transforms.DepthwiseConv2DBatchNormReLU6Fold(),
            tflite_transforms.Conv2DBatchNormReLU6Fold(),
            tflite_transforms.Conv2DBatchNormFold(),
        ]

        return model_transformer.ModelTransformer(
            model, transforms, layer_quantize_map.keys(),
            layer_quantize_map).transform()
    def testTransformsDepthwiseConvBNReLUPattern(self):
        model = DepthwiseConv2DModel.get_nonfolded_batchnorm_model(
            post_bn_activation=keras.layers.ReLU(6.0), model_type='functional')
        folded_model = DepthwiseConv2DModel.get_folded_batchnorm_model(
            post_bn_activation=keras.layers.ReLU(6.0), is_quantized=True)

        transformed_model, _ = ModelTransformer(
            model, [tflite_transforms.DepthwiseConv2DBatchNormReLU6Fold()
                    ]).transform()

        inputs = np.random.standard_normal(
            DepthwiseConv2DModel.get_batched_input_shape())
        self.assertAllClose(transformed_model.predict(inputs),
                            folded_model.predict(inputs))
    def testTransformsDepthwiseConvBNReLUPatternPreservesWeights(self):
        # random_init to prevent non-random initialization in resulting
        # in same weights between transformed and non-transformed models.
        model = DepthwiseConv2DModel.get_nonfolded_batchnorm_model(
            post_bn_activation=keras.layers.ReLU(6.0),
            model_type='functional',
            random_init=True)

        transformed_model, _ = ModelTransformer(
            model, [tflite_transforms.DepthwiseConv2DBatchNormReLU6Fold()
                    ]).transform()

        transformed_weights = transformed_model.get_weights()
        # Remove quantization related weights.
        del transformed_weights[3:8]

        self.assertEqual(len(transformed_weights), len(model.get_weights()))
        for i in range(len(transformed_weights)):
            self.assertAllEqual(transformed_weights[i], model.get_weights()[i])