def testSparsityIsPreservedDuringTraining(self):
        """Set a specific random seed.

       Ensures that we get some null weights to test sparsity preservation with.
    """
        tf.random.set_seed(1)
        # Verifies that training a clustered model with null weights in it
        # does not destroy the sparsity of the weights.
        original_model = keras.Sequential([
            layers.Dense(5, input_shape=(5, )),
            layers.Flatten(),
        ])
        # Reset the kernel weights to reflect potential zero drifting of
        # the cluster centroids
        first_layer_weights = original_model.layers[0].get_weights()
        first_layer_weights[0][:][0:2] = 0.0
        first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13]
        first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13]
        original_model.layers[0].set_weights(first_layer_weights)
        clustering_params = {
            "number_of_clusters": 6,
            "cluster_centroids_init": CentroidInitialization.LINEAR,
            "preserve_sparsity": True
        }
        clustered_model = experimental_cluster.cluster_weights(
            original_model, **clustering_params)
        stripped_model_before_tuning = cluster.strip_clustering(
            clustered_model)
        nr_of_unique_weights_before = self._get_number_of_unique_weights(
            stripped_model_before_tuning, 0, "kernel")
        clustered_model.compile(
            loss=keras.losses.categorical_crossentropy,
            optimizer="adam",
            metrics=["accuracy"],
        )
        clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100)
        stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
        weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
        nr_of_unique_weights_after = self._get_number_of_unique_weights(
            stripped_model_after_tuning, 0, "kernel")
        # Check after sparsity-aware clustering, despite zero centroid can drift,
        # the final number of unique weights remains the same
        self.assertLessEqual(nr_of_unique_weights_after,
                             nr_of_unique_weights_before)
        # Check that the null weights stayed the same before and after tuning.
        # There might be new weights that become zeros but sparsity-aware
        # clustering preserves the original null weights in the original positions
        # of the weight array
        self.assertTrue(
            np.array_equal(first_layer_weights[0][:][0:2],
                           weights_after_tuning[:][0:2]))
        # Check that the number of unique weights matches the number of clusters.
        self.assertLessEqual(nr_of_unique_weights_after,
                             clustering_params["number_of_clusters"])
    def testSparsityIsPreservedDuringTraining(self):
        """Set a specific random seed to ensure that we get some null weights to test sparsity preservation with."""
        tf.random.set_seed(1)

        # Verifies that training a clustered model does not destroy the sparsity of
        # the weights.
        original_model = keras.Sequential([
            layers.Dense(5, input_shape=(5, )),
            layers.Dense(5),
        ])

        # Using a mininum number of centroids to make it more likely that some
        # weights will be zero.
        clustering_params = {
            "number_of_clusters": 3,
            "cluster_centroids_init": CentroidInitialization.LINEAR,
            "preserve_sparsity": True
        }

        clustered_model = experimental_cluster.cluster_weights(
            original_model, **clustering_params)

        stripped_model_before_tuning = cluster.strip_clustering(
            clustered_model)
        weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
        non_zero_weight_indices_before_tuning = np.nonzero(
            weights_before_tuning)

        clustered_model.compile(
            loss=keras.losses.categorical_crossentropy,
            optimizer="adam",
            metrics=["accuracy"],
        )
        clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)

        stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
        weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
        non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
        weights_as_list_after_tuning = weights_after_tuning.reshape(
            -1, ).tolist()
        unique_weights_after_tuning = set(weights_as_list_after_tuning)

        # Check that the null weights stayed the same before and after tuning.
        self.assertTrue(
            np.array_equal(non_zero_weight_indices_before_tuning,
                           non_zero_weight_indices_after_tuning))

        # Check that the number of unique weights matches the number of clusters.
        self.assertLessEqual(len(unique_weights_after_tuning),
                             self.params["number_of_clusters"])
  def testEndToEnd(self):
    """Test End to End clustering."""
    original_model = keras.Sequential([
        layers.Dense(2, input_shape=(2,)),
        layers.Dense(2),
    ])

    clustered_model = cluster.cluster_weights(original_model, **self.params)

    clustered_model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer="adam",
        metrics=["accuracy"],
    )

    clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
    stripped_model = cluster.strip_clustering(clustered_model)

    _, tflite_file = tempfile.mkstemp(".tflite")
    _, keras_file = tempfile.mkstemp(".h5")

    if not compat.is_v1_apis():
      converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
    else:
      tf.keras.models.save_model(stripped_model, keras_file)
      converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)

    tflite_model = converter.convert()
    with open(tflite_file, "wb") as f:
      f.write(tflite_model)

    self._verify_tflite(tflite_file, self.x_train)

    os.remove(keras_file)
    os.remove(tflite_file)
