def test_prune_one_layer_conv_no_mask(self): """Tests pruning of model with one conv. layer without an existing mask.""" pruned_mask = pruning.prune(self._masked_conv_model, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=1)
def test_prune_single_layer_local_pruning(self): """Test pruning of model with a single layer, and local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model, { 'MaskedModule_0': 0.5, }) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_two_layers_dense_no_mask(self): """Tests pruning of model with two dense layers without an existing mask.""" pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_single_layer_dense_with_mask(self): """Tests pruning of single dense layer with an existing mask.""" pruned_mask = pruning.prune(self._masked_model, 0.5, mask=masked.shuffled_mask( self._masked_model, self._rng, 0.95)) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
def test_prune_two_layer_local_pruning_rate(self): """Test pruning of model with two layers, and a local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model_twolayer, { 'MaskedModule_1': 0.5, }) mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0']) mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1']) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_layer_0_sparsity'): self.assertEqual(mask_layer_0_sparsity, 0.) with self.subTest(name='test_mask_layer_1_sparsity'): self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)