Esempio n. 1
0
    def test_mask_layer_sparsity_ones_mask(self):
        """Tests mask calculation with a mask full of ones."""
        one_mask = masked.simple_mask(self._masked_model, jnp.zeros,
                                      ['kernel'])

        self.assertEqual(
            masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.)
Esempio n. 2
0
    def test_mask_layer_sparsity_zero_mask(self):
        """Tests mask calculation with a zeroed mask."""
        zero_mask = masked.simple_mask(self._masked_model, jnp.ones,
                                       ['kernel'])

        self.assertEqual(
            masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.)
Esempio n. 3
0
    def test_mask_layer_sparsity_half_mask(self):
        """Tests mask calculation with a half-filled mask."""
        half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)

        self.assertAlmostEqual(
            masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5)