def selective_cluster_model(original_model, sparsity_flag):
    cluster_epoch = 1
    clustering_params = {
        'number_of_clusters':
        8,
        'cluster_centroids_init':
        (tfmot_cluster_config.CentroidInitialization.DENSITY_BASED),
        'preserve_sparsity':
        sparsity_flag,
    }

    def apply_clustering_to_conv2d(layer):
        if isinstance(layer, tf.keras.layers.Conv2D):
            return exp_tfmot_cluster.cluster_weights(layer,
                                                     **clustering_params)
        return layer

    cluster_model = tf.keras.models.clone_model(
        original_model,
        clone_function=apply_clustering_to_conv2d,
    )

    callbacks = []
    cluster_model = _train_model(cluster_model, callbacks, cluster_epoch)

    clustered_model_stripped = tfmot_cluster.strip_clustering(cluster_model)

    return cluster_model, clustered_model_stripped
Exemplo n.º 5
0
    def testValuesAreClusteredAfterStripping(self, number_of_clusters,
                                             cluster_centroids_init):
        """
    Verifies that, for any number of clusters and any centroid initialization
    method, the number of unique weight values after stripping is always less
    or equal to number_of_clusters.
    """
        original_model = tf.keras.Sequential([
            layers.Dense(32, input_shape=(10, )),
        ])
        self.assertGreater(
            len(set(original_model.get_weights()[0].reshape(-1, ).tolist())),
            number_of_clusters)
        clustered_model = cluster.cluster_weights(
            original_model,
            number_of_clusters=number_of_clusters,
            cluster_centroids_init=cluster_centroids_init)
        stripped_model = cluster.strip_clustering(clustered_model)
        weights_as_list = stripped_model.get_weights()[0].reshape(
            -1, ).tolist()
        unique_weights = set(weights_as_list)
        # Make sure numbers match
        self.assertLessEqual(len(unique_weights), number_of_clusters)

        # Make sure that the stripped layer is the Dense one
        self.assertIsInstance(stripped_model.layers[0], layers.Dense)
    def testPassingNonPrunedModelToPCQAT(self):
        """Runs PCQAT as CQAT if the input model is not pruned."""
        preserve_sparsity = False
        clustered_model = self._get_clustered_model(preserve_sparsity)

        clustered_model = cluster.strip_clustering(clustered_model)
        nr_of_unique_weights_after = self._get_number_of_unique_weights(
            clustered_model, 0, 'kernel')

        # Check after plain clustering, if there are no zero weights,
        # PCQAT falls back to CQAT
        quant_aware_annotate_model = (
            quantize.quantize_annotate_model(clustered_model))

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme(True))

        self.compile_and_fit(quant_aware_model)
        stripped_pcqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of clustered_model and pcqat_model
        num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
            stripped_pcqat_model, 1, 'kernel')
        self.assertAllEqual(nr_of_unique_weights_after,
                            num_of_unique_weights_pcqat)
    def testEndToEndClusterPreserve(self):
        """Runs CQAT end to end and whole model is quantized."""
        original_model = tf.keras.Sequential(
            [layers.Dense(5, activation='softmax', input_shape=(10, ))])
        clustered_model = cluster.cluster_weights(original_model,
                                                  **self.cluster_params)
        self.compile_and_fit(clustered_model)
        clustered_model = cluster.strip_clustering(clustered_model)
        num_of_unique_weights_clustering = self._get_number_of_unique_weights(
            clustered_model, 0, 'kernel')

        quant_aware_annotate_model = (
            quantize.quantize_annotate_model(clustered_model))

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme())

        self.compile_and_fit(quant_aware_model)
        stripped_cqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of a certain layer of
        # clustered_model and pcqat_model
        num_of_unique_weights_cqat = self._get_number_of_unique_weights(
            stripped_cqat_model, 1, 'kernel')
        self.assertAllEqual(num_of_unique_weights_clustering,
                            num_of_unique_weights_cqat)
