def test_magnitude_thresholding(self, input_shape):
        small_row = 3
        inputs = jnp.ones(input_shape)
        inputs = inputs.at[:, small_row].set(0.1)

        model = layers.MagnitudeNonlinearity()
        params = model.init(_JAX_RANDOM_KEY, inputs)

        # With zero bias output must match input.
        bias = params["params"]["bias"].at[:].set(0.0)
        inputs_unchanged = model.apply(params, inputs)
        self.assertAllClose(inputs, inputs_unchanged)

        # We run again with bias = -0.2; now out[small_row] must be zero.
        bias_value = -0.2
        bias = params["params"]["bias"].at[:].set(bias_value)
        params_changed = flax.core.FrozenDict({"params": {"bias": bias}})
        inputs_changed = model.apply(params_changed, inputs)
        self.assertAllEqual(inputs_changed[:, small_row],
                            np.zeros_like(inputs[:, small_row]))
        # All other rows have the bias added.
        self.assertAllClose(inputs_changed[:, :small_row],
                            inputs[:, :small_row] + bias_value)
        self.assertAllClose(inputs_changed[:, small_row + 1:],
                            inputs[:, small_row + 1:] + bias_value)
    def test_azimuthal_equivariance(self, shift):
        resolution = 8
        spins = (0, -1, 2)
        transformer = _get_transformer()

        model = layers.MagnitudeNonlinearity(
            bias_initializer=_magnitude_nonlinearity_nonzero_initializer)

        output_1, output_2 = test_utils.apply_model_to_azimuthally_rotated_pairs(
            transformer, model, resolution, spins, shift)

        self.assertAllClose(output_1, output_2)
    def test_equivariance(self):
        resolution = 16
        spins = (0, -1, 2)
        transformer = _get_transformer()

        model = layers.MagnitudeNonlinearity(
            bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
        coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs(
            transformer, model, resolution, spins)
        # Tolerance needs to be high here due to approximate equivariance. We also
        # check the mean absolute error.
        self.assertAllClose(coefficients_1, coefficients_2, atol=1e-1)
        self.assertLess(abs(coefficients_1 - coefficients_2).mean(), 5e-3)
def _evaluate_magnitudenonlinearity_versions(spins):
    """Evaluates MagnitudeNonlinearity and MagnitudeNonlinearityLeakyRelu."""
    transformer = _get_transformer()
    inputs, _ = test_utils.get_spin_spherical(transformer,
                                              shape=(2, 8, 8, len(spins), 2),
                                              spins=spins)
    model = layers.MagnitudeNonlinearity(
        bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
    params = model.init(_JAX_RANDOM_KEY, inputs)
    outputs = model.apply(params, inputs)

    model_relu = layers.MagnitudeNonlinearityLeakyRelu(
        spins=spins,
        bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
    params_relu = model_relu.init(_JAX_RANDOM_KEY, inputs)
    outputs_relu = model_relu.apply(params_relu, inputs)

    return inputs, outputs, outputs_relu