def testPassingNonPrunedModelToPCQAT(self):
        """Runs PCQAT as CQAT if the input model is not pruned."""
        preserve_sparsity = False
        clustered_model = self._get_clustered_model(preserve_sparsity)

        clustered_model = cluster.strip_clustering(clustered_model)
        nr_of_unique_weights_after = self._get_number_of_unique_weights(
            clustered_model, 0, 'kernel')

        # Check after plain clustering, if there are no zero weights,
        # PCQAT falls back to CQAT
        quant_aware_annotate_model = (
            quantize.quantize_annotate_model(clustered_model))

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme(True))

        self.compile_and_fit(quant_aware_model)
        stripped_pcqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of clustered_model and pcqat_model
        num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
            stripped_pcqat_model, 1, 'kernel')
        self.assertAllEqual(nr_of_unique_weights_after,
                            num_of_unique_weights_pcqat)
    def testEndToEndClusterPreserve(self):
        """Runs CQAT end to end and whole model is quantized."""
        original_model = tf.keras.Sequential(
            [layers.Dense(5, activation='softmax', input_shape=(10, ))])
        clustered_model = cluster.cluster_weights(original_model,
                                                  **self.cluster_params)
        self.compile_and_fit(clustered_model)
        clustered_model = cluster.strip_clustering(clustered_model)
        num_of_unique_weights_clustering = self._get_number_of_unique_weights(
            clustered_model, 0, 'kernel')

        quant_aware_annotate_model = (
            quantize.quantize_annotate_model(clustered_model))

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme())

        self.compile_and_fit(quant_aware_model)
        stripped_cqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of a certain layer of
        # clustered_model and pcqat_model
        num_of_unique_weights_cqat = self._get_number_of_unique_weights(
            stripped_cqat_model, 1, 'kernel')
        self.assertAllEqual(num_of_unique_weights_clustering,
                            num_of_unique_weights_cqat)
def prune_cluster_preserve_quantize_model(clustered_model, preserve_sparsity):
    """Prune_cluster_preserve QAT model."""

    pcqat_epoch = 1
    quant_aware_annotate_model = quantize.quantize_annotate_model(
        clustered_model)
    quant_aware_model = quantize.quantize_apply(
        quant_aware_annotate_model,
        scheme=default_8bit_cluster_preserve_quantize_scheme.
        Default8BitClusterPreserveQuantizeScheme(preserve_sparsity))

    callbacks = []
    quant_aware_model = _train_model(quant_aware_model, callbacks, pcqat_epoch)
    pcqat_stripped = cluster_utils.strip_clustering_cqat(quant_aware_model)

    return quant_aware_model, pcqat_stripped
Exemplo n.º 4
0
def main(unused_args):
    # Load the MNIST dataset.
    mnist = tf.keras.datasets.mnist
    # Shuffle and split data to generate training and testing datasets
    (train_images, train_labels), (test_images,
                                   test_labels) = mnist.load_data()
    # Normalize the input images so that each pixel value is between 0 and 1.
    train_images = train_images / 255.0
    test_images = test_images / 255.0

    input_shape = (28, 28)
    # Create and train the baseline model
    model = setup_model(input_shape, train_images, train_labels)
    # Apply clustering API and retrain the model
    clustered_model = cluster_model(model, train_images, train_labels)
    print('Apply clustering:')
    clst_acc = evaluate_model_fp32(clustered_model, test_images, test_labels)
    clustered_model_stripped = tfmot_cluster.strip_clustering(clustered_model)
    print('Apply cluster-preserve quantization aware training (cqat):')
    # Start from pretrained clustered model, apply CQAT API, retrain the model
    cqat_model = cluster_preserve_quantize_model(clustered_model_stripped,
                                                 train_images, train_labels)
    cqat_acc = evaluate_model_fp32(cqat_model, test_images, test_labels)
    # This only removes extra variables introduced by clustering
    # but the quantize_wrapper stays
    cqat_model_stripped = strip_clustering_cqat(cqat_model)

    # Compare between clustering and cqat in terms of FP32 accuracy
    # and numbers of unique weights
    print('FP32 accuracy of clustered model:', clst_acc)
    print_unique_weights(clustered_model_stripped)
    print('FP32 accuracy of cqat model:', cqat_acc)
    print_unique_weights(cqat_model_stripped)

    # See consistency of accuracy from TF to TFLite
    converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model_stripped)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    cqat_model_stripped_tflite = converter.convert()
    interpreter = tf.lite.Interpreter(model_content=cqat_model_stripped_tflite)
    interpreter.allocate_tensors()
    test_accuracy = evaluate_model(interpreter, test_images, test_labels)

    with open('cqat.tflite', 'wb') as f:
        f.write(cqat_model_stripped_tflite)

    print('CQAT TFLite test_accuracy:', test_accuracy)
    print('CQAT TF test accuracy:', cqat_acc)
    def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model):
        """PCQAT training on the input model."""
        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme(preserve_sparsity))

        self.compile_and_fit(quant_aware_model)

        stripped_pcqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of clustered_model and pcqat_model
        # layer 0 is the quantize_layer
        num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
            stripped_pcqat_model, 1, 'kernel')

        sparsity_pcqat = self._get_sparsity(stripped_pcqat_model)

        return sparsity_pcqat, num_of_unique_weights_pcqat
    def testEndToEndClusterPreserveOneLayer(self):
        """Runs CQAT end to end and model is quantized only for a single layer."""
        original_model = tf.keras.Sequential([
            layers.Dense(5, activation='relu', input_shape=(10, )),
            layers.Dense(5,
                         activation='softmax',
                         input_shape=(10, ),
                         name='qat')
        ])
        clustered_model = cluster.cluster_weights(original_model,
                                                  **self.cluster_params)
        self.compile_and_fit(clustered_model)
        clustered_model = cluster.strip_clustering(clustered_model)
        num_of_unique_weights_clustering = self._get_number_of_unique_weights(
            clustered_model, 1, 'kernel')

        def apply_quantization_to_dense(layer):
            if isinstance(layer, tf.keras.layers.Dense):
                if layer.name == 'qat':
                    return quantize.quantize_annotate_layer(layer)
            return layer

        quant_aware_annotate_model = tf.keras.models.clone_model(
            clustered_model,
            clone_function=apply_quantization_to_dense,
        )

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme())

        self.compile_and_fit(quant_aware_model)

        stripped_cqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of a certain layer of
        # clustered_model and pcqat_model
        num_of_unique_weights_cqat = self._get_number_of_unique_weights(
            stripped_cqat_model, 1, 'kernel')
        self.assertAllEqual(num_of_unique_weights_clustering,
                            num_of_unique_weights_cqat)