Exemplo n.º 8
0
def cluster_train_eval_strip(
    model, x_train, y_train, x_test, y_test, batch_size, test_case):
  """Train, evaluate and strip clustering."""
  model = cluster.cluster_weights(
      model,
      number_of_clusters=16,
      cluster_centroids_init=cluster_config.CentroidInitialization
      .KMEANS_PLUS_PLUS,)

  model.compile(loss="binary_crossentropy",
                optimizer="adam",
                metrics=["accuracy"])

  print("Train...")
  model.fit(x_train, y_train, batch_size=batch_size, epochs=1,
            validation_data=(x_test, y_test), verbose=2)
  score, acc = model.evaluate(x_test, y_test,
                              batch_size=batch_size)

  print("Test score:", score)
  print("Test accuracy:", acc)

  print("Strip clustering wrapper...")
  model = cluster.strip_clustering(model)
  if "Bidirectional" in test_case:
    layer_weight = getattr(model.layers[1].forward_layer.cell, "kernel")
  elif "StackedRNNCells" in test_case:
    layer_weight = getattr(model.layers[1].cell.cells[0], "kernel")
  else:
    raise ValueError("Only Bidirectional and StackedRNNCells are tested now.")
  print("Number of clusters:", len(set(layer_weight.numpy().flatten())))
    def end_to_end_testing(self, original_model, clusters_check=None):
        """Test End to End clustering."""

        clustered_model = cluster.cluster_weights(original_model,
                                                  **self.params)

        clustered_model.compile(
            loss=keras.losses.categorical_crossentropy,
            optimizer="adam",
            metrics=["accuracy"],
        )

        clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
        stripped_model = cluster.strip_clustering(clustered_model)
        if clusters_check is not None:
            clusters_check(stripped_model)

        _, tflite_file = tempfile.mkstemp(".tflite")
        _, keras_file = tempfile.mkstemp(".h5")

        converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
        tflite_model = converter.convert()

        with open(tflite_file, "wb") as f:
            f.write(tflite_model)

        self._verify_tflite(tflite_file, self.x_test)

        os.remove(keras_file)
        os.remove(tflite_file)
