def test_random_invalid_model(self): """Tests random mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.random_mask(self._unmasked_model, self._rng, 0.5)
def test_random_invalid_sparsity(self): """Tests random mask with invalid sparsity.""" with self.subTest(name='random_sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.random_mask(self._masked_model, self._rng, -0.5) with self.subTest(name='random_sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.random_mask(self._masked_model, self._rng, 1.5)
def test_random_mask_sparsity_half_full(self): """Tests random mask generation, for a half-full mask.""" mask = masked.random_mask(self._masked_model, self._rng, 0.5) param_len = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].size half_full = param_len / 2 with self.subTest(name='random_mask_values'): self.assertBetween(jnp.sum(mask['MaskedModule_0']['kernel']), 0.66 * half_full, 1.33 * half_full)
def test_random_mask_sparsity_full(self): """Tests random mask generation, for 100% sparsity.""" mask = masked.random_mask(self._masked_model, self._rng, 1.) with self.subTest(name='random_full_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all()) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='random_full_mask_dense_values'): self.assertTrue((masked_output.all() == 0).all()) with self.subTest(name='random_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_random_mask_sparsity_empty(self): """Tests random mask generation, for 0% sparsity.""" mask = masked.random_mask(self._masked_model, self._rng, 0.) with self.subTest(name='random_empty_mask_values'): self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']), mask['MaskedModule_0']['kernel'].size) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='random_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='random_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)