def test_count_permutations_mask_twolayer_known_symmetric(self): """Tests count of permutations in a known mask with 4 permutations.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array( ((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T } } stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_full_mask_unique'): self.assertEqual(stats['unique_neurons'], 4) with self.subTest(name='count_permutations_full_mask_permutations'): self.assertEqual(stats['permutations'], 4) with self.subTest(name='count_permutations_full_mask_zeroed'): self.assertEqual(stats['zeroed_neurons'], 1) with self.subTest(name='Count_permutations_full_mask_total'): self.assertEqual( stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1])
def test_count_permutations_mask_twolayer_known_non_symmetric(self): """Tests mask with 1 permutation only if both layers are considered.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array( ((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T } } stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_unique'): self.assertEqual(stats['unique_neurons'], 6) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 1) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 1) with self.subTest(name='count_permutations_total'): self.assertEqual( stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1])
def test_count_permutations_mask_empty(self): """Tests count of weight permutations in an empty mask.""" mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_shuffled_full_mask(self): """Tests count of weight permutations on a generated full mask.""" mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_bn_layer_full(self): """Tests count of permutations on a mask for model with non-masked layers.""" mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedDense.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_twolayers_empty(self): """Tests count of weight permutations in an empty mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES)) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES))
def test_count_permutations_mask_twolayer_full(self): """Tests count of weight permutations in a full mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 2) with self.subTest(name='count_permutations_permutations'): self.assertEqual( stats['permutations'], functools.reduce(operator.mul, [ math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES ])) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES))