def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
                                expected_output):
   """Verifies that ConvolutionalWeightsCA works as expected."""
   clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
   clustering_algo = clustering_registry.ConvolutionalWeightsCA(
       clustering_centroids, GradientAggregation.SUM)
   self._check_pull_values(clustering_algo, pulling_indices, expected_output)
 def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
                                expected_output):
     """
 Verifies that ConvolutionalWeightsCA works as expected.
 """
     ca = clustering_registry.ConvolutionalWeightsCA(clustering_centroids)
     self._pull_values(ca, pulling_indices, expected_output)
  def testConvolutionalWeightsCAGrad(
      self,
      cluster_gradient_aggregation,
      pulling_indices,
      expected_grad_centroids,
  ):
    """Tests that the gradients of ConvolutionalWeightsCA work as expected."""
    clustering_centroids = tf.Variable([0.0, 3.0], dtype=tf.float32)
    weight = tf.constant([[0.1, 0.1, 0.1], [3.0, 3.0, 3.0], [0.2, 0.2, 0.2]])

    clustering_algo = clustering_registry.ConvolutionalWeightsCA(
        clustering_centroids, cluster_gradient_aggregation)
    self._check_gradients_clustered_weight(
        clustering_algo,
        weight,
        pulling_indices,
        expected_grad_centroids,
    )