def testTransformsDepthwiseConvBNReLUPattern(self):
    model = DepthwiseConv2DModel.get_nonfolded_batchnorm_model(
        post_bn_activation=keras.layers.ReLU(6.0), model_type='functional')
    folded_model = DepthwiseConv2DModel.get_folded_batchnorm_model(
        post_bn_activation=keras.layers.ReLU(6.0), is_quantized=True)

    transformed_model, _ = ModelTransformer(
        model, [default_8bit_transforms.DepthwiseConv2DBatchNormReLU6Fold()
               ]).transform()

    inputs = np.random.standard_normal(
        DepthwiseConv2DModel.get_batched_input_shape())
    self.assertAllClose(
        transformed_model.predict(inputs), folded_model.predict(inputs))
  def testTransformsDepthwiseConvBNReLUPatternPreservesWeights(self):
    # random_init to prevent non-random initialization in resulting
    # in same weights between transformed and non-transformed models.
    model = DepthwiseConv2DModel.get_nonfolded_batchnorm_model(
        post_bn_activation=keras.layers.ReLU(6.0),
        model_type='functional',
        random_init=True)

    transformed_model, _ = ModelTransformer(
        model, [default_8bit_transforms.DepthwiseConv2DBatchNormReLU6Fold()
               ]).transform()

    transformed_weights = transformed_model.get_weights()
    # Remove quantization related weights.
    del transformed_weights[3:8]

    self.assertEqual(len(transformed_weights), len(model.get_weights()))
    for i in range(len(transformed_weights)):
      self.assertAllEqual(transformed_weights[i], model.get_weights()[i])