Esempio n. 1
0
 def test_random_invalid_model(self):
     """Tests random mask with model containing no masked layers."""
     with self.assertRaisesRegex(
             ValueError,
             'Model does not support masking, i.e. no layers are '
             'wrapped by a MaskedModule.'):
         masked.random_mask(self._unmasked_model, self._rng, 0.5)
Esempio n. 2
0
    def test_random_invalid_sparsity(self):
        """Tests random mask with invalid sparsity."""

        with self.subTest(name='random_sparsity_too_small'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, -0.5, is not in range \[0, 1\]'):
                masked.random_mask(self._masked_model, self._rng, -0.5)

        with self.subTest(name='random_sparsity_too_large'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, 1.5, is not in range \[0, 1\]'):
                masked.random_mask(self._masked_model, self._rng, 1.5)
Esempio n. 3
0
    def test_random_mask_sparsity_half_full(self):
        """Tests random mask generation, for a half-full mask."""
        mask = masked.random_mask(self._masked_model, self._rng, 0.5)
        param_len = self._masked_model.params['MaskedModule_0']['unmasked'][
            'kernel'].size
        half_full = param_len / 2

        with self.subTest(name='random_mask_values'):
            self.assertBetween(jnp.sum(mask['MaskedModule_0']['kernel']),
                               0.66 * half_full, 1.33 * half_full)
Esempio n. 4
0
    def test_random_mask_sparsity_full(self):
        """Tests random mask generation, for 100% sparsity."""
        mask = masked.random_mask(self._masked_model, self._rng, 1.)

        with self.subTest(name='random_full_mask_values'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='random_full_mask_dense_values'):
            self.assertTrue((masked_output.all() == 0).all())

        with self.subTest(name='random_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
Esempio n. 5
0
    def test_random_mask_sparsity_empty(self):
        """Tests random mask generation, for 0% sparsity."""
        mask = masked.random_mask(self._masked_model, self._rng, 0.)

        with self.subTest(name='random_empty_mask_values'):
            self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']),
                             mask['MaskedModule_0']['kernel'].size)

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='random_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output, self._unmasked_output).all())

        with self.subTest(name='random_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)