コード例 #1
0
  def testTransformsConvBNPattern(self):
    model = Conv2DModel.get_nonfolded_batchnorm_model(
        model_type='functional')
    folded_model = Conv2DModel.get_folded_batchnorm_model(
        is_quantized=True)

    with quantize.quantize_scope():
      transformed_model, _ = ModelTransformer(
          model, [default_8bit_transforms.Conv2DBatchNormFold()]).transform()

    inputs = np.random.standard_normal(Conv2DModel.get_batched_input_shape())
    self.assertAllClose(
        transformed_model.predict(inputs), folded_model.predict(inputs))
コード例 #2
0
    def testTransformsConvBNPatternPreservesWeights(self):
        # random_init to prevent non-random initialization in resulting
        # in same weights between transformed and non-transformed models.
        model = Conv2DModel.get_nonfolded_batchnorm_model(
            model_type='functional', random_init=True)

        transformed_model, _ = ModelTransformer(
            model,
            [default_8bit_transforms.Conv2DBatchNormFold()]).transform()

        transformed_weights = transformed_model.get_weights()
        # Remove quantization related weights.
        del transformed_weights[3:8]

        self.assertEqual(len(transformed_weights), len(model.get_weights()))
        for i in range(len(transformed_weights)):
            self.assertAllEqual(transformed_weights[i], model.get_weights()[i])