def __call__(self, inputs, train): """Apply the network to `inputs`. Args: inputs: (batch_size, resolution, resolution, n_spins, n_channels) array of spin-weighted spherical functions (SWSF) with equiangular sampling. train: whether to run in training or inference mode. Returns: A (batch_size, num_classes) float32 array with per-class scores (logits). """ resolution, num_spins, num_channels = inputs.shape[2:] if (resolution != self.resolutions[0] or num_spins != len(self.spins[0]) or num_channels != self.widths[0]): raise ValueError('Incorrect input dimensions!') feature_maps = inputs for layer in self.layers: feature_maps = layer(feature_maps, train=train) # Current feature maps are still spin spherical. Do final processing. # Global pooling is not equivariant for spin != 0, so me must take the # absolute values before. mean_abs = sphere_utils.spin_spherical_mean(jnp.abs(feature_maps)) mean = sphere_utils.spin_spherical_mean(feature_maps).real spins = jnp.expand_dims(jnp.array(self.spins[-1]), [0, 2]) feature_maps = jnp.where(spins == 0, mean, mean_abs) # Shape is now (batch, spins, channel). feature_maps = feature_maps.reshape((feature_maps.shape[0], -1)) return self.final_dense(feature_maps)
def test_spin_spherical_mean(self, resolution): """Check that spin_spherical_mean is equivariant and Parseval holds.""" transformer = _get_transformer() spins = (0, 1, -1, 2) shape = (2, resolution, resolution, len(spins), 2) alpha, beta, gamma = 1.0, 2.0, 3.0 pair = test_utils.get_rotated_pair(transformer, shape, spins, alpha, beta, gamma) # Mean should be zero for spin != 0 so we compare the squared norm. abs_squared = lambda x: x.real**2 + x.imag**2 norm = sphere_utils.spin_spherical_mean(abs_squared(pair.sphere)) rotated_norm = sphere_utils.spin_spherical_mean( abs_squared(pair.rotated_sphere)) with self.subTest(name="Equivariance"): self.assertAllClose(norm, rotated_norm) # Compute energy of coefficients and check that Parseval's theorem holds. coefficients_norm = jnp.sum(abs_squared(pair.coefficients), axis=(1, 2)) with self.subTest(name="Parseval"): self.assertAllClose(norm * 4 * np.pi, coefficients_norm)
def test_SphericalPooling_matches_spin_spherical_mean(self, resolution): """SphericalPooling with max stride must match spin_spherical_mean.""" shape = [2, resolution, resolution, 3, 4] spins = [0, -1, 2] inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape, spins) spherical_mean = sphere_utils.spin_spherical_mean(inputs) model = layers.SphericalPooling(stride=resolution) params = model.init(_JAX_RANDOM_KEY, inputs) pooled = model.apply(params, inputs) # Tolerance here is higher because of slightly different quadratures. self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
def __call__(self, inputs, use_running_stats=None): """Normalizes the input using batch (optional) means and variances. Stats are computed over the batch and spherical dimensions: (0, 1, 2). Args: inputs: An array of dimensions (batch_size, resolution, resolution, n_spins_in, n_channels_in). use_running_stats: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_stats = nn.module.merge_param("use_running_stats", self.use_running_stats, use_running_stats) # Normalization is independent per spin per channel. num_spins, num_channels = inputs.shape[-2:] feature_shape = (1, 1, 1, num_spins, num_channels) reduced_feature_shape = (num_spins, num_channels) initializing = not self.has_variable("batch_stats", "variance") running_variance = self.variable("batch_stats", "variance", lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if self.centered: running_mean = self.variable("batch_stats", "mean", lambda s: jnp.zeros(s, jnp.complex64), reduced_feature_shape) if use_running_stats: variance = running_variance.value if self.centered: mean = running_mean.value else: # Compute the spherical mean over the spherical grid dimensions, then a # conventional mean over the batch. if self.centered: mean = sphere_utils.spin_spherical_mean(inputs) mean = jnp.mean(mean, axis=0) # Complex variance is E[x x*] - E[x]E[x*]. # For spin != 0, E[x] should be zero, although due to discretization this # is not always true. We only use E[x x*] here. # E[x x*]: mean_abs_squared = sphere_utils.spin_spherical_mean(inputs * inputs.conj()) mean_abs_squared = jnp.mean(mean_abs_squared, axis=0) # Aggregate means over devices. if self.axis_name is not None and not initializing: if self.centered: mean = lax.pmean(mean, axis_name=self.axis_name) mean_abs_squared = lax.pmean(mean_abs_squared, axis_name=self.axis_name) # Imaginary part is negligible. variance = mean_abs_squared.real if not initializing: running_variance.value = ( self.momentum * running_variance.value + (1 - self.momentum) * variance) if self.centered: running_mean.value = (self.momentum * running_mean.value + (1 - self.momentum) * mean) if self.centered: outputs = inputs - mean.reshape(feature_shape) else: outputs = inputs factor = lax.rsqrt(variance.reshape(feature_shape) + self.epsilon) if self.use_scale: scale = self.param("scale", self.scale_init, reduced_feature_shape).reshape(feature_shape) factor = factor * scale outputs = outputs * factor if self.use_bias: bias = self.param("bias", self.bias_init, reduced_feature_shape).reshape(feature_shape) outputs = outputs + bias return outputs
def _batched_spherical_variance(inputs): """Computes variances over the sphere and batch dimensions.""" # Assumes mean=0 as in SpinSphericalBatchNormalization. return sphere_utils.spin_spherical_mean(inputs * inputs.conj()).mean(axis=0)