def test_shuffled_mask_invalid_model(self): """Tests shuffled mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.shuffled_mask(self._unmasked_model, self._rng, 0.5)
def test_shuffled_mask_invalid_sparsity(self): """Tests shuffled mask with invalid sparsity.""" with self.subTest(name='sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.shuffled_mask(self._masked_model, self._rng, -0.5) with self.subTest(name='sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.shuffled_mask(self._masked_model, self._rng, 1.5)
def test_shuffled_mask_sparsity_empty_twolayer(self): """Tests shuffled mask generation for two layers, for 0% sparsity.""" mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_empty_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all()) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output_twolayer).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape)
def test_shuffled_mask_sparsity_full_twolayer(self): """Tests shuffled mask generation for two layers, and 100% sparsity.""" mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 1.0) with self.subTest(name='shuffled_full_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_0']['bias']) with self.subTest(name='shuffled_full_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_full_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_1']['bias']) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_values'): self.assertTrue((masked_output == 0).all()) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape)
def test_shuffled_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5) param_len = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].size with self.subTest(name='shuffled_mask_values'): self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']), param_len // 2)
def test_prune_one_layer_conv_with_mask(self): """Tests pruning of model with one conv. layer with an existing mask.""" pruned_mask = pruning.prune(self._masked_conv_model, 0.5, mask=masked.shuffled_mask( self._masked_model, self._rng, 0.95)) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
def test_count_permutations_shuffled_full_mask(self): """Tests count of weight permutations on a generated full mask.""" mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_shuffled_mask_sparsity_empty(self): """Tests shuffled mask generation, for 0% sparsity.""" mask = masked.shuffled_mask(self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_mask_layer_sparsity_half_mask(self): """Tests mask calculation with a half-filled mask.""" half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5) self.assertAlmostEqual( masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5)
rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, features=features) model_param_count = utils.count_param(base_model, ('kernel',)) logging.info( 'Model Config: param.: %d, depth: %d. max width: %d, min width: %d', model_param_count, len(features), max(features), min(features)) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.random_seed) mask = masked.shuffled_mask( base_model, rng=mask_rng, sparsity=FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label)