Example #1
0
  def __call__(self, inputs, train):
    """Apply block to `inputs`.

    Args:
      inputs: (batch_size, resolution, resolution, n_spins_in, n_channels_in)
        array of spin-weighted spherical functions with equiangular sampling.
      train: whether to run in training or inference mode.
    Returns:
      A (batch_size, resolution // downsampling_factor, resolution //
        downsampling_factor, n_spins_out, num_channels) complex64 array.
    """
    feature_maps = inputs
    if self.downsampling_factor != 1:
      feature_maps = layers.SphericalPooling(
          stride=self.downsampling_factor, name='spherical_pool')(feature_maps)

    feature_maps = layers.SpinSphericalConvolution(
        features=self.num_channels,
        spins_in=self.spins_in,
        spins_out=self.spins_out,
        num_filter_params=self.num_filter_params,
        transformer=self.transformer,
        name='spherical_conv')(feature_maps)

    return layers.SpinSphericalBatchNormalizationNonlinearity(
        spins=self.spins_out,
        use_running_stats=not train,
        axis_name=self.axis_name,
        name='batch_norm_nonlin')(feature_maps)
    def test_azimuthal_equivariance(self, shift, stride):
        resolution = 16
        spins = (0, -1, 2)
        transformer = _get_transformer()
        model = layers.SphericalPooling(stride=stride)

        output_1, output_2 = test_utils.apply_model_to_azimuthally_rotated_pairs(
            transformer, model, resolution, spins, shift)

        self.assertAllClose(output_1, output_2)
    def test_SphericalPooling_matches_spin_spherical_mean(self, resolution):
        """SphericalPooling with max stride must match spin_spherical_mean."""
        shape = [2, resolution, resolution, 3, 4]
        spins = [0, -1, 2]
        inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape,
                                                  spins)
        spherical_mean = sphere_utils.spin_spherical_mean(inputs)

        model = layers.SphericalPooling(stride=resolution)
        params = model.init(_JAX_RANDOM_KEY, inputs)
        pooled = model.apply(params, inputs)

        # Tolerance here is higher because of slightly different quadratures.
        self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
    def test_constant_latitude_values(self, resolution):
        """Average for constant-latitude values tilts towards largest area."""
        inputs = jnp.zeros([2, resolution, resolution, 1, 1])
        first_latitude = 1
        second_latitude = 2
        inputs = inputs.at[:, 0].set(first_latitude)
        inputs = inputs.at[:, 1].set(second_latitude)

        model = layers.SphericalPooling(stride=2)
        params = model.init(_JAX_RANDOM_KEY, inputs)

        pooled = model.apply(params, inputs)
        # Since both the area and the value in the second band are larger than the
        # first, the output values should be larger than the unweighted average.
        unweighted = (first_latitude + second_latitude) / 2
        self.assertAllGreater(pooled[:, 0], unweighted)

        # Now we make the second value smaller, so average must be smaller than the
        # unweighted.
        second_latitude = 0
        inputs = inputs.at[:, 1].set(second_latitude)
        unweighted = (first_latitude + second_latitude) / 2
        pooled = model.apply(params, inputs)
        self.assertAllLess(pooled[:, 0], unweighted)