def __call__(self, inputs, train): """Apply block to `inputs`. Args: inputs: (batch_size, resolution, resolution, n_spins_in, n_channels_in) array of spin-weighted spherical functions with equiangular sampling. train: whether to run in training or inference mode. Returns: A (batch_size, resolution // downsampling_factor, resolution // downsampling_factor, n_spins_out, num_channels) complex64 array. """ feature_maps = inputs if self.downsampling_factor != 1: feature_maps = layers.SphericalPooling( stride=self.downsampling_factor, name='spherical_pool')(feature_maps) feature_maps = layers.SpinSphericalConvolution( features=self.num_channels, spins_in=self.spins_in, spins_out=self.spins_out, num_filter_params=self.num_filter_params, transformer=self.transformer, name='spherical_conv')(feature_maps) return layers.SpinSphericalBatchNormalizationNonlinearity( spins=self.spins_out, use_running_stats=not train, axis_name=self.axis_name, name='batch_norm_nonlin')(feature_maps)
def test_azimuthal_equivariance(self, shift, stride): resolution = 16 spins = (0, -1, 2) transformer = _get_transformer() model = layers.SphericalPooling(stride=stride) output_1, output_2 = test_utils.apply_model_to_azimuthally_rotated_pairs( transformer, model, resolution, spins, shift) self.assertAllClose(output_1, output_2)
def test_SphericalPooling_matches_spin_spherical_mean(self, resolution): """SphericalPooling with max stride must match spin_spherical_mean.""" shape = [2, resolution, resolution, 3, 4] spins = [0, -1, 2] inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape, spins) spherical_mean = sphere_utils.spin_spherical_mean(inputs) model = layers.SphericalPooling(stride=resolution) params = model.init(_JAX_RANDOM_KEY, inputs) pooled = model.apply(params, inputs) # Tolerance here is higher because of slightly different quadratures. self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
def test_constant_latitude_values(self, resolution): """Average for constant-latitude values tilts towards largest area.""" inputs = jnp.zeros([2, resolution, resolution, 1, 1]) first_latitude = 1 second_latitude = 2 inputs = inputs.at[:, 0].set(first_latitude) inputs = inputs.at[:, 1].set(second_latitude) model = layers.SphericalPooling(stride=2) params = model.init(_JAX_RANDOM_KEY, inputs) pooled = model.apply(params, inputs) # Since both the area and the value in the second band are larger than the # first, the output values should be larger than the unweighted average. unweighted = (first_latitude + second_latitude) / 2 self.assertAllGreater(pooled[:, 0], unweighted) # Now we make the second value smaller, so average must be smaller than the # unweighted. second_latitude = 0 inputs = inputs.at[:, 1].set(second_latitude) unweighted = (first_latitude + second_latitude) / 2 pooled = model.apply(params, inputs) self.assertAllLess(pooled[:, 0], unweighted)