def test_mask_sparsity_mixed_mask(self): """Tests mask calculation with a mask different sparsity masked layers.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_conv_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } mask_sparsity = masked.mask_sparsity(mask) true_sparsity = self._masked_conv_model_twolayer.params[ 'MaskedModule_1']['unmasked']['kernel'].size / ( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].size + self._masked_conv_model_twolayer. params['MaskedModule_1']['unmasked']['kernel'].size) self.assertAlmostEqual(mask_sparsity, 1.0 - true_sparsity)
def test_prune_one_layer_conv_no_mask(self): """Tests pruning of model with one conv. layer without an existing mask.""" pruned_mask = pruning.prune(self._masked_conv_model, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=1)
def test_prune_two_layer_local_pruning_rate(self): """Test pruning of model with two layers, and a local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model_twolayer, { 'MaskedModule_1': 0.5, }) mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0']) mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1']) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_layer_0_sparsity'): self.assertEqual(mask_layer_0_sparsity, 0.) with self.subTest(name='test_mask_layer_1_sparsity'): self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)
def test_prune_single_layer_local_pruning(self): """Test pruning of model with a single layer, and local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model, { 'MaskedModule_0': 0.5, }) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_two_layers_dense_no_mask(self): """Tests pruning of model with two dense layers without an existing mask.""" pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_single_layer_dense_with_mask(self): """Tests pruning of single dense layer with an existing mask.""" pruned_mask = pruning.prune(self._masked_model, 0.5, mask=masked.shuffled_mask( self._masked_model, self._rng, 0.95)) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
def get_mask_stats(mask): """Calculates an array of mask statistics. Args: mask: A model mask to calculate the statistics of. Returns: A dictionary, containing a set of mask statistics. """ mask_stats = count_permutations_mask(mask) mask_stats.update({ 'sparsity': masked.mask_sparsity(mask), 'permutation_num_digits': len(str(mask_stats['permutations'])), 'permutation_log10': math.log10(mask_stats['permutations'] + 1), }) return mask_stats
def test_mask_sparsity_ones_mask(self): """Tests mask calculation with a mask full of ones.""" one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) self.assertEqual(masked.mask_sparsity(one_mask), 1.)
def test_mask_sparsity_zero_mask(self): """Tests mask calculation with a zeroed mask.""" zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) self.assertEqual(masked.mask_sparsity(zero_mask), 0.)