def testCanBeInitializedWithAlreadyClusterableLayer(self): layer = AlreadyClusterableLayer(10) l = cluster_wrapper.ClusterWeights(layer, number_of_clusters=13, cluster_centroids_init='linear') self.assertIsInstance(l, cluster_wrapper.ClusterWeights)
def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self): l = cluster_wrapper.ClusterWeights(layers.Dense(10, input_shape=(10, )), number_of_clusters=13, cluster_centroids_init='linear') self.assertTrue(hasattr(l, '_batch_input_shape'))
def testCannotBeInitializedWithNonClusterableLayer(self): with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights(NonClusterableLayer(10), number_of_clusters=13, cluster_centroids_init='linear')
def testCannotBeInitializedWithFloatNumberOfClusters(self): with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights(layers.Dense(10), number_of_clusters=13.4, cluster_centroids_init='linear')
def testKerasCustomLayerClusterable(self): """Verifies that we can wrap keras custom layer that is customerable.""" layer = KerasCustomLayerClusterable() wrapped_layer = cluster_wrapper.ClusterWeights(layer, **self.params) self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)
def testCannotBeInitializedWithNonLayerObject(self): with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights({'this': 'is not a Layer instance'}, number_of_clusters=13, cluster_centroids_init='linear')
def testCannotBeInitializedWithNonIntegerNumberOfClusters(self): """Verifies that ClusterWeights cannot be initialized with a string value provided for the number of clusters.""" with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights( layers.Dense(10), number_of_clusters='13')
def testAssociationValuesPerReplica(self, distribution): """Verifies that associations of weights are updated per replica.""" assert tf.distribute.get_replica_context() is not None with distribution.scope(): assert tf.distribute.get_replica_context() is None input_shape = (1, 2) output_shape = (2, 8) l = cluster_wrapper.ClusterWeights( keras.layers.Dense(8, input_shape=input_shape), number_of_clusters=self.params["number_of_clusters"], cluster_centroids_init=self.params["cluster_centroids_init"] ) l.build(input_shape) clusterable_weights = l.layer.get_clusterable_weights() self.assertEqual(len(clusterable_weights), 1) weights_name = clusterable_weights[0][0] self.assertEqual(weights_name, 'kernel') centroids1 = l.cluster_centroids_tf[weights_name] mean_weight = tf.reduce_mean(l.layer.kernel) min_weight = tf.reduce_min(l.layer.kernel) max_weight = tf.reduce_max(l.layer.kernel) max_dist = max_weight - min_weight def assert_all_cluster_indices(per_replica, indices_val): if indices_val == 1: val_tensor = tf.dtypes.cast( tf.ones(shape=output_shape), per_replica[0].dtype) if indices_val == 0: val_tensor = tf.dtypes.cast( tf.zeros(shape=output_shape), per_replica[0].dtype) for i in range(0, len(per_replica)): all_equal = tf.reduce_all( tf.equal( per_replica[i], val_tensor ) ) self.assertTrue(all_equal) def update_fn(v, val): return v.assign(val) initial_val = tf.Variable([mean_weight, mean_weight + 2.0 * max_dist], \ aggregation=tf.VariableAggregation.MEAN) centroids1 = distribution.extended.update( centroids1, update_fn, args=(initial_val,)) l.call(tf.ones(shape=input_shape)) clst_indices = l.pulling_indices_tf[weights_name] per_replica = distribution.experimental_local_results(clst_indices) assert_all_cluster_indices(per_replica, 0) second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \ aggregation=tf.VariableAggregation.MEAN) centroids2 = l.cluster_centroids_tf[weights_name] centroids2 = distribution.extended.update( centroids2, update_fn, args=(second_val,)) l.call(tf.ones(shape=input_shape)) clst_indices = l.pulling_indices_tf[weights_name] per_replica = distribution.experimental_local_results(clst_indices) assert_all_cluster_indices(per_replica, 1)
def testClusterCustomNonClusterableLayer(self): with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights(self.custom_non_clusterable_layer, **self.params)
def testCanBeInitializedWithNonIntegerNumberOfClusters(self): with self.assertRaises(ValueError): cluster_wrapper.ClusterWeights(layers.Dense(10), number_of_clusters="13", cluster_centroids_init='linear')
def testClusterReassociation(self): """ Verifies that the association of weights to cluster centroids are updated every iteration. """ # Create a dummy layer for this test input_shape = ( 1, 2, ) l = cluster_wrapper.ClusterWeights( keras.layers.Dense(8, input_shape=input_shape), number_of_clusters=2, cluster_centroids_init=CentroidInitialization.LINEAR) # Build a layer with the given shape l.build(input_shape) # Get name of the clusterable weights clusterable_weights = l.layer.get_clusterable_weights() self.assertEqual(len(clusterable_weights), 1) weights_name = clusterable_weights[0][0] self.assertEqual(weights_name, 'kernel') # Get cluster centroids centroids = l.cluster_centroids_tf[weights_name] # Calculate some statistics of the weights to set the centroids later on mean_weight = tf.reduce_mean(l.layer.kernel) min_weight = tf.reduce_min(l.layer.kernel) max_weight = tf.reduce_max(l.layer.kernel) max_dist = max_weight - min_weight def assert_all_weights_associated(weights, centroid_index): """Helper function to make sure that all weights are associated with one centroid.""" all_associated = tf.reduce_all( tf.equal( weights, tf.constant(centroids[centroid_index], shape=weights.shape))) self.assertTrue(all_associated) # Set centroids so that all weights should be re-associated with centroid 0 centroids[0].assign(mean_weight) centroids[1].assign(mean_weight + 2.0 * max_dist) # Update associations of weights to centroids l.call(tf.ones(shape=input_shape)) # Weights should now be all clustered with the centroid 0 assert_all_weights_associated(l.layer.kernel, centroid_index=0) # Set centroids so that all weights should be re-associated with centroid 1 centroids[0].assign(mean_weight - 2.0 * max_dist) centroids[1].assign(mean_weight) # Update associations of weights to centroids l.call(tf.ones(shape=input_shape)) # Weights should now be all clustered with the centroid 1 assert_all_weights_associated(l.layer.kernel, centroid_index=1)