def _cluster_model(model, number_of_clusters):

    (x_train, y_train), _ = _get_dataset()

    clustering_params = {
        'number_of_clusters':
        number_of_clusters,
        'cluster_centroids_init':
        cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS
    }

    # Cluster model
    clustered_model = cluster.cluster_weights(model, **clustering_params)

    # Use smaller learning rate for fine-tuning
    # clustered model
    opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

    clustered_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=opt,
        metrics=['accuracy'])

    # Fine-tune clustered model
    clustered_model.fit(x_train, y_train, epochs=EPOCHS_FINE_TUNING)

    stripped_model = cluster.strip_clustering(clustered_model)
    stripped_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=opt,
        metrics=['accuracy'])

    return stripped_model
    def testEndToEndPruneClusterPreserveQAT(self):
        """Runs PCQAT end to end when we quantize the whole model."""
        preserve_sparsity = True
        clustered_model = self._get_clustered_model(preserve_sparsity)
        # Save the kernel weights
        first_layer_weights = clustered_model.layers[0].weights[1]
        stripped_model_before_tuning = cluster.strip_clustering(
            clustered_model)
        nr_of_unique_weights_before = self._get_number_of_unique_weights(
            stripped_model_before_tuning, 0, 'kernel')

        self.compile_and_fit(clustered_model)

        stripped_model_clustered = cluster.strip_clustering(clustered_model)
        weights_after_tuning = stripped_model_clustered.layers[0].kernel
        nr_of_unique_weights_after = self._get_number_of_unique_weights(
            stripped_model_clustered, 0, 'kernel')

        # Check after sparsity-aware clustering, despite zero centroid can drift,
        # the final number of unique weights remains the same
        self.assertEqual(nr_of_unique_weights_before,
                         nr_of_unique_weights_after)

        # Check that the zero weights stayed the same before and after tuning.
        # There might be new weights that become zeros but sparsity-aware
        # clustering preserves the original zero weights in the original positions
        # of the weight array
        self.assertTrue(
            np.array_equal(first_layer_weights[:][0:2],
                           weights_after_tuning[:][0:2]))

        # Check sparsity before the input of PCQAT
        sparsity_pruning = self._get_sparsity(stripped_model_clustered)

        # PCQAT: when the preserve_sparsity flag is True, the PCQAT should work
        quant_aware_annotate_model = (
            quantize.quantize_annotate_model(stripped_model_clustered))

        # When preserve_sparsity is True in PCQAT, the final sparsity of
        # the layer stays the same or larger than that of the input layer
        preserve_sparsity = True
        sparsity_pcqat, unique_weights_pcqat = self._pcqat_training(
            preserve_sparsity, quant_aware_annotate_model)
        self.assertAllGreaterEqual(np.array(sparsity_pcqat),
                                   sparsity_pruning[0])
        self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
    def _clusterTrainStrip(self, model):
        clustered_model = cluster.cluster_weights(
            model,
            **self.params_clustering,
        )
        self._train(clustered_model)
        stripped_model = cluster.strip_clustering(clustered_model)

        return stripped_model
Exemplo n.º 13
0
    def testStripClusteringSequentialModel(self):
        model = keras.Sequential([
            layers.Dense(10),
            layers.Dense(10),
        ])

        clustered_model = cluster.cluster_weights(model, **self.params)
        stripped_model = cluster.strip_clustering(clustered_model)
        self.assertEqual(self._count_clustered_layers(stripped_model), 0)
        self.assertEqual(model.get_config(), stripped_model.get_config())
Exemplo n.º 14
0
    def testStripSelectivelyClusteredSequentialModel(self):
        clustered_model = keras.Sequential([
            cluster.cluster_weights(layers.Dense(10), **self.params),
            layers.Dense(10),
        ])
        clustered_model.build(input_shape=(1, 10))

        stripped_model = cluster.strip_clustering(clustered_model)

        self.assertEqual(self._count_clustered_layers(stripped_model), 0)
        self.assertIsInstance(stripped_model.layers[0], layers.Dense)
Exemplo n.º 15
0
    def testStripSelectivelyClusteredFunctionalModel(self):
        i1 = keras.Input(shape=(10, ))
        i2 = keras.Input(shape=(10, ))
        x1 = cluster.cluster_weights(layers.Dense(10), **self.params)(i1)
        x2 = layers.Dense(10)(i2)
        outputs = layers.Add()([x1, x2])
        clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs)

        stripped_model = cluster.strip_clustering(clustered_model)

        self.assertEqual(self._count_clustered_layers(stripped_model), 0)
        self.assertIsInstance(stripped_model.layers[2], layers.Dense)
Exemplo n.º 16
0
  def testStripSelectivelyClusteredSequentialModel(self):
    """Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers."""
    clustered_model = keras.Sequential([
        cluster.cluster_weights(layers.Dense(10), **self.params),
        layers.Dense(10),
    ])
    clustered_model.build(input_shape=(1, 10))

    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertIsInstance(stripped_model.layers[0], layers.Dense)
