def testCanBeInitializedWithAlreadyClusterableLayer(self):
     layer = AlreadyClusterableLayer(10)
     l = cluster_wrapper.ClusterWeights(layer,
                                        number_of_clusters=13,
                                        cluster_centroids_init='linear')
     self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
 def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
     l = cluster_wrapper.ClusterWeights(layers.Dense(10,
                                                     input_shape=(10, )),
                                        number_of_clusters=13,
                                        cluster_centroids_init='linear')
     self.assertTrue(hasattr(l, '_batch_input_shape'))
 def testCannotBeInitializedWithNonClusterableLayer(self):
     with self.assertRaises(ValueError):
         cluster_wrapper.ClusterWeights(NonClusterableLayer(10),
                                        number_of_clusters=13,
                                        cluster_centroids_init='linear')
 def testCannotBeInitializedWithFloatNumberOfClusters(self):
     with self.assertRaises(ValueError):
         cluster_wrapper.ClusterWeights(layers.Dense(10),
                                        number_of_clusters=13.4,
                                        cluster_centroids_init='linear')
示例#5
0
 def testKerasCustomLayerClusterable(self):
     """Verifies that we can wrap keras custom layer that is customerable."""
     layer = KerasCustomLayerClusterable()
     wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params)
     self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
 def testCannotBeInitializedWithNonLayerObject(self):
     with self.assertRaises(ValueError):
         cluster_wrapper.ClusterWeights({'this': 'is not a Layer instance'},
                                        number_of_clusters=13,
                                        cluster_centroids_init='linear')
示例#7
0
 def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
   """Verifies that ClusterWeights cannot be initialized with a string value provided for the number of clusters."""
   with self.assertRaises(ValueError):
     cluster_wrapper.ClusterWeights(
         layers.Dense(10),
         number_of_clusters='13')
  def testAssociationValuesPerReplica(self, distribution):
    """Verifies that associations of weights are updated per replica."""
    assert tf.distribute.get_replica_context() is not None
    with distribution.scope():
      assert tf.distribute.get_replica_context() is None
      input_shape = (1, 2)
      output_shape = (2, 8)
      l = cluster_wrapper.ClusterWeights(
          keras.layers.Dense(8, input_shape=input_shape),
          number_of_clusters=self.params["number_of_clusters"],
          cluster_centroids_init=self.params["cluster_centroids_init"]
      )
      l.build(input_shape)

      clusterable_weights = l.layer.get_clusterable_weights()
      self.assertEqual(len(clusterable_weights), 1)
      weights_name = clusterable_weights[0][0]
      self.assertEqual(weights_name, 'kernel')
      centroids1 = l.cluster_centroids_tf[weights_name]

      mean_weight = tf.reduce_mean(l.layer.kernel)
      min_weight = tf.reduce_min(l.layer.kernel)
      max_weight = tf.reduce_max(l.layer.kernel)
      max_dist = max_weight - min_weight

      def assert_all_cluster_indices(per_replica, indices_val):
        if indices_val == 1:
          val_tensor = tf.dtypes.cast(
              tf.ones(shape=output_shape), per_replica[0].dtype)
        if indices_val == 0:
          val_tensor = tf.dtypes.cast(
              tf.zeros(shape=output_shape), per_replica[0].dtype)
        for i in range(0, len(per_replica)):
          all_equal = tf.reduce_all(
              tf.equal(
                  per_replica[i], val_tensor
              )
          )
          self.assertTrue(all_equal)

      def update_fn(v, val):
        return v.assign(val)

      initial_val = tf.Variable([mean_weight, mean_weight + 2.0 * max_dist], \
        aggregation=tf.VariableAggregation.MEAN)

      centroids1 = distribution.extended.update(
          centroids1, update_fn, args=(initial_val,))
      l.call(tf.ones(shape=input_shape))

      clst_indices = l.pulling_indices_tf[weights_name]
      per_replica = distribution.experimental_local_results(clst_indices)
      assert_all_cluster_indices(per_replica, 0)

      second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
        aggregation=tf.VariableAggregation.MEAN)
      centroids2 = l.cluster_centroids_tf[weights_name]
      centroids2 = distribution.extended.update(
          centroids2, update_fn, args=(second_val,))
      l.call(tf.ones(shape=input_shape))

      clst_indices = l.pulling_indices_tf[weights_name]
      per_replica = distribution.experimental_local_results(clst_indices)
      assert_all_cluster_indices(per_replica, 1)
 def testClusterCustomNonClusterableLayer(self):
     with self.assertRaises(ValueError):
         cluster_wrapper.ClusterWeights(self.custom_non_clusterable_layer,
                                        **self.params)
 def testCanBeInitializedWithNonIntegerNumberOfClusters(self):
     with self.assertRaises(ValueError):
         cluster_wrapper.ClusterWeights(layers.Dense(10),
                                        number_of_clusters="13",
                                        cluster_centroids_init='linear')
    def testClusterReassociation(self):
        """
    Verifies that the association of weights to cluster centroids are updated
    every iteration.
    """

        # Create a dummy layer for this test
        input_shape = (
            1,
            2,
        )
        l = cluster_wrapper.ClusterWeights(
            keras.layers.Dense(8, input_shape=input_shape),
            number_of_clusters=2,
            cluster_centroids_init=CentroidInitialization.LINEAR)
        # Build a layer with the given shape
        l.build(input_shape)

        # Get name of the clusterable weights
        clusterable_weights = l.layer.get_clusterable_weights()
        self.assertEqual(len(clusterable_weights), 1)
        weights_name = clusterable_weights[0][0]
        self.assertEqual(weights_name, 'kernel')
        # Get cluster centroids
        centroids = l.cluster_centroids_tf[weights_name]

        # Calculate some statistics of the weights to set the centroids later on
        mean_weight = tf.reduce_mean(l.layer.kernel)
        min_weight = tf.reduce_min(l.layer.kernel)
        max_weight = tf.reduce_max(l.layer.kernel)
        max_dist = max_weight - min_weight

        def assert_all_weights_associated(weights, centroid_index):
            """Helper function to make sure that all weights are associated with one
      centroid."""
            all_associated = tf.reduce_all(
                tf.equal(
                    weights,
                    tf.constant(centroids[centroid_index],
                                shape=weights.shape)))
            self.assertTrue(all_associated)

        # Set centroids so that all weights should be re-associated with centroid 0
        centroids[0].assign(mean_weight)
        centroids[1].assign(mean_weight + 2.0 * max_dist)

        # Update associations of weights to centroids
        l.call(tf.ones(shape=input_shape))

        # Weights should now be all clustered with the centroid 0
        assert_all_weights_associated(l.layer.kernel, centroid_index=0)

        # Set centroids so that all weights should be re-associated with centroid 1
        centroids[0].assign(mean_weight - 2.0 * max_dist)
        centroids[1].assign(mean_weight)

        # Update associations of weights to centroids
        l.call(tf.ones(shape=input_shape))

        # Weights should now be all clustered with the centroid 1
        assert_all_weights_associated(l.layer.kernel, centroid_index=1)