Example #1
0
    def test_equivariance(self, train):
        resolution = 8
        spins = (0, -1, 2)
        transformer = _get_transformer()
        model = models.SpinSphericalBlock(num_channels=2,
                                          spins_in=spins,
                                          spins_out=spins,
                                          downsampling_factor=1,
                                          axis_name=None,
                                          transformer=transformer)

        init_args = dict(train=False)
        apply_args = dict(train=train, mutable=['batch_stats'])

        coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs(
            transformer,
            model,
            resolution,
            spins,
            init_args=init_args,
            apply_args=apply_args)

        # Tolerance needs to be high here due to approximate equivariance. We check
        # the mean absolute error.
        self.assertLess(
            _normalized_mean_absolute_error(coefficients_1, coefficients_2),
            0.1)
    def test_equivariance(self):
        resolution = 8
        spins = (0, -1, 2)
        transformer = _get_transformer()
        model = layers.SpinSphericalConvolution(transformer=transformer,
                                                spins_in=spins,
                                                spins_out=spins,
                                                features=2)

        coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs(
            transformer, model, resolution, spins)

        self.assertAllClose(coefficients_1, coefficients_2, atol=1e-6)
    def test_equivariance(self):
        resolution = 16
        spins = (0, -1, 2)
        transformer = _get_transformer()

        model = layers.MagnitudeNonlinearity(
            bias_initializer=_magnitude_nonlinearity_nonzero_initializer)
        coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs(
            transformer, model, resolution, spins)
        # Tolerance needs to be high here due to approximate equivariance. We also
        # check the mean absolute error.
        self.assertAllClose(coefficients_1, coefficients_2, atol=1e-1)
        self.assertLess(abs(coefficients_1 - coefficients_2).mean(), 5e-3)
    def test_equivariance(self, train):
        resolution = 16
        spins = (0, 1)
        transformer = _get_transformer()
        model = layers.SpinSphericalBatchNormalization(spins=spins)
        init_args = dict(use_running_stats=True)
        apply_args = dict(use_running_stats=not train, mutable=["batch_stats"])

        coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs(
            transformer,
            model,
            resolution,
            spins,
            init_args=init_args,
            apply_args=apply_args)

        self.assertAllClose(coefficients_1, coefficients_2, atol=1e-5)
Example #5
0
    def test_apply_model_to_rotated_pairs_with_simple_model(self):
        transformer = _get_transformer()
        resolution = 8
        spins = (0, 1)

        # We use a dummy model that doubles its inputs. Outputs of
        # `apply_model_to_rotated_pairs` must be equal and double of the input
        # rotated coefficients.
        class Double(nn.Module):
            @nn.compact
            def __call__(self, inputs):
                return 2 * inputs

        output_1, output_2, pair = test_utils.apply_model_to_rotated_pairs(
            transformer, Double(), resolution, spins)

        self.assertAllClose(2 * pair.rotated_coefficients, output_1)
        self.assertAllClose(2 * pair.rotated_coefficients, output_2)