Exemplo n.º 17
0
  def testStripClusteringSequentialModel(self):
    """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
    model = keras.Sequential([
        layers.Dense(10),
        layers.Dense(10),
    ])

    clustered_model = cluster.cluster_weights(model, **self.params)
    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertEqual(model.get_config(), stripped_model.get_config())
Exemplo n.º 18
0
    def testClusterStrippingFunctionalModel(self):
        i1 = keras.Input(shape=(10, ))
        i2 = keras.Input(shape=(10, ))
        x1 = layers.Dense(10)(i1)
        x2 = layers.Dense(10)(i2)
        outputs = layers.Add()([x1, x2])
        model = keras.Model(inputs=[i1, i2], outputs=outputs)

        clustered_model = cluster.cluster_weights(model, **self.params)
        stripped_model = cluster.strip_clustering(clustered_model)

        self.assertEqual(self._count_clustered_layers(stripped_model), 0)
        self.assertEqual(model.get_config(), stripped_model.get_config())
Exemplo n.º 19
0
  def testClusterWeightsStrippedWeights(self):
    """Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights."""
    i1 = keras.Input(shape=(10,))
    x1 = layers.BatchNormalization()(i1)
    outputs = x1
    model = keras.Model(inputs=[i1], outputs=outputs)

    clustered_model = cluster.cluster_weights(model, **self.params)
    cluster_weight_length = (len(clustered_model.get_weights()))
    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertLen(stripped_model.get_weights(), cluster_weight_length)
Exemplo n.º 20
0
    def testClusterWeightsStrippedWeights(self):
        i1 = keras.Input(shape=(10, ))
        x1 = layers.BatchNormalization()(i1)
        outputs = x1
        model = keras.Model(inputs=[i1], outputs=outputs)

        clustered_model = cluster.cluster_weights(model, **self.params)
        cluster_weight_length = (len(clustered_model.get_weights()))
        stripped_model = cluster.strip_clustering(clustered_model)

        self.assertEqual(self._count_clustered_layers(stripped_model), 0)
        self.assertEqual(len(stripped_model.get_weights()),
                         cluster_weight_length)
Exemplo n.º 21
0
 def testStripClusteringSequentialModelWithBiasRegularizer(self):
     """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
     model = keras.Sequential([
         layers.Dense(10, input_shape=(10, )),
         layers.Dense(10, bias_regularizer=tf.keras.regularizers.L1(0.01)),
     ])
     clustered_model = cluster.cluster_weights(model, **self.params)
     stripped_model = cluster.strip_clustering(clustered_model)
     # check that kernel regularizer is present in the second dense layer
     self.assertIsNotNone(stripped_model.layers[1].bias_regularizer)
     with tempfile.TemporaryDirectory() as tmp_dir_name:
         keras_file = os.path.join(tmp_dir_name, 'cluster_test')
         stripped_model.save(keras_file, save_traces=True)
Exemplo n.º 22
0
    def testStripClusteringAndSetOriginalWeightsBack(self):
        """Verifies that we can set_weights onto the stripped model."""
        model = keras.Sequential([
            layers.Dense(10, input_shape=(5, )),
            layers.Dense(10),
        ])

        # Save original weights
        original_weights = model.get_weights()

        # Cluster and strip
        clustered_model = cluster.cluster_weights(model, **self.params)
        stripped_model = cluster.strip_clustering(clustered_model)

        # Set back original weights onto the strip model
        stripped_model.set_weights(original_weights)
