Пример #1
0
    def __call__(self, inputs):
        """Applies spherical pooling.

    Args:
      inputs: An array of dimensions (batch_size, resolution, resolution,
      n_spins_in, n_channels_in).
    Returns:
      An array of dimensions (batch_size, resolution // stride, resolution //
      stride, n_spins_in, n_channels_in).
    """
        # We use variables to cache the in/out weights.
        resolution_in = inputs.shape[1]
        resolution_out = resolution_in // self.stride
        weights_in = sphere_utils.sphere_quadrature_weights(resolution_in)
        weights_out = sphere_utils.sphere_quadrature_weights(resolution_out)

        weighted = inputs * jnp.expand_dims(weights_in, (0, 2, 3, 4))
        pooled = nn.avg_pool(weighted,
                             window_shape=(self.stride, self.stride, 1),
                             strides=(self.stride, self.stride, 1))
        # This was average pooled. We multiply by stride**2 to obtain the sum
        # pooled, then divide by output weights to get the weighted average.
        pooled = (pooled * self.stride**2 /
                  jnp.expand_dims(weights_out, (0, 2, 3, 4)))

        return pooled
 def test_sphere_quadrature_weights_2x2(self):
   """In a 2x2 discretization, all areas are equal."""
   areas = sphere_utils.sphere_quadrature_weights(2)
   self.assertAllClose(areas,
                       np.ones_like(areas) * np.pi)
 def test_spherical_cell_area_sum(self, resolution):
   """Sum of spherical cell areas must match spherical surface area."""
   areas = sphere_utils.sphere_quadrature_weights(resolution)
   self.assertAllClose(areas.sum() * resolution, 4*np.pi)