예제 #1
0
    def test_count_permutations_mask_layer_twolayer_known_symmetric(self):
        """Tests count of permutations in a known mask with 2 permutations."""
        mask = {
            'MaskedModule_0': {
                'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T,
            },
        }

        stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],
                                                       mask['MaskedModule_1'])

        with self.subTest(name='count_permutations_unique'):
            self.assertEqual(stats['unique_neurons'], 2)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 2)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'],
                             mask['MaskedModule_0']['kernel'].shape[-1])
예제 #2
0
  def test_count_permutations_layer_mask_known_perm_zeros(self):
    """Tests count of weight permutations in a mask with zeroed neurons."""
    param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
        'kernel'].shape

    # Create two unique random mask rows.
    row_type_one = jax.random.bernoulli(
        self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32)
    row_type_two = jnp.zeros(shape=(param_shape[0],), dtype=jnp.int32)

    # Create mask by repeating the two unique rows.
    repeat_one = param_shape[-1] // 3
    repeat_two = param_shape[-1] - repeat_one
    mask_layer = {'kernel': jnp.concatenate(
        (jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1),
         jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)),
        axis=-1)}

    stats = symmetry.count_permutations_mask_layer(mask_layer)

    with self.subTest(name='count_permutations_mask_unique'):
      self.assertEqual(stats['unique_neurons'], 1)

    with self.subTest(name='count_permutations_permutations'):
      self.assertEqual(stats['permutations'], math.factorial(repeat_one))

    with self.subTest(name='count_permutations_zeroed'):
      self.assertEqual(stats['zeroed_neurons'], repeat_two)

    with self.subTest(name='count_permutations_total'):
      self.assertEqual(stats['total_neurons'], param_shape[-1])
예제 #3
0
  def test_count_permutations_mask_layer_twolayer(self, mask, unique,
                                                  permutations, zeroed, total):
    """Test mask permutations if both layers are considered."""
    stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],
                                                   mask['MaskedModule_1'])

    with self.subTest(name='count_permutations_unique'):
      self.assertEqual(stats['unique_neurons'], unique)

    with self.subTest(name='count_permutations_permutations'):
      self.assertEqual(stats['permutations'], permutations)

    with self.subTest(name='count_permutations_zeroed'):
      self.assertEqual(stats['zeroed_neurons'], zeroed)

    with self.subTest(name='count_permutations_total'):
      self.assertEqual(stats['total_neurons'], total)
예제 #4
0
    def test_count_permutations_conv_layer_mask_empty(self):
        """Tests count of weight permutations in an empty mask for a conv. layer."""
        mask_layer = {
            'kernel':
            jnp.zeros(self._masked_conv_model.params['MaskedModule_0']
                      ['unmasked']['kernel'].shape),
        }

        stats = symmetry.count_permutations_mask_layer(mask_layer)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 0)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 0)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
예제 #5
0
    def test_count_permutations_layer_mask_full(self):
        """Tests count of weight permutations in a full mask."""
        mask_layer = {
            'kernel':
            jnp.ones(self._masked_model.params['MaskedModule_0']['unmasked']
                     ['kernel'].shape),
        }

        stats = symmetry.count_permutations_mask_layer(mask_layer)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 1)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'],
                             math.factorial(MaskedDense.NUM_FEATURES))

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES)