def test_azimuthal_equivariance(self, shift, train, downsampling_factor=1, num_filter_params=None): resolution = 8 transformer = _get_transformer() spins = (0, 1, 2) shape = (2, resolution, resolution, len(spins), 2) sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins) rotated_sphere = jnp.roll(sphere, shift, axis=2) model = models.SpinSphericalBlock(num_channels=2, spins_in=spins, spins_out=spins, downsampling_factor=downsampling_factor, num_filter_params=num_filter_params, axis_name=None, transformer=transformer) params = model.init(_JAX_RANDOM_KEY, sphere, train=False) # Add negative bias so that the magnitude nonlinearity is active. params = params.unfreeze() for key, value in params['params']['batch_norm_nonlin'].items(): if 'magnitude_nonlin' in key: value['bias'] -= 0.1 output, _ = model.apply(params, sphere, train=train, mutable=['batch_stats']) rotated_output, _ = model.apply(params, rotated_sphere, train=train, mutable=['batch_stats']) shifted_output = jnp.roll(output, shift // downsampling_factor, axis=2) self.assertAllClose(rotated_output, shifted_output, atol=1e-6)
def test_azimuthal_invariance(self, shift): # Make a simple two-layer classifier with pooling for testing. resolutions = [8, 4] transformer = _get_transformer() spins = [[0, -1], [0, 1, 2]] channels = [2, 3] shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]] sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins[0]) rotated_sphere = jnp.roll(sphere, shift, axis=2) model = models.SpinSphericalClassifier(num_classes=5, resolutions=resolutions, spins=spins, widths=channels, axis_name=None, input_transformer=transformer) params = model.init(_JAX_RANDOM_KEY, sphere, train=False) output, _ = model.apply(params, sphere, train=True, mutable=['batch_stats']) rotated_output, _ = model.apply(params, rotated_sphere, train=True, mutable=['batch_stats']) # The classifier should be rotation-invariant. self.assertAllClose(rotated_output, output, atol=1e-6)
def test_SphericalPooling_matches_spin_spherical_mean(self, resolution): """SphericalPooling with max stride must match spin_spherical_mean.""" shape = [2, resolution, resolution, 3, 4] spins = [0, -1, 2] inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape, spins) spherical_mean = sphere_utils.spin_spherical_mean(inputs) model = layers.SphericalPooling(stride=resolution) params = model.init(_JAX_RANDOM_KEY, inputs) pooled = model.apply(params, inputs) # Tolerance here is higher because of slightly different quadratures. self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
def _evaluate_magnitudenonlinearity_versions(spins): """Evaluates MagnitudeNonlinearity and MagnitudeNonlinearityLeakyRelu.""" transformer = _get_transformer() inputs, _ = test_utils.get_spin_spherical(transformer, shape=(2, 8, 8, len(spins), 2), spins=spins) model = layers.MagnitudeNonlinearity( bias_initializer=_magnitude_nonlinearity_nonzero_initializer) params = model.init(_JAX_RANDOM_KEY, inputs) outputs = model.apply(params, inputs) model_relu = layers.MagnitudeNonlinearityLeakyRelu( spins=spins, bias_initializer=_magnitude_nonlinearity_nonzero_initializer) params_relu = model_relu.init(_JAX_RANDOM_KEY, inputs) outputs_relu = model_relu.apply(params_relu, inputs) return inputs, outputs, outputs_relu