Ejemplo n.º 1
0
  def test_azimuthal_equivariance(self, shift, train,
                                  downsampling_factor=1,
                                  num_filter_params=None):
    resolution = 8
    transformer = _get_transformer()
    spins = (0, 1, 2)
    shape = (2, resolution, resolution, len(spins), 2)

    sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins)
    rotated_sphere = jnp.roll(sphere, shift, axis=2)

    model = models.SpinSphericalBlock(num_channels=2,
                                      spins_in=spins,
                                      spins_out=spins,
                                      downsampling_factor=downsampling_factor,
                                      num_filter_params=num_filter_params,
                                      axis_name=None,
                                      transformer=transformer)
    params = model.init(_JAX_RANDOM_KEY, sphere, train=False)

    # Add negative bias so that the magnitude nonlinearity is active.
    params = params.unfreeze()
    for key, value in params['params']['batch_norm_nonlin'].items():
      if 'magnitude_nonlin' in key:
        value['bias'] -= 0.1

    output, _ = model.apply(params, sphere, train=train,
                            mutable=['batch_stats'])
    rotated_output, _ = model.apply(params, rotated_sphere, train=train,
                                    mutable=['batch_stats'])
    shifted_output = jnp.roll(output, shift // downsampling_factor, axis=2)

    self.assertAllClose(rotated_output, shifted_output, atol=1e-6)
Ejemplo n.º 2
0
  def test_azimuthal_invariance(self, shift):
    # Make a simple two-layer classifier with pooling for testing.
    resolutions = [8, 4]
    transformer = _get_transformer()
    spins = [[0, -1], [0, 1, 2]]
    channels = [2, 3]
    shape = [2, resolutions[0], resolutions[0], len(spins[0]), channels[0]]
    sphere, _ = test_utils.get_spin_spherical(transformer, shape, spins[0])
    rotated_sphere = jnp.roll(sphere, shift, axis=2)

    model = models.SpinSphericalClassifier(num_classes=5,
                                           resolutions=resolutions,
                                           spins=spins,
                                           widths=channels,
                                           axis_name=None,
                                           input_transformer=transformer)
    params = model.init(_JAX_RANDOM_KEY, sphere, train=False)

    output, _ = model.apply(params, sphere, train=True,
                            mutable=['batch_stats'])
    rotated_output, _ = model.apply(params, rotated_sphere, train=True,
                                    mutable=['batch_stats'])

    # The classifier should be rotation-invariant.
    self.assertAllClose(rotated_output, output, atol=1e-6)
Ejemplo n.º 3
0
    def test_SphericalPooling_matches_spin_spherical_mean(self, resolution):
        """SphericalPooling with max stride must match spin_spherical_mean."""
        shape = [2, resolution, resolution, 3, 4]
        spins = [0, -1, 2]
        inputs, _ = test_utils.get_spin_spherical(_get_transformer(), shape,
                                                  spins)
        spherical_mean = sphere_utils.spin_spherical_mean(inputs)

        model = layers.SphericalPooling(stride=resolution)
        params = model.init(_JAX_RANDOM_KEY, inputs)
        pooled = model.apply(params, inputs)

        # Tolerance here is higher because of slightly different quadratures.
        self.assertAllClose(spherical_mean, pooled[:, 0, 0], atol=1e-3)
Ejemplo n.º 4
0
def _evaluate_magnitudenonlinearity_versions(spins):
    """Evaluates MagnitudeNonlinearity and MagnitudeNonlinearityLeakyRelu."""
    transformer = _get_transformer()
    inputs, _ = test_utils.get_spin_spherical(transformer,
                                              shape=(2, 8, 8, len(spins), 2),
                                              spins=spins)
    model = layers.MagnitudeNonlinearity(
        bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
    params = model.init(_JAX_RANDOM_KEY, inputs)
    outputs = model.apply(params, inputs)

    model_relu = layers.MagnitudeNonlinearityLeakyRelu(
        spins=spins,
        bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
    params_relu = model_relu.init(_JAX_RANDOM_KEY, inputs)
    outputs_relu = model_relu.apply(params_relu, inputs)

    return inputs, outputs, outputs_relu