def testSparsityIsPreservedDuringTraining(self): """Set a specific random seed. Ensures that we get some null weights to test sparsity preservation with. """ tf.random.set_seed(1) # Verifies that training a clustered model with null weights in it # does not destroy the sparsity of the weights. original_model = keras.Sequential([ layers.Dense(5, input_shape=(5, )), layers.Flatten(), ]) # Reset the kernel weights to reflect potential zero drifting of # the cluster centroids first_layer_weights = original_model.layers[0].get_weights() first_layer_weights[0][:][0:2] = 0.0 first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13] first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13] original_model.layers[0].set_weights(first_layer_weights) clustering_params = { "number_of_clusters": 6, "cluster_centroids_init": CentroidInitialization.LINEAR, "preserve_sparsity": True } clustered_model = experimental_cluster.cluster_weights( original_model, **clustering_params) stripped_model_before_tuning = cluster.strip_clustering( clustered_model) nr_of_unique_weights_before = self._get_number_of_unique_weights( stripped_model_before_tuning, 0, "kernel") clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100) stripped_model_after_tuning = cluster.strip_clustering(clustered_model) weights_after_tuning = stripped_model_after_tuning.layers[0].kernel nr_of_unique_weights_after = self._get_number_of_unique_weights( stripped_model_after_tuning, 0, "kernel") # Check after sparsity-aware clustering, despite zero centroid can drift, # the final number of unique weights remains the same self.assertLessEqual(nr_of_unique_weights_after, nr_of_unique_weights_before) # Check that the null weights stayed the same before and after tuning. # There might be new weights that become zeros but sparsity-aware # clustering preserves the original null weights in the original positions # of the weight array self.assertTrue( np.array_equal(first_layer_weights[0][:][0:2], weights_after_tuning[:][0:2])) # Check that the number of unique weights matches the number of clusters. self.assertLessEqual(nr_of_unique_weights_after, clustering_params["number_of_clusters"])
def testSparsityIsPreservedDuringTraining(self): """Set a specific random seed to ensure that we get some null weights to test sparsity preservation with.""" tf.random.set_seed(1) # Verifies that training a clustered model does not destroy the sparsity of # the weights. original_model = keras.Sequential([ layers.Dense(5, input_shape=(5, )), layers.Dense(5), ]) # Using a mininum number of centroids to make it more likely that some # weights will be zero. clustering_params = { "number_of_clusters": 3, "cluster_centroids_init": CentroidInitialization.LINEAR, "preserve_sparsity": True } clustered_model = experimental_cluster.cluster_weights( original_model, **clustering_params) stripped_model_before_tuning = cluster.strip_clustering( clustered_model) weights_before_tuning = stripped_model_before_tuning.get_weights()[0] non_zero_weight_indices_before_tuning = np.nonzero( weights_before_tuning) clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1) stripped_model_after_tuning = cluster.strip_clustering(clustered_model) weights_after_tuning = stripped_model_after_tuning.get_weights()[0] non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning) weights_as_list_after_tuning = weights_after_tuning.reshape( -1, ).tolist() unique_weights_after_tuning = set(weights_as_list_after_tuning) # Check that the null weights stayed the same before and after tuning. self.assertTrue( np.array_equal(non_zero_weight_indices_before_tuning, non_zero_weight_indices_after_tuning)) # Check that the number of unique weights matches the number of clusters. self.assertLessEqual(len(unique_weights_after_tuning), self.params["number_of_clusters"])
def testEndToEnd(self): """Test End to End clustering.""" original_model = keras.Sequential([ layers.Dense(2, input_shape=(2,)), layers.Dense(2), ]) clustered_model = cluster.cluster_weights(original_model, **self.params) clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) stripped_model = cluster.strip_clustering(clustered_model) _, tflite_file = tempfile.mkstemp(".tflite") _, keras_file = tempfile.mkstemp(".h5") if not compat.is_v1_apis(): converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) else: tf.keras.models.save_model(stripped_model, keras_file) converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() with open(tflite_file, "wb") as f: f.write(tflite_model) self._verify_tflite(tflite_file, self.x_train) os.remove(keras_file) os.remove(tflite_file)
def selective_cluster_model(original_model, sparsity_flag): cluster_epoch = 1 clustering_params = { 'number_of_clusters': 8, 'cluster_centroids_init': (tfmot_cluster_config.CentroidInitialization.DENSITY_BASED), 'preserve_sparsity': sparsity_flag, } def apply_clustering_to_conv2d(layer): if isinstance(layer, tf.keras.layers.Conv2D): return exp_tfmot_cluster.cluster_weights(layer, **clustering_params) return layer cluster_model = tf.keras.models.clone_model( original_model, clone_function=apply_clustering_to_conv2d, ) callbacks = [] cluster_model = _train_model(cluster_model, callbacks, cluster_epoch) clustered_model_stripped = tfmot_cluster.strip_clustering(cluster_model) return cluster_model, clustered_model_stripped
def testValuesAreClusteredAfterStripping(self, number_of_clusters, cluster_centroids_init): """ Verifies that, for any number of clusters and any centroid initialization method, the number of unique weight values after stripping is always less or equal to number_of_clusters. """ original_model = tf.keras.Sequential([ layers.Dense(32, input_shape=(10, )), ]) self.assertGreater( len(set(original_model.get_weights()[0].reshape(-1, ).tolist())), number_of_clusters) clustered_model = cluster.cluster_weights( original_model, number_of_clusters=number_of_clusters, cluster_centroids_init=cluster_centroids_init) stripped_model = cluster.strip_clustering(clustered_model) weights_as_list = stripped_model.get_weights()[0].reshape( -1, ).tolist() unique_weights = set(weights_as_list) # Make sure numbers match self.assertLessEqual(len(unique_weights), number_of_clusters) # Make sure that the stripped layer is the Dense one self.assertIsInstance(stripped_model.layers[0], layers.Dense)
def testPassingNonPrunedModelToPCQAT(self): """Runs PCQAT as CQAT if the input model is not pruned.""" preserve_sparsity = False clustered_model = self._get_clustered_model(preserve_sparsity) clustered_model = cluster.strip_clustering(clustered_model) nr_of_unique_weights_after = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') # Check after plain clustering, if there are no zero weights, # PCQAT falls back to CQAT quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model)) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme(True)) self.compile_and_fit(quant_aware_model) stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of clustered_model and pcqat_model num_of_unique_weights_pcqat = self._get_number_of_unique_weights( stripped_pcqat_model, 1, 'kernel') self.assertAllEqual(nr_of_unique_weights_after, num_of_unique_weights_pcqat)
def testEndToEndClusterPreserve(self): """Runs CQAT end to end and whole model is quantized.""" original_model = tf.keras.Sequential( [layers.Dense(5, activation='softmax', input_shape=(10, ))]) clustered_model = cluster.cluster_weights(original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model)) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat)
def cluster_train_eval_strip( model, x_train, y_train, x_test, y_test, batch_size, test_case): """Train, evaluate and strip clustering.""" model = cluster.cluster_weights( model, number_of_clusters=16, cluster_centroids_init=cluster_config.CentroidInitialization .KMEANS_PLUS_PLUS,) model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]) print("Train...") model.fit(x_train, y_train, batch_size=batch_size, epochs=1, validation_data=(x_test, y_test), verbose=2) score, acc = model.evaluate(x_test, y_test, batch_size=batch_size) print("Test score:", score) print("Test accuracy:", acc) print("Strip clustering wrapper...") model = cluster.strip_clustering(model) if "Bidirectional" in test_case: layer_weight = getattr(model.layers[1].forward_layer.cell, "kernel") elif "StackedRNNCells" in test_case: layer_weight = getattr(model.layers[1].cell.cells[0], "kernel") else: raise ValueError("Only Bidirectional and StackedRNNCells are tested now.") print("Number of clusters:", len(set(layer_weight.numpy().flatten())))
def end_to_end_testing(self, original_model, clusters_check=None): """Test End to End clustering.""" clustered_model = cluster.cluster_weights(original_model, **self.params) clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) stripped_model = cluster.strip_clustering(clustered_model) if clusters_check is not None: clusters_check(stripped_model) _, tflite_file = tempfile.mkstemp(".tflite") _, keras_file = tempfile.mkstemp(".h5") converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model) tflite_model = converter.convert() with open(tflite_file, "wb") as f: f.write(tflite_model) self._verify_tflite(tflite_file, self.x_test) os.remove(keras_file) os.remove(tflite_file)
def _cluster_model(model, number_of_clusters): (x_train, y_train), _ = _get_dataset() clustering_params = { 'number_of_clusters': number_of_clusters, 'cluster_centroids_init': cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS } # Cluster model clustered_model = cluster.cluster_weights(model, **clustering_params) # Use smaller learning rate for fine-tuning # clustered model opt = tf.keras.optimizers.Adam(learning_rate=1e-5) clustered_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=opt, metrics=['accuracy']) # Fine-tune clustered model clustered_model.fit(x_train, y_train, epochs=EPOCHS_FINE_TUNING) stripped_model = cluster.strip_clustering(clustered_model) stripped_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=opt, metrics=['accuracy']) return stripped_model
def testEndToEndPruneClusterPreserveQAT(self): """Runs PCQAT end to end when we quantize the whole model.""" preserve_sparsity = True clustered_model = self._get_clustered_model(preserve_sparsity) # Save the kernel weights first_layer_weights = clustered_model.layers[0].weights[1] stripped_model_before_tuning = cluster.strip_clustering( clustered_model) nr_of_unique_weights_before = self._get_number_of_unique_weights( stripped_model_before_tuning, 0, 'kernel') self.compile_and_fit(clustered_model) stripped_model_clustered = cluster.strip_clustering(clustered_model) weights_after_tuning = stripped_model_clustered.layers[0].kernel nr_of_unique_weights_after = self._get_number_of_unique_weights( stripped_model_clustered, 0, 'kernel') # Check after sparsity-aware clustering, despite zero centroid can drift, # the final number of unique weights remains the same self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after) # Check that the zero weights stayed the same before and after tuning. # There might be new weights that become zeros but sparsity-aware # clustering preserves the original zero weights in the original positions # of the weight array self.assertTrue( np.array_equal(first_layer_weights[:][0:2], weights_after_tuning[:][0:2])) # Check sparsity before the input of PCQAT sparsity_pruning = self._get_sparsity(stripped_model_clustered) # PCQAT: when the preserve_sparsity flag is True, the PCQAT should work quant_aware_annotate_model = ( quantize.quantize_annotate_model(stripped_model_clustered)) # When preserve_sparsity is True in PCQAT, the final sparsity of # the layer stays the same or larger than that of the input layer preserve_sparsity = True sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( preserve_sparsity, quant_aware_annotate_model) self.assertAllGreaterEqual(np.array(sparsity_pcqat), sparsity_pruning[0]) self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
def _clusterTrainStrip(self, model): clustered_model = cluster.cluster_weights( model, **self.params_clustering, ) self._train(clustered_model) stripped_model = cluster.strip_clustering(clustered_model) return stripped_model
def testStripClusteringSequentialModel(self): model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) clustered_model = cluster.cluster_weights(model, **self.params) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def testStripSelectivelyClusteredSequentialModel(self): clustered_model = keras.Sequential([ cluster.cluster_weights(layers.Dense(10), **self.params), layers.Dense(10), ]) clustered_model.build(input_shape=(1, 10)) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsInstance(stripped_model.layers[0], layers.Dense)
def testStripSelectivelyClusteredFunctionalModel(self): i1 = keras.Input(shape=(10, )) i2 = keras.Input(shape=(10, )) x1 = cluster.cluster_weights(layers.Dense(10), **self.params)(i1) x2 = layers.Dense(10)(i2) outputs = layers.Add()([x1, x2]) clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsInstance(stripped_model.layers[2], layers.Dense)
def testStripSelectivelyClusteredSequentialModel(self): """Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers.""" clustered_model = keras.Sequential([ cluster.cluster_weights(layers.Dense(10), **self.params), layers.Dense(10), ]) clustered_model.build(input_shape=(1, 10)) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsInstance(stripped_model.layers[0], layers.Dense)
def testStripClusteringSequentialModel(self): """Verifies that stripping the clustering wrappers from a sequential model produces the expected config.""" model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) clustered_model = cluster.cluster_weights(model, **self.params) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def testClusterStrippingFunctionalModel(self): i1 = keras.Input(shape=(10, )) i2 = keras.Input(shape=(10, )) x1 = layers.Dense(10)(i1) x2 = layers.Dense(10)(i2) outputs = layers.Add()([x1, x2]) model = keras.Model(inputs=[i1, i2], outputs=outputs) clustered_model = cluster.cluster_weights(model, **self.params) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def testClusterWeightsStrippedWeights(self): """Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights.""" i1 = keras.Input(shape=(10,)) x1 = layers.BatchNormalization()(i1) outputs = x1 model = keras.Model(inputs=[i1], outputs=outputs) clustered_model = cluster.cluster_weights(model, **self.params) cluster_weight_length = (len(clustered_model.get_weights())) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertLen(stripped_model.get_weights(), cluster_weight_length)
def testClusterWeightsStrippedWeights(self): i1 = keras.Input(shape=(10, )) x1 = layers.BatchNormalization()(i1) outputs = x1 model = keras.Model(inputs=[i1], outputs=outputs) clustered_model = cluster.cluster_weights(model, **self.params) cluster_weight_length = (len(clustered_model.get_weights())) stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(len(stripped_model.get_weights()), cluster_weight_length)
def testStripClusteringSequentialModelWithBiasRegularizer(self): """Verifies that stripping the clustering wrappers from a sequential model produces the expected config.""" model = keras.Sequential([ layers.Dense(10, input_shape=(10, )), layers.Dense(10, bias_regularizer=tf.keras.regularizers.L1(0.01)), ]) clustered_model = cluster.cluster_weights(model, **self.params) stripped_model = cluster.strip_clustering(clustered_model) # check that kernel regularizer is present in the second dense layer self.assertIsNotNone(stripped_model.layers[1].bias_regularizer) with tempfile.TemporaryDirectory() as tmp_dir_name: keras_file = os.path.join(tmp_dir_name, 'cluster_test') stripped_model.save(keras_file, save_traces=True)
def testStripClusteringAndSetOriginalWeightsBack(self): """Verifies that we can set_weights onto the stripped model.""" model = keras.Sequential([ layers.Dense(10, input_shape=(5, )), layers.Dense(10), ]) # Save original weights original_weights = model.get_weights() # Cluster and strip clustered_model = cluster.cluster_weights(model, **self.params) stripped_model = cluster.strip_clustering(clustered_model) # Set back original weights onto the strip model stripped_model.set_weights(original_weights)
def main(unused_args): # Load the MNIST dataset. mnist = tf.keras.datasets.mnist # Shuffle and split data to generate training and testing datasets (train_images, train_labels), (test_images, test_labels) = mnist.load_data() # Normalize the input images so that each pixel value is between 0 and 1. train_images = train_images / 255.0 test_images = test_images / 255.0 input_shape = (28, 28) # Create and train the baseline model model = setup_model(input_shape, train_images, train_labels) # Apply clustering API and retrain the model clustered_model = cluster_model(model, train_images, train_labels) print('Apply clustering:') clst_acc = evaluate_model_fp32(clustered_model, test_images, test_labels) clustered_model_stripped = tfmot_cluster.strip_clustering(clustered_model) print('Apply cluster-preserve quantization aware training (cqat):') # Start from pretrained clustered model, apply CQAT API, retrain the model cqat_model = cluster_preserve_quantize_model(clustered_model_stripped, train_images, train_labels) cqat_acc = evaluate_model_fp32(cqat_model, test_images, test_labels) # This only removes extra variables introduced by clustering # but the quantize_wrapper stays cqat_model_stripped = strip_clustering_cqat(cqat_model) # Compare between clustering and cqat in terms of FP32 accuracy # and numbers of unique weights print('FP32 accuracy of clustered model:', clst_acc) print_unique_weights(clustered_model_stripped) print('FP32 accuracy of cqat model:', cqat_acc) print_unique_weights(cqat_model_stripped) # See consistency of accuracy from TF to TFLite converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model_stripped) converter.optimizations = [tf.lite.Optimize.DEFAULT] cqat_model_stripped_tflite = converter.convert() interpreter = tf.lite.Interpreter(model_content=cqat_model_stripped_tflite) interpreter.allocate_tensors() test_accuracy = evaluate_model(interpreter, test_images, test_labels) with open('cqat.tflite', 'wb') as f: f.write(cqat_model_stripped_tflite) print('CQAT TFLite test_accuracy:', test_accuracy) print('CQAT TF test accuracy:', cqat_acc)
def testStrippedKernel(self): """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value .""" i1 = keras.Input(shape=(1, 1, 1)) x1 = layers.Conv2D(1, 1)(i1) outputs = x1 model = keras.Model(inputs=[i1], outputs=outputs) clustered_model = cluster.cluster_weights(model, **self.params) clustered_conv2d_layer = clustered_model.layers[1] clustered_kernel = clustered_conv2d_layer.layer.kernel stripped_model = cluster.strip_clustering(clustered_model) stripped_conv2d_layer = stripped_model.layers[1] self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel) self.assertEqual(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights[0])
def testValuesRemainClusteredAfterTraining(self): """ Verifies that training a clustered model does not destroy the clusters. """ number_of_clusters = 10 original_model = keras.Sequential([ layers.Dense(2, input_shape=(2,)), layers.Dense(2), ]) clustered_model = cluster.cluster_weights( original_model, number_of_clusters=number_of_clusters, cluster_centroids_init='linear' ) clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'] ) def dataset_generator(): x_train = np.array([ [0, 1], [2, 0], [0, 3], [4, 1], [5, 1], ]) y_train = np.array([ [0, 1], [1, 0], [1, 0], [0, 1], [0, 1], ]) for x, y in zip(x_train, y_train): yield np.array([x]), np.array([y]) clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1) stripped_model = cluster.strip_clustering(clustered_model) weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() unique_weights = set(weights_as_list) self.assertLessEqual(len(unique_weights), number_of_clusters)
def testMHA(self): model = self._get_model() clustered_model = cluster.cluster_weights(model, **self.params_clustering) clustered_model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]) clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1) stripped_model = cluster.strip_clustering(clustered_model) layerMHA = stripped_model.layers[1] for weight in layerMHA.weights: if 'kernel' in weight.name: nr_unique_weights = len(np.unique(weight.numpy())) assert nr_unique_weights == self.nr_of_clusters
def testEndToEndClusterPreserveOneLayer(self): """Runs CQAT end to end and model is quantized only for a single layer.""" original_model = tf.keras.Sequential([ layers.Dense(5, activation='relu', input_shape=(10, )), layers.Dense(5, activation='softmax', input_shape=(10, ), name='qat') ]) clustered_model = cluster.cluster_weights(original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 1, 'kernel') def apply_quantization_to_dense(layer): if isinstance(layer, tf.keras.layers.Dense): if layer.name == 'qat': return quantize.quantize_annotate_layer(layer) return layer quant_aware_annotate_model = tf.keras.models.clone_model( clustered_model, clone_function=apply_quantization_to_dense, ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme. Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat)
def testValuesAreClusteredAfterStripping(self, number_of_clusters, cluster_centroids_init): # We want to make sure that for any number of clusters and any initializations methods # there is always no more than number_of_clusters unique points after stripping the model original_model = tf.keras.Sequential([ layers.Dense(32, input_shape=(10, )), ]) clustered_model = cluster.cluster_weights( original_model, number_of_clusters=number_of_clusters, cluster_centroids_init=cluster_centroids_init) stripped_model = cluster.strip_clustering(clustered_model) weights_as_list = stripped_model.get_weights()[0].reshape( -1, ).tolist() unique_weights = set(weights_as_list) # Make sure numbers match self.assertLessEqual(len(unique_weights), number_of_clusters) # Make sure that the stripped layer is the Dense one self.assertIsInstance(stripped_model.layers[0], layers.Dense)
def testValuesRemainClusteredAfterTraining(self): """Verifies that training a clustered model does not destroy the clusters.""" original_model = keras.Sequential([ layers.Dense(2, input_shape=(2,)), layers.Dense(2), ]) clustered_model = cluster.cluster_weights(original_model, **self.params) clustered_model.compile( loss=keras.losses.categorical_crossentropy, optimizer="adam", metrics=["accuracy"], ) clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) stripped_model = cluster.strip_clustering(clustered_model) weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist() unique_weights = set(weights_as_list) self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
def _cluster_model(original_model, sparsity_flag): """Apply the clustering wrapper, compile and train the model.""" cluster_epoch = 1 clustering_params = { 'number_of_clusters': 8, 'cluster_centroids_init': (tfmot_cluster_config.CentroidInitialization.DENSITY_BASED), 'preserve_sparsity': sparsity_flag, } cluster_model = exp_tfmot_cluster.cluster_weights(original_model, **clustering_params) callbacks = [] cluster_model = _train_model(cluster_model, callbacks, cluster_epoch) clustered_model_stripped = tfmot_cluster.strip_clustering(cluster_model) return cluster_model, clustered_model_stripped