def test_invariance(self): # Make a simple two-layer classifier with pooling for testing. resolutions = [16, 8] transformer = _get_transformer() spins = [[0, -1], [0, 1, 2]] channels = [2, 3] shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]] pair = test_utils.get_rotated_pair(transformer, shape, spins[0], alpha=1.0, beta=2.0, gamma=3.0) model = models.SpinSphericalClassifier(num_classes=5, resolutions=resolutions, spins=spins, widths=channels, axis_name=None, input_transformer=transformer) params = model.init(_JAX_RANDOM_KEY, pair.sphere, train=False) output, _ = model.apply(params, pair.sphere, train=True, mutable=['batch_stats']) rotated_output, _ = model.apply(params, pair.rotated_sphere, train=True, mutable=['batch_stats']) # The classifier should be rotation-invariant. Here the tolerance is high # because the local pooling introduces equivariance errors. self.assertAllClose(rotated_output, output, atol=1e-1) self.assertLess(_normalized_mean_absolute_error(output, rotated_output), 0.1)
def test_get_rotated_pair_shapes(self, shape): transformer = _get_transformer() *_, num_spins, _ = shape spins = jnp.arange(num_spins) pair = test_utils.get_rotated_pair(transformer, shape, spins, 1.0, 2.0, 3.0) self.assertEqual(pair.sphere.shape, shape) self.assertEqual(pair.rotated_sphere.shape, shape)
def test_swsconv_spatial_spectral_is_equivariant(self, resolution, spins_in, spins_out): """Tests the SO(3)-equivariance of _swsconv_spatial_spectral().""" transformer = _get_transformer() num_channels_in, num_channels_out = 2, 3 # Euler angles. alpha, beta, gamma = 1.0, 2.0, 3.0 shape = (1, resolution, resolution, len(spins_in), num_channels_in) pair = test_utils.get_rotated_pair(transformer, shape=shape, spins=spins_in, alpha=alpha, beta=beta, gamma=gamma) # Get rid of the batch dimension. sphere = pair.sphere[0] rotated_sphere = pair.rotated_sphere[0] # Filter is defined by its spectral coefficients. ell_max = resolution // 2 - 1 shape = [ ell_max + 1, len(spins_in), len(spins_out), num_channels_in, num_channels_out ] # Make more arbitrary reproducible complex inputs. filter_coefficients = jnp.linspace(-0.5 + 0.2j, 0.2, np.prod(shape)).reshape(shape) sphere_out = layers._swsconv_spatial_spectral(transformer, sphere, filter_coefficients, spins_in, spins_out) rotated_sphere_out = layers._swsconv_spatial_spectral( transformer, rotated_sphere, filter_coefficients, spins_in, spins_out) # Now since the convolution is SO(3)-equivariant, the same rotation that # relates the inputs must relate the outputs. We apply it spectrally. coefficients_out = transformer.swsft_forward_spins_channels( sphere_out, spins_out) # This is R(f) * g (in the spectral domain). rotated_coefficients_out_1 = transformer.swsft_forward_spins_channels( rotated_sphere_out, spins_out) # And this is R(f * g) (in the spectral domain). rotated_coefficients_out_2 = test_utils.rotate_coefficients( coefficients_out, alpha, beta, gamma) # There is some loss of precision on the Wigner-D computation for rotating # the coefficients, hence we use a slighly higher tolerance. self.assertAllClose(rotated_coefficients_out_1, rotated_coefficients_out_2, atol=1e-5)
def test_get_rotated_pair_azimuthal_rotation(self, shift): """Check that azimuthal rotation corresponds to horizontal shift.""" resolution = 16 transformer = _get_transformer() spins = (0, 1) shape = (2, resolution, resolution, len(spins), 2) # Convert shift to azimuthal rotation angle. gamma = shift * 2 * jnp.pi / resolution # sympy returns nans for wigner-ds when beta==0, hence the 1e-8 here. beta = 1e-8 pair = test_utils.get_rotated_pair(transformer, shape, spins, 0.0, beta, gamma) shifted_sphere = jnp.roll(pair.sphere, shift, axis=2) self.assertAllClose(shifted_sphere, pair.rotated_sphere)
def test_spin_spherical_mean(self, resolution): """Check that spin_spherical_mean is equivariant and Parseval holds.""" transformer = _get_transformer() spins = (0, 1, -1, 2) shape = (2, resolution, resolution, len(spins), 2) alpha, beta, gamma = 1.0, 2.0, 3.0 pair = test_utils.get_rotated_pair(transformer, shape, spins, alpha, beta, gamma) # Mean should be zero for spin != 0 so we compare the squared norm. abs_squared = lambda x: x.real**2 + x.imag**2 norm = sphere_utils.spin_spherical_mean(abs_squared(pair.sphere)) rotated_norm = sphere_utils.spin_spherical_mean( abs_squared(pair.rotated_sphere)) with self.subTest(name="Equivariance"): self.assertAllClose(norm, rotated_norm) # Compute energy of coefficients and check that Parseval's theorem holds. coefficients_norm = jnp.sum(abs_squared(pair.coefficients), axis=(1, 2)) with self.subTest(name="Parseval"): self.assertAllClose(norm * 4 * np.pi, coefficients_norm)