Exemple #1
0
 def test_shuffled_mask_invalid_model(self):
     """Tests shuffled mask with model containing no masked layers."""
     with self.assertRaisesRegex(
             ValueError,
             'Model does not support masking, i.e. no layers are '
             'wrapped by a MaskedModule.'):
         masked.shuffled_mask(self._unmasked_model, self._rng, 0.5)
Exemple #2
0
    def test_shuffled_mask_invalid_sparsity(self):
        """Tests shuffled mask with invalid sparsity."""

        with self.subTest(name='sparsity_too_small'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, -0.5, is not in range \[0, 1\]'):
                masked.shuffled_mask(self._masked_model, self._rng, -0.5)

        with self.subTest(name='sparsity_too_large'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, 1.5, is not in range \[0, 1\]'):
                masked.shuffled_mask(self._masked_model, self._rng, 1.5)
Exemple #3
0
    def test_shuffled_mask_sparsity_empty_twolayer(self):
        """Tests shuffled mask generation for two layers, for 0% sparsity."""
        mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng,
                                    0.0)

        with self.subTest(name='shuffled_empty_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer1'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())

        with self.subTest(name='shuffled_empty_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer2'):
            self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output,
                            self._unmasked_output_twolayer).all())

        with self.subTest(name='shuffled_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
Exemple #4
0
    def test_shuffled_mask_sparsity_full_twolayer(self):
        """Tests shuffled mask generation for two layers, and 100% sparsity."""
        mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng,
                                    1.0)

        with self.subTest(name='shuffled_full_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values_layer1'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        with self.subTest(name='shuffled_full_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_full_mask_values_layer2'):
            self.assertTrue((mask['MaskedModule_1']['kernel'] == 0).all())

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
            self.assertIsNone(mask['MaskedModule_1']['bias'])

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_values'):
            self.assertTrue((masked_output == 0).all())

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
Exemple #5
0
    def test_shuffled_mask_sparsity_half_full(self):
        """Tests shuffled mask generation, for a half-full mask."""
        mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)
        param_len = self._masked_model.params['MaskedModule_0']['unmasked'][
            'kernel'].size

        with self.subTest(name='shuffled_mask_values'):
            self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']),
                             param_len // 2)
Exemple #6
0
    def test_prune_one_layer_conv_with_mask(self):
        """Tests pruning of model with one conv. layer with an existing mask."""
        pruned_mask = pruning.prune(self._masked_conv_model,
                                    0.5,
                                    mask=masked.shuffled_mask(
                                        self._masked_model, self._rng, 0.95))
        mask_sparsity = masked.mask_sparsity(pruned_mask)

        with self.subTest(name='test_mask_param_not_none'):
            self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])

        with self.subTest(name='test_mask_sparsity'):
            self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
Exemple #7
0
  def test_count_permutations_shuffled_full_mask(self):
    """Tests count of weight permutations on a generated full mask."""
    mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1)

    stats = symmetry.count_permutations_mask(mask)

    with self.subTest(name='count_permutations_mask_unique'):
      self.assertEqual(stats['unique_neurons'], 0)

    with self.subTest(name='count_permutations_permutations'):
      self.assertEqual(stats['permutations'], 0)

    with self.subTest(name='count_permutations_zeroed'):
      self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)

    with self.subTest(name='count_permutations_total'):
      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
Exemple #8
0
    def test_shuffled_mask_sparsity_empty(self):
        """Tests shuffled mask generation, for 0% sparsity."""
        mask = masked.shuffled_mask(self._masked_model, self._rng, 0.0)

        with self.subTest(name='shuffled_empty_mask'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_empty_mask_values'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())

        with self.subTest(name='shuffled_empty_mask_not_masked_values'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='shuffled_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output, self._unmasked_output).all())

        with self.subTest(name='shuffled_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
Exemple #9
0
    def test_mask_layer_sparsity_half_mask(self):
        """Tests mask calculation with a half-filled mask."""
        half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)

        self.assertAlmostEqual(
            masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5)
Exemple #10
0
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes,
      features=features)

  model_param_count = utils.count_param(base_model, ('kernel',))

  logging.info(
      'Model Config: param.: %d, depth: %d. max width: %d, min width: %d',
      model_param_count, len(features), max(features), min(features))

  logging.info('Generating random mask based on model')

  # Re-initialize the RNG to maintain same training pattern (as in prune code).
  mask_rng = jax.random.PRNGKey(FLAGS.random_seed)
  mask = masked.shuffled_mask(
      base_model,
      rng=mask_rng,
      sparsity=FLAGS.mask_sparsity)

  if jax.host_id() == 0:
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Mask stats: %s', str(mask_stats))


    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'
                      ', writing as string instead.', label)