def testDenseWeightsCA(self, clustering_centroids, pulling_indices,
                        expected_output):
     """
 Verifies that DenseWeightsCA works as expected.
 """
     ca = clustering_registry.DenseWeightsCA(clustering_centroids)
     self._pull_values(ca, pulling_indices, expected_output)
 def testDenseWeightsCAGrad(self, clustering_centroids, weight,
                            pulling_indices, expected_output):
     """
 Verifies that the gradients of DenseWeightsCA work as expected.
 """
     ca = clustering_registry.DenseWeightsCA(clustering_centroids)
     self._check_gradients(ca, weight, pulling_indices, expected_output)
 def testDenseWeightsCA(self, clustering_centroids, pulling_indices,
                        expected_output):
   """Verifies that DenseWeightsCA works as expected."""
   clustering_centroids = tf.Variable(clustering_centroids, dtype=tf.float32)
   clustering_algo = clustering_registry.DenseWeightsCA(
       clustering_centroids, GradientAggregation.SUM)
   self._check_pull_values(clustering_algo, pulling_indices, expected_output)
  def testDenseWeightsCAGrad(
      self,
      cluster_gradient_aggregation,
      pulling_indices,
      expected_grad_centroids,
  ):
    """Verifies that the gradients of DenseWeightsCA work as expected."""
    clustering_centroids = tf.Variable([-0.800450444, 0.864694357])
    weight = tf.constant(
        [[0.220442653, 0.854694366, 0.0328432359, 0.506857157],
         [0.0527950861, -0.659555554, -0.849919915, -0.54047],
         [-0.305815876, 0.0865516588, 0.659202456, -0.355699599],
         [-0.348868281, -0.662001, 0.6171574, -0.296582848]])

    clustering_algo = clustering_registry.DenseWeightsCA(
        clustering_centroids, cluster_gradient_aggregation)
    self._check_gradients_clustered_weight(
        clustering_algo,
        weight,
        pulling_indices,
        expected_grad_centroids,
    )
 def testDenseWeightsCA(self, clustering_centroids, pulling_indices,
                        expected_output):
     ca = clustering_registry.DenseWeightsCA(clustering_centroids)
     self._pull_values(ca, pulling_indices, expected_output)