def test_get_mask_stats_keys_values(self): """Tests the returned dict has the required keys, and value types/ranges.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array( ((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T } } mask_stats = symmetry.get_mask_stats(mask) with self.subTest(name='sparsity_exists'): self.assertIn('sparsity', mask_stats) with self.subTest(name='sparsity_value'): self.assertBetween(mask_stats['sparsity'], 0.0, 1.0) with self.subTest(name='permutation_num_digits_exists'): self.assertIn('permutation_num_digits', mask_stats) with self.subTest(name='permutation_num_digits_value'): self.assertGreaterEqual(mask_stats['permutation_num_digits'], 0.0) with self.subTest(name='permutation_log10_exists'): self.assertIn('permutation_log10', mask_stats) with self.subTest(name='permutation_log10_value'): self.assertGreaterEqual(mask_stats['permutation_log10'], 0.0) with self.subTest(name='unique_neurons_exists'): self.assertIn('unique_neurons', mask_stats) with self.subTest(name='unique_neurons_value'): self.assertEqual(mask_stats['unique_neurons'], 6) with self.subTest(name='permutations_exists'): self.assertIn('permutations', mask_stats) with self.subTest(name='permutations_value'): self.assertEqual(mask_stats['permutations'], 1) with self.subTest(name='zeroed_neurons_exists'): self.assertIn('zeroed_neurons', mask_stats) with self.subTest(name='zeroed_neurons_value'): self.assertEqual(mask_stats['zeroed_neurons'], 1) with self.subTest(name='total_neurons_exists'): self.assertIn('total_neurons', mask_stats) with self.subTest(name='total_neurons_value'): self.assertEqual( mask_stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1])
input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed) mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng, FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json(