def testPassingNonPrunedModelToPCQAT(self): """Runs PCQAT as CQAT if the input model is not pruned.""" preserve_sparsity = False clustered_model = self._get_clustered_model(preserve_sparsity) clustered_model = cluster.strip_clustering(clustered_model) nr_of_unique_weights_after = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') # Check after plain clustering, if there are no zero weights, # PCQAT falls back to CQAT quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model)) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme(True)) self.compile_and_fit(quant_aware_model) stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of clustered_model and pcqat_model num_of_unique_weights_pcqat = self._get_number_of_unique_weights( stripped_pcqat_model, 1, 'kernel') self.assertAllEqual(nr_of_unique_weights_after, num_of_unique_weights_pcqat)
def testEndToEndClusterPreserve(self): """Runs CQAT end to end and whole model is quantized.""" original_model = tf.keras.Sequential( [layers.Dense(5, activation='softmax', input_shape=(10, ))]) clustered_model = cluster.cluster_weights(original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model)) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat)
def prune_cluster_preserve_quantize_model(clustered_model, preserve_sparsity): """Prune_cluster_preserve QAT model.""" pcqat_epoch = 1 quant_aware_annotate_model = quantize.quantize_annotate_model( clustered_model) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme(preserve_sparsity)) callbacks = [] quant_aware_model = _train_model(quant_aware_model, callbacks, pcqat_epoch) pcqat_stripped = cluster_utils.strip_clustering_cqat(quant_aware_model) return quant_aware_model, pcqat_stripped
def main(unused_args): # Load the MNIST dataset. mnist = tf.keras.datasets.mnist # Shuffle and split data to generate training and testing datasets (train_images, train_labels), (test_images, test_labels) = mnist.load_data() # Normalize the input images so that each pixel value is between 0 and 1. train_images = train_images / 255.0 test_images = test_images / 255.0 input_shape = (28, 28) # Create and train the baseline model model = setup_model(input_shape, train_images, train_labels) # Apply clustering API and retrain the model clustered_model = cluster_model(model, train_images, train_labels) print('Apply clustering:') clst_acc = evaluate_model_fp32(clustered_model, test_images, test_labels) clustered_model_stripped = tfmot_cluster.strip_clustering(clustered_model) print('Apply cluster-preserve quantization aware training (cqat):') # Start from pretrained clustered model, apply CQAT API, retrain the model cqat_model = cluster_preserve_quantize_model(clustered_model_stripped, train_images, train_labels) cqat_acc = evaluate_model_fp32(cqat_model, test_images, test_labels) # This only removes extra variables introduced by clustering # but the quantize_wrapper stays cqat_model_stripped = strip_clustering_cqat(cqat_model) # Compare between clustering and cqat in terms of FP32 accuracy # and numbers of unique weights print('FP32 accuracy of clustered model:', clst_acc) print_unique_weights(clustered_model_stripped) print('FP32 accuracy of cqat model:', cqat_acc) print_unique_weights(cqat_model_stripped) # See consistency of accuracy from TF to TFLite converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model_stripped) converter.optimizations = [tf.lite.Optimize.DEFAULT] cqat_model_stripped_tflite = converter.convert() interpreter = tf.lite.Interpreter(model_content=cqat_model_stripped_tflite) interpreter.allocate_tensors() test_accuracy = evaluate_model(interpreter, test_images, test_labels) with open('cqat.tflite', 'wb') as f: f.write(cqat_model_stripped_tflite) print('CQAT TFLite test_accuracy:', test_accuracy) print('CQAT TF test accuracy:', cqat_acc)
def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model): """PCQAT training on the input model.""" quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme(preserve_sparsity)) self.compile_and_fit(quant_aware_model) stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of clustered_model and pcqat_model # layer 0 is the quantize_layer num_of_unique_weights_pcqat = self._get_number_of_unique_weights( stripped_pcqat_model, 1, 'kernel') sparsity_pcqat = self._get_sparsity(stripped_pcqat_model) return sparsity_pcqat, num_of_unique_weights_pcqat
def testEndToEndClusterPreserveOneLayer(self): """Runs CQAT end to end and model is quantized only for a single layer.""" original_model = tf.keras.Sequential([ layers.Dense(5, activation='relu', input_shape=(10, )), layers.Dense(5, activation='softmax', input_shape=(10, ), name='qat') ]) clustered_model = cluster.cluster_weights(original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 1, 'kernel') def apply_quantization_to_dense(layer): if isinstance(layer, tf.keras.layers.Dense): if layer.name == 'qat': return quantize.quantize_annotate_layer(layer) return layer quant_aware_annotate_model = tf.keras.models.clone_model( clustered_model, clone_function=apply_quantization_to_dense, ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat)