def test_equivariance(self, train): resolution = 8 spins = (0, -1, 2) transformer = _get_transformer() model = models.SpinSphericalBlock(num_channels=2, spins_in=spins, spins_out=spins, downsampling_factor=1, axis_name=None, transformer=transformer) init_args = dict(train=False) apply_args = dict(train=train, mutable=['batch_stats']) coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs( transformer, model, resolution, spins, init_args=init_args, apply_args=apply_args) # Tolerance needs to be high here due to approximate equivariance. We check # the mean absolute error. self.assertLess( _normalized_mean_absolute_error(coefficients_1, coefficients_2), 0.1)
def test_equivariance(self): resolution = 8 spins = (0, -1, 2) transformer = _get_transformer() model = layers.SpinSphericalConvolution(transformer=transformer, spins_in=spins, spins_out=spins, features=2) coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs( transformer, model, resolution, spins) self.assertAllClose(coefficients_1, coefficients_2, atol=1e-6)
def test_equivariance(self): resolution = 16 spins = (0, -1, 2) transformer = _get_transformer() model = layers.MagnitudeNonlinearity( bias_initializer=_magnitude_nonlinearity_nonzero_initializer) coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs( transformer, model, resolution, spins) # Tolerance needs to be high here due to approximate equivariance. We also # check the mean absolute error. self.assertAllClose(coefficients_1, coefficients_2, atol=1e-1) self.assertLess(abs(coefficients_1 - coefficients_2).mean(), 5e-3)
def test_equivariance(self, train): resolution = 16 spins = (0, 1) transformer = _get_transformer() model = layers.SpinSphericalBatchNormalization(spins=spins) init_args = dict(use_running_stats=True) apply_args = dict(use_running_stats=not train, mutable=["batch_stats"]) coefficients_1, coefficients_2, _ = test_utils.apply_model_to_rotated_pairs( transformer, model, resolution, spins, init_args=init_args, apply_args=apply_args) self.assertAllClose(coefficients_1, coefficients_2, atol=1e-5)
def test_apply_model_to_rotated_pairs_with_simple_model(self): transformer = _get_transformer() resolution = 8 spins = (0, 1) # We use a dummy model that doubles its inputs. Outputs of # `apply_model_to_rotated_pairs` must be equal and double of the input # rotated coefficients. class Double(nn.Module): @nn.compact def __call__(self, inputs): return 2 * inputs output_1, output_2, pair = test_utils.apply_model_to_rotated_pairs( transformer, Double(), resolution, spins) self.assertAllClose(2 * pair.rotated_coefficients, output_1) self.assertAllClose(2 * pair.rotated_coefficients, output_2)