def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices, expected_output): """Verifies that ConvolutionalWeightsCA works as expected.""" clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32) clustering_algo = clustering_registry.ConvolutionalWeightsCA( clustering_centroids, GradientAggregation.SUM) self._check_pull_values(clustering_algo, pulling_indices, expected_output)
def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices, expected_output): """ Verifies that ConvolutionalWeightsCA works as expected. """ ca = clustering_registry.ConvolutionalWeightsCA(clustering_centroids) self._pull_values(ca, pulling_indices, expected_output)
def testConvolutionalWeightsCAGrad( self, cluster_gradient_aggregation, pulling_indices, expected_grad_centroids, ): """Tests that the gradients of ConvolutionalWeightsCA work as expected.""" clustering_centroids = tf.Variable([0.0, 3.0], dtype=tf.float32) weight = tf.constant([[0.1, 0.1, 0.1], [3.0, 3.0, 3.0], [0.2, 0.2, 0.2]]) clustering_algo = clustering_registry.ConvolutionalWeightsCA( clustering_centroids, cluster_gradient_aggregation) self._check_gradients_clustered_weight( clustering_algo, weight, pulling_indices, expected_grad_centroids, )