コード例 #1
0
ファイル: symmetry_test.py プロジェクト: tawawhite/rigl
    def test_get_mask_stats_keys_values(self):
        """Tests the returned dict has the required keys, and value types/ranges."""
        mask = {
            'MaskedModule_0': {
                'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
            },
            'MaskedModule_1': {
                'kernel': jnp.array(
                    ((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T
            }
        }

        mask_stats = symmetry.get_mask_stats(mask)

        with self.subTest(name='sparsity_exists'):
            self.assertIn('sparsity', mask_stats)

        with self.subTest(name='sparsity_value'):
            self.assertBetween(mask_stats['sparsity'], 0.0, 1.0)

        with self.subTest(name='permutation_num_digits_exists'):
            self.assertIn('permutation_num_digits', mask_stats)

        with self.subTest(name='permutation_num_digits_value'):
            self.assertGreaterEqual(mask_stats['permutation_num_digits'], 0.0)

        with self.subTest(name='permutation_log10_exists'):
            self.assertIn('permutation_log10', mask_stats)

        with self.subTest(name='permutation_log10_value'):
            self.assertGreaterEqual(mask_stats['permutation_log10'], 0.0)

        with self.subTest(name='unique_neurons_exists'):
            self.assertIn('unique_neurons', mask_stats)

        with self.subTest(name='unique_neurons_value'):
            self.assertEqual(mask_stats['unique_neurons'], 6)

        with self.subTest(name='permutations_exists'):
            self.assertIn('permutations', mask_stats)

        with self.subTest(name='permutations_value'):
            self.assertEqual(mask_stats['permutations'], 1)

        with self.subTest(name='zeroed_neurons_exists'):
            self.assertIn('zeroed_neurons', mask_stats)

        with self.subTest(name='zeroed_neurons_value'):
            self.assertEqual(mask_stats['zeroed_neurons'], 1)

        with self.subTest(name='total_neurons_exists'):
            self.assertIn('total_neurons', mask_stats)

        with self.subTest(name='total_neurons_value'):
            self.assertEqual(
                mask_stats['total_neurons'],
                mask['MaskedModule_0']['kernel'].shape[-1] +
                mask['MaskedModule_1']['kernel'].shape[-1])
コード例 #2
0
ファイル: shuffled_mask.py プロジェクト: yaelandau22/rigl
  input_shape = (1,) + dataset.shape
  base_model, _ = model_factory.create_model(
      FLAGS.model,
      rng, ((input_shape, jnp.float32),),
      num_classes=dataset.num_classes)

  logging.info('Generating random mask based on model')

  # Re-initialize the RNG to maintain same training pattern (as in prune code).
  mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)

  mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,
                                  FLAGS.mask_sparsity)

  if jax.host_id() == 0:
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Mask stats: %s', str(mask_stats))


    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'
                      ', writing as string instead.', label)

    if FLAGS.dump_json:
      mask_stats['permutations'] = str(mask_stats['permutations'])
      utils.dump_dict_json(