コード例 #1
0
ファイル: pruning_test.py プロジェクト: tawawhite/rigl
    def test_prune_one_layer_conv_no_mask(self):
        """Tests pruning of model with one conv. layer without an existing mask."""
        pruned_mask = pruning.prune(self._masked_conv_model, 0.5)
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=1)
コード例 #2
0
ファイル: pruning_test.py プロジェクト: tawawhite/rigl
    def test_prune_single_layer_local_pruning(self):
        """Test pruning of model with a single layer, and local pruning schedule."""
        pruned_mask = pruning.prune(self._masked_model, {
            'MaskedModule_0': 0.5,
        })
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
コード例 #3
0
ファイル: pruning_test.py プロジェクト: tawawhite/rigl
    def test_prune_two_layers_dense_no_mask(self):
        """Tests pruning of model with two dense layers without an existing mask."""
        pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5)
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_layer1_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_layer2_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
コード例 #4
0
ファイル: pruning_test.py プロジェクト: tawawhite/rigl
    def test_prune_single_layer_dense_with_mask(self):
        """Tests pruning of single dense layer with an existing mask."""
        pruned_mask = pruning.prune(self._masked_model,
                                    0.5,
                                    mask=masked.shuffled_mask(
                                        self._masked_model, self._rng, 0.95))
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
コード例 #5
0
  def test_prune_two_layer_local_pruning_rate(self):
    """Test pruning of model with two layers, and a local pruning schedule."""
    pruned_mask = pruning.prune(self._masked_model_twolayer, {
        'MaskedModule_1': 0.5,
    })
    mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0'])
    mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1'])

    with self.subTest(name='test_mask_layer1_param_not_none'):
      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

    with self.subTest(name='test_mask_layer2_param_not_none'):
      self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])

    with self.subTest(name='test_mask_layer_0_sparsity'):
      self.assertEqual(mask_layer_0_sparsity, 0.)

    with self.subTest(name='test_mask_layer_1_sparsity'):
      self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)