Exemplo n.º 1
0
  def __call__(self, inputs, train):
    """Apply the network to `inputs`.

    Args:
      inputs: (batch_size, resolution, resolution, n_spins, n_channels) array of
        spin-weighted spherical functions (SWSF) with equiangular sampling.
      train: whether to run in training or inference mode.
    Returns:
      A (batch_size, num_classes) float32 array with per-class scores (logits).
    """
    resolution, num_spins, num_channels = inputs.shape[2:]
    if (resolution != self.resolutions[0] or
        num_spins != len(self.spins[0]) or
        num_channels != self.widths[0]):
      raise ValueError('Incorrect input dimensions!')

    feature_maps = inputs
    for layer in self.layers:
      feature_maps = layer(feature_maps, train=train)

    # Current feature maps are still spin spherical. Do final processing.
    # Global pooling is not equivariant for spin != 0, so me must take the
    # absolute values before.
    mean_abs = sphere_utils.spin_spherical_mean(jnp.abs(feature_maps))
    mean = sphere_utils.spin_spherical_mean(feature_maps).real
    spins = jnp.expand_dims(jnp.array(self.spins[-1]), [0, 2])
    feature_maps = jnp.where(spins == 0, mean, mean_abs)
    # Shape is now (batch, spins, channel).
    feature_maps = feature_maps.reshape((feature_maps.shape[0], -1))

    return self.final_dense(feature_maps)
Exemplo n.º 2
0
    def test_spin_spherical_mean(self, resolution):
        """Check that spin_spherical_mean is equivariant and Parseval holds."""
        transformer = _get_transformer()
        spins = (0, 1, -1, 2)
        shape = (2, resolution, resolution, len(spins), 2)
        alpha, beta, gamma = 1.0, 2.0, 3.0
        pair = test_utils.get_rotated_pair(transformer, shape, spins, alpha,
                                           beta, gamma)

        # Mean should be zero for spin != 0 so we compare the squared norm.
        abs_squared = lambda x: x.real**2 + x.imag**2
        norm = sphere_utils.spin_spherical_mean(abs_squared(pair.sphere))
        rotated_norm = sphere_utils.spin_spherical_mean(
            abs_squared(pair.rotated_sphere))
        with self.subTest(name="Equivariance"):
            self.assertAllClose(norm, rotated_norm)

        # Compute energy of coefficients and check that Parseval's theorem holds.
        coefficients_norm = jnp.sum(abs_squared(pair.coefficients),
                                    axis=(1, 2))
        with self.subTest(name="Parseval"):
            self.assertAllClose(norm * 4 * np.pi, coefficients_norm)
Exemplo n.º 3
0
    def test_SphericalPooling_matches_spin_spherical_mean(self, resolution):
        """SphericalPooling with max stride must match spin_spherical_mean."""
        shape = [2, resolution, resolution, 3, 4]
        spins = [0, -1, 2]
        inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape,
                                                  spins)
        spherical_mean = sphere_utils.spin_spherical_mean(inputs)

        model = layers.SphericalPooling(stride=resolution)
        params = model.init(_JAX_RANDOM_KEY, inputs)
        pooled = model.apply(params, inputs)

        # Tolerance here is higher because of slightly different quadratures.
        self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
Exemplo n.º 4
0
    def __call__(self, inputs, use_running_stats=None):
        """Normalizes the input using batch (optional) means and variances.

    Stats are computed over the batch and spherical dimensions: (0, 1, 2).

    Args:
      inputs: An array of dimensions (batch_size, resolution, resolution,
        n_spins_in, n_channels_in).
      use_running_stats: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        use_running_stats = nn.module.merge_param("use_running_stats",
                                                  self.use_running_stats,
                                                  use_running_stats)

        # Normalization is independent per spin per channel.
        num_spins, num_channels = inputs.shape[-2:]
        feature_shape = (1, 1, 1, num_spins, num_channels)
        reduced_feature_shape = (num_spins, num_channels)

        initializing = not self.has_variable("batch_stats", "variance")

        running_variance = self.variable("batch_stats", "variance",
                                         lambda s: jnp.ones(s, jnp.float32),
                                         reduced_feature_shape)

        if self.centered:
            running_mean = self.variable("batch_stats", "mean",
                                         lambda s: jnp.zeros(s, jnp.complex64),
                                         reduced_feature_shape)

        if use_running_stats:
            variance = running_variance.value
            if self.centered:
                mean = running_mean.value
        else:
            # Compute the spherical mean over the spherical grid dimensions, then a
            # conventional mean over the batch.
            if self.centered:
                mean = sphere_utils.spin_spherical_mean(inputs)
                mean = jnp.mean(mean, axis=0)
            # Complex variance is E[x x*] - E[x]E[x*].
            # For spin != 0, E[x] should be zero, although due to discretization this
            # is not always true. We only use E[x x*] here.
            # E[x x*]:
            mean_abs_squared = sphere_utils.spin_spherical_mean(inputs *
                                                                inputs.conj())
            mean_abs_squared = jnp.mean(mean_abs_squared, axis=0)
            # Aggregate means over devices.
            if self.axis_name is not None and not initializing:
                if self.centered:
                    mean = lax.pmean(mean, axis_name=self.axis_name)
                mean_abs_squared = lax.pmean(mean_abs_squared,
                                             axis_name=self.axis_name)

            # Imaginary part is negligible.
            variance = mean_abs_squared.real

            if not initializing:
                running_variance.value = (
                    self.momentum * running_variance.value +
                    (1 - self.momentum) * variance)
                if self.centered:
                    running_mean.value = (self.momentum * running_mean.value +
                                          (1 - self.momentum) * mean)

        if self.centered:
            outputs = inputs - mean.reshape(feature_shape)
        else:
            outputs = inputs

        factor = lax.rsqrt(variance.reshape(feature_shape) + self.epsilon)
        if self.use_scale:
            scale = self.param("scale", self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            factor = factor * scale

        outputs = outputs * factor

        if self.use_bias:
            bias = self.param("bias", self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            outputs = outputs + bias

        return outputs
Exemplo n.º 5
0
def _batched_spherical_variance(inputs):
    """Computes variances over the sphere and batch dimensions."""
    # Assumes mean=0 as in SpinSphericalBatchNormalization.
    return sphere_utils.spin_spherical_mean(inputs *
                                            inputs.conj()).mean(axis=0)