Exemplo n.º 23
0
def main(unused_args):
    # Load the MNIST dataset.
    mnist = tf.keras.datasets.mnist
    # Shuffle and split data to generate training and testing datasets
    (train_images, train_labels), (test_images,
                                   test_labels) = mnist.load_data()
    # Normalize the input images so that each pixel value is between 0 and 1.
    train_images = train_images / 255.0
    test_images = test_images / 255.0

    input_shape = (28, 28)
    # Create and train the baseline model
    model = setup_model(input_shape, train_images, train_labels)
    # Apply clustering API and retrain the model
    clustered_model = cluster_model(model, train_images, train_labels)
    print('Apply clustering:')
    clst_acc = evaluate_model_fp32(clustered_model, test_images, test_labels)
    clustered_model_stripped = tfmot_cluster.strip_clustering(clustered_model)
    print('Apply cluster-preserve quantization aware training (cqat):')
    # Start from pretrained clustered model, apply CQAT API, retrain the model
    cqat_model = cluster_preserve_quantize_model(clustered_model_stripped,
                                                 train_images, train_labels)
    cqat_acc = evaluate_model_fp32(cqat_model, test_images, test_labels)
    # This only removes extra variables introduced by clustering
    # but the quantize_wrapper stays
    cqat_model_stripped = strip_clustering_cqat(cqat_model)

    # Compare between clustering and cqat in terms of FP32 accuracy
    # and numbers of unique weights
    print('FP32 accuracy of clustered model:', clst_acc)
    print_unique_weights(clustered_model_stripped)
    print('FP32 accuracy of cqat model:', cqat_acc)
    print_unique_weights(cqat_model_stripped)

    # See consistency of accuracy from TF to TFLite
    converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model_stripped)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    cqat_model_stripped_tflite = converter.convert()
    interpreter = tf.lite.Interpreter(model_content=cqat_model_stripped_tflite)
    interpreter.allocate_tensors()
    test_accuracy = evaluate_model(interpreter, test_images, test_labels)

    with open('cqat.tflite', 'wb') as f:
        f.write(cqat_model_stripped_tflite)

    print('CQAT TFLite test_accuracy:', test_accuracy)
    print('CQAT TF test accuracy:', cqat_acc)
Exemplo n.º 24
0
  def testStrippedKernel(self):
    """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value ."""
    i1 = keras.Input(shape=(1, 1, 1))
    x1 = layers.Conv2D(1, 1)(i1)
    outputs = x1
    model = keras.Model(inputs=[i1], outputs=outputs)

    clustered_model = cluster.cluster_weights(model, **self.params)
    clustered_conv2d_layer = clustered_model.layers[1]
    clustered_kernel = clustered_conv2d_layer.layer.kernel
    stripped_model = cluster.strip_clustering(clustered_model)
    stripped_conv2d_layer = stripped_model.layers[1]

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
    self.assertEqual(stripped_conv2d_layer.kernel,
                     stripped_conv2d_layer.weights[0])
Exemplo n.º 25
0
  def testValuesRemainClusteredAfterTraining(self):
    """
    Verifies that training a clustered model does not destroy the clusters.
    """
    number_of_clusters = 10
    original_model = keras.Sequential([
        layers.Dense(2, input_shape=(2,)),
        layers.Dense(2),
    ])

    clustered_model = cluster.cluster_weights(
        original_model,
        number_of_clusters=number_of_clusters,
        cluster_centroids_init='linear'
    )

    clustered_model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer='adam',
        metrics=['accuracy']
    )

    def dataset_generator():
      x_train = np.array([
          [0, 1],
          [2, 0],
          [0, 3],
          [4, 1],
          [5, 1],
      ])
      y_train = np.array([
          [0, 1],
          [1, 0],
          [1, 0],
          [0, 1],
          [0, 1],
      ])
      for x, y in zip(x_train, y_train):
        yield np.array([x]), np.array([y])

    clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1)
    stripped_model = cluster.strip_clustering(clustered_model)
    weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
    unique_weights = set(weights_as_list)
    self.assertLessEqual(len(unique_weights), number_of_clusters)
