示例#1
0
    def test_mask_sparsity_mixed_mask(self):
        """Tests mask calculation with a mask different sparsity masked layers."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jnp.zeros(
                    self._masked_conv_model_twolayer.params['MaskedModule_0']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.ones(
                    self._masked_conv_model_twolayer.params['MaskedModule_1']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
        }

        mask_sparsity = masked.mask_sparsity(mask)
        true_sparsity = self._masked_conv_model_twolayer.params[
            'MaskedModule_1']['unmasked']['kernel'].size / (
                self._masked_conv_model_twolayer.params['MaskedModule_0']
                ['unmasked']['kernel'].size + self._masked_conv_model_twolayer.
                params['MaskedModule_1']['unmasked']['kernel'].size)

        self.assertAlmostEqual(mask_sparsity, 1.0 - true_sparsity)
示例#2
0
    def test_prune_one_layer_conv_no_mask(self):
        """Tests pruning of model with one conv. layer without an existing mask."""
        pruned_mask = pruning.prune(self._masked_conv_model, 0.5)
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=1)
示例#3
0
  def test_prune_two_layer_local_pruning_rate(self):
    """Test pruning of model with two layers, and a local pruning schedule."""
    pruned_mask = pruning.prune(self._masked_model_twolayer, {
        'MaskedModule_1': 0.5,
    })
    mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0'])
    mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1'])

    with self.subTest(name='test_mask_layer1_param_not_none'):
      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

    with self.subTest(name='test_mask_layer2_param_not_none'):
      self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])

    with self.subTest(name='test_mask_layer_0_sparsity'):
      self.assertEqual(mask_layer_0_sparsity, 0.)

    with self.subTest(name='test_mask_layer_1_sparsity'):
      self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)
示例#4
0
    def test_prune_single_layer_local_pruning(self):
        """Test pruning of model with a single layer, and local pruning schedule."""
        pruned_mask = pruning.prune(self._masked_model, {
            'MaskedModule_0': 0.5,
        })
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
示例#5
0
    def test_prune_two_layers_dense_no_mask(self):
        """Tests pruning of model with two dense layers without an existing mask."""
        pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5)
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_layer1_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_layer2_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
示例#6
0
    def test_prune_single_layer_dense_with_mask(self):
        """Tests pruning of single dense layer with an existing mask."""
        pruned_mask = pruning.prune(self._masked_model,
                                    0.5,
                                    mask=masked.shuffled_mask(
                                        self._masked_model, self._rng, 0.95))
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
示例#7
0
def get_mask_stats(mask):
    """Calculates an array of mask statistics.

  Args:
    mask: A model mask to calculate the statistics of.

  Returns:
    A dictionary, containing a set of mask statistics.
  """
    mask_stats = count_permutations_mask(mask)
    mask_stats.update({
        'sparsity':
        masked.mask_sparsity(mask),
        'permutation_num_digits':
        len(str(mask_stats['permutations'])),
        'permutation_log10':
        math.log10(mask_stats['permutations'] + 1),
    })

    return mask_stats
示例#8
0
    def test_mask_sparsity_ones_mask(self):
        """Tests mask calculation with a mask full of ones."""
        one_mask = masked.simple_mask(self._masked_model, jnp.zeros,
                                      ['kernel'])

        self.assertEqual(masked.mask_sparsity(one_mask), 1.)
示例#9
0
    def test_mask_sparsity_zero_mask(self):
        """Tests mask calculation with a zeroed mask."""
        zero_mask = masked.simple_mask(self._masked_model, jnp.ones,
                                       ['kernel'])

        self.assertEqual(masked.mask_sparsity(zero_mask), 0.)