예제 #1
0
    def test_count_permutations_mask_twolayer_known_symmetric(self):
        """Tests count of permutations in a known mask with 4 permutations."""
        mask = {
            'MaskedModule_0': {
                'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
            },
            'MaskedModule_1': {
                'kernel': jnp.array(
                    ((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T
            }
        }

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_full_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 4)

        with self.subTest(name='count_permutations_full_mask_permutations'):
            self.assertEqual(stats['permutations'], 4)

        with self.subTest(name='count_permutations_full_mask_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 1)

        with self.subTest(name='Count_permutations_full_mask_total'):
            self.assertEqual(
                stats['total_neurons'],
                mask['MaskedModule_0']['kernel'].shape[-1] +
                mask['MaskedModule_1']['kernel'].shape[-1])
예제 #2
0
    def test_count_permutations_mask_twolayer_known_non_symmetric(self):
        """Tests mask with 1 permutation only if both layers are considered."""
        mask = {
            'MaskedModule_0': {
                'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
            },
            'MaskedModule_1': {
                'kernel': jnp.array(
                    ((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T
            }
        }

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_unique'):
            self.assertEqual(stats['unique_neurons'], 6)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 1)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 1)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(
                stats['total_neurons'],
                mask['MaskedModule_0']['kernel'].shape[-1] +
                mask['MaskedModule_1']['kernel'].shape[-1])
예제 #3
0
    def test_count_permutations_mask_empty(self):
        """Tests count of weight permutations in an empty mask."""
        mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 0)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 0)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
예제 #4
0
  def test_count_permutations_shuffled_full_mask(self):
    """Tests count of weight permutations on a generated full mask."""
    mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1)

    stats = symmetry.count_permutations_mask(mask)

    with self.subTest(name='count_permutations_mask_unique'):
      self.assertEqual(stats['unique_neurons'], 0)

    with self.subTest(name='count_permutations_permutations'):
      self.assertEqual(stats['permutations'], 0)

    with self.subTest(name='count_permutations_zeroed'):
      self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)

    with self.subTest(name='count_permutations_total'):
      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
예제 #5
0
    def test_count_permutations_mask_bn_layer_full(self):
        """Tests count of permutations on a mask for model with non-masked layers."""
        mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 1)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'],
                             math.factorial(MaskedDense.NUM_FEATURES))

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
예제 #6
0
    def test_count_permutations_mask_twolayers_empty(self):
        """Tests count of weight permutations in an empty mask for 2 layers."""
        mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros,
                                  ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 0)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 0)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))
예제 #7
0
    def test_count_permutations_mask_twolayer_full(self):
        """Tests count of weight permutations in a full mask for 2 layers."""
        mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones,
                                  ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 2)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(
                stats['permutations'],
                functools.reduce(operator.mul, [
                    math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES
                ]))

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))