Exemplo n.º 26
0
  def testMHA(self):
    model = self._get_model()

    clustered_model = cluster.cluster_weights(model, **self.params_clustering)

    clustered_model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
    clustered_model.fit(self.x_train, self.y_train, epochs=1, batch_size=100, verbose=1)

    stripped_model = cluster.strip_clustering(clustered_model)

    layerMHA = stripped_model.layers[1]
    for weight in layerMHA.weights:
      if 'kernel' in weight.name:
        nr_unique_weights = len(np.unique(weight.numpy()))
        assert nr_unique_weights == self.nr_of_clusters
    def testEndToEndClusterPreserveOneLayer(self):
        """Runs CQAT end to end and model is quantized only for a single layer."""
        original_model = tf.keras.Sequential([
            layers.Dense(5, activation='relu', input_shape=(10, )),
            layers.Dense(5,
                         activation='softmax',
                         input_shape=(10, ),
                         name='qat')
        ])
        clustered_model = cluster.cluster_weights(original_model,
                                                  **self.cluster_params)
        self.compile_and_fit(clustered_model)
        clustered_model = cluster.strip_clustering(clustered_model)
        num_of_unique_weights_clustering = self._get_number_of_unique_weights(
            clustered_model, 1, 'kernel')

        def apply_quantization_to_dense(layer):
            if isinstance(layer, tf.keras.layers.Dense):
                if layer.name == 'qat':
                    return quantize.quantize_annotate_layer(layer)
            return layer

        quant_aware_annotate_model = tf.keras.models.clone_model(
            clustered_model,
            clone_function=apply_quantization_to_dense,
        )

        quant_aware_model = quantize.quantize_apply(
            quant_aware_annotate_model,
            scheme=default_8bit_cluster_preserve_quantize_scheme.
            Default8BitClusterPreserveQuantizeScheme())

        self.compile_and_fit(quant_aware_model)

        stripped_cqat_model = strip_clustering_cqat(quant_aware_model)

        # Check the unique weights of a certain layer of
        # clustered_model and pcqat_model
        num_of_unique_weights_cqat = self._get_number_of_unique_weights(
            stripped_cqat_model, 1, 'kernel')
        self.assertAllEqual(num_of_unique_weights_clustering,
                            num_of_unique_weights_cqat)
    def testValuesAreClusteredAfterStripping(self, number_of_clusters,
                                             cluster_centroids_init):
        # We want to make sure that for any number of clusters and any initializations methods
        # there is always no more than number_of_clusters unique points after stripping the model
        original_model = tf.keras.Sequential([
            layers.Dense(32, input_shape=(10, )),
        ])
        clustered_model = cluster.cluster_weights(
            original_model,
            number_of_clusters=number_of_clusters,
            cluster_centroids_init=cluster_centroids_init)
        stripped_model = cluster.strip_clustering(clustered_model)
        weights_as_list = stripped_model.get_weights()[0].reshape(
            -1, ).tolist()
        unique_weights = set(weights_as_list)
        # Make sure numbers match
        self.assertLessEqual(len(unique_weights), number_of_clusters)

        # Make sure that the stripped layer is the Dense one
        self.assertIsInstance(stripped_model.layers[0], layers.Dense)
  def testValuesRemainClusteredAfterTraining(self):
    """Verifies that training a clustered model does not destroy the clusters."""
    original_model = keras.Sequential([
        layers.Dense(2, input_shape=(2,)),
        layers.Dense(2),
    ])

    clustered_model = cluster.cluster_weights(original_model, **self.params)

    clustered_model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer="adam",
        metrics=["accuracy"],
    )

    clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
    stripped_model = cluster.strip_clustering(clustered_model)
    weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
    unique_weights = set(weights_as_list)
    self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
def _cluster_model(original_model, sparsity_flag):
    """Apply the clustering wrapper, compile and train the model."""
    cluster_epoch = 1
    clustering_params = {
        'number_of_clusters':
        8,
        'cluster_centroids_init':
        (tfmot_cluster_config.CentroidInitialization.DENSITY_BASED),
        'preserve_sparsity':
        sparsity_flag,
    }
    cluster_model = exp_tfmot_cluster.cluster_weights(original_model,
                                                      **clustering_params)

    callbacks = []
    cluster_model = _train_model(cluster_model, callbacks, cluster_epoch)

    clustered_model_stripped = tfmot_cluster.strip_clustering(cluster_model)

    return cluster_model, clustered_model_stripped