Ejemplo n.º 1
0
  def test_invariance(self):
    # Make a simple two-layer classifier with pooling for testing.
    resolutions = [16, 8]
    transformer = _get_transformer()
    spins = [[0, -1], [0, 1, 2]]
    channels = [2, 3]
    shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]]
    pair = test_utils.get_rotated_pair(transformer, shape, spins[0],
                                       alpha=1.0, beta=2.0, gamma=3.0)

    model = models.SpinSphericalClassifier(num_classes=5,
                                           resolutions=resolutions,
                                           spins=spins,
                                           widths=channels,
                                           axis_name=None,
                                           input_transformer=transformer)

    params = model.init(_JAX_RANDOM_KEY, pair.sphere, train=False)

    output, _ = model.apply(params, pair.sphere, train=True,
                            mutable=['batch_stats'])
    rotated_output, _ = model.apply(params, pair.rotated_sphere, train=True,
                                    mutable=['batch_stats'])

    # The classifier should be rotation-invariant. Here the tolerance is high
    # because the local pooling introduces equivariance errors.
    self.assertAllClose(rotated_output, output, atol=1e-1)
    self.assertLess(_normalized_mean_absolute_error(output, rotated_output),
                    0.1)
Ejemplo n.º 2
0
 def test_get_rotated_pair_shapes(self, shape):
     transformer = _get_transformer()
     *_, num_spins, _ = shape
     spins = jnp.arange(num_spins)
     pair = test_utils.get_rotated_pair(transformer, shape, spins, 1.0, 2.0,
                                        3.0)
     self.assertEqual(pair.sphere.shape, shape)
     self.assertEqual(pair.rotated_sphere.shape, shape)
Ejemplo n.º 3
0
    def test_swsconv_spatial_spectral_is_equivariant(self, resolution,
                                                     spins_in, spins_out):
        """Tests the SO(3)-equivariance of _swsconv_spatial_spectral()."""
        transformer = _get_transformer()
        num_channels_in, num_channels_out = 2, 3
        # Euler angles.
        alpha, beta, gamma = 1.0, 2.0, 3.0
        shape = (1, resolution, resolution, len(spins_in), num_channels_in)
        pair = test_utils.get_rotated_pair(transformer,
                                           shape=shape,
                                           spins=spins_in,
                                           alpha=alpha,
                                           beta=beta,
                                           gamma=gamma)
        # Get rid of the batch dimension.
        sphere = pair.sphere[0]
        rotated_sphere = pair.rotated_sphere[0]

        # Filter is defined by its spectral coefficients.
        ell_max = resolution // 2 - 1
        shape = [
            ell_max + 1,
            len(spins_in),
            len(spins_out), num_channels_in, num_channels_out
        ]
        # Make more arbitrary reproducible complex inputs.
        filter_coefficients = jnp.linspace(-0.5 + 0.2j, 0.2,
                                           np.prod(shape)).reshape(shape)

        sphere_out = layers._swsconv_spatial_spectral(transformer, sphere,
                                                      filter_coefficients,
                                                      spins_in, spins_out)

        rotated_sphere_out = layers._swsconv_spatial_spectral(
            transformer, rotated_sphere, filter_coefficients, spins_in,
            spins_out)

        # Now since the convolution is SO(3)-equivariant, the same rotation that
        # relates the inputs must relate the outputs. We apply it spectrally.
        coefficients_out = transformer.swsft_forward_spins_channels(
            sphere_out, spins_out)

        # This is R(f) * g (in the spectral domain).
        rotated_coefficients_out_1 = transformer.swsft_forward_spins_channels(
            rotated_sphere_out, spins_out)

        # And this is R(f * g) (in the spectral domain).
        rotated_coefficients_out_2 = test_utils.rotate_coefficients(
            coefficients_out, alpha, beta, gamma)

        # There is some loss of precision on the Wigner-D computation for rotating
        # the coefficients, hence we use a slighly higher tolerance.
        self.assertAllClose(rotated_coefficients_out_1,
                            rotated_coefficients_out_2,
                            atol=1e-5)
Ejemplo n.º 4
0
    def test_get_rotated_pair_azimuthal_rotation(self, shift):
        """Check that azimuthal rotation corresponds to horizontal shift."""
        resolution = 16
        transformer = _get_transformer()

        spins = (0, 1)
        shape = (2, resolution, resolution, len(spins), 2)
        # Convert shift to azimuthal rotation angle.
        gamma = shift * 2 * jnp.pi / resolution
        # sympy returns nans for wigner-ds when beta==0, hence the 1e-8 here.
        beta = 1e-8
        pair = test_utils.get_rotated_pair(transformer, shape, spins, 0.0,
                                           beta, gamma)
        shifted_sphere = jnp.roll(pair.sphere, shift, axis=2)

        self.assertAllClose(shifted_sphere, pair.rotated_sphere)
Ejemplo n.º 5
0
    def test_spin_spherical_mean(self, resolution):
        """Check that spin_spherical_mean is equivariant and Parseval holds."""
        transformer = _get_transformer()
        spins = (0, 1, -1, 2)
        shape = (2, resolution, resolution, len(spins), 2)
        alpha, beta, gamma = 1.0, 2.0, 3.0
        pair = test_utils.get_rotated_pair(transformer, shape, spins, alpha,
                                           beta, gamma)

        # Mean should be zero for spin != 0 so we compare the squared norm.
        abs_squared = lambda x: x.real**2 + x.imag**2
        norm = sphere_utils.spin_spherical_mean(abs_squared(pair.sphere))
        rotated_norm = sphere_utils.spin_spherical_mean(
            abs_squared(pair.rotated_sphere))
        with self.subTest(name="Equivariance"):
            self.assertAllClose(norm, rotated_norm)

        # Compute energy of coefficients and check that Parseval's theorem holds.
        coefficients_norm = jnp.sum(abs_squared(pair.coefficients),
                                    axis=(1, 2))
        with self.subTest(name="Parseval"):
            self.assertAllClose(norm * 4 * np.pi, coefficients_norm)