def testTransformsConvBNPattern(self): model = Conv2DModel.get_nonfolded_batchnorm_model( model_type='functional') folded_model = Conv2DModel.get_folded_batchnorm_model( is_quantized=True) with quantize.quantize_scope(): transformed_model, _ = ModelTransformer( model, [default_8bit_transforms.Conv2DBatchNormFold()]).transform() inputs = np.random.standard_normal(Conv2DModel.get_batched_input_shape()) self.assertAllClose( transformed_model.predict(inputs), folded_model.predict(inputs))
def testTransformsConvBNPatternPreservesWeights(self): # random_init to prevent non-random initialization in resulting # in same weights between transformed and non-transformed models. model = Conv2DModel.get_nonfolded_batchnorm_model( model_type='functional', random_init=True) transformed_model, _ = ModelTransformer( model, [default_8bit_transforms.Conv2DBatchNormFold()]).transform() transformed_weights = transformed_model.get_weights() # Remove quantization related weights. del transformed_weights[3:8] self.assertEqual(len(transformed_weights), len(model.get_weights())) for i in range(len(transformed_weights)): self.assertAllEqual(transformed_weights[i], model.get_weights()[i])