def test_mask_layer_sparsity_zero_mask(self): """Tests mask calculation with a zeroed mask.""" zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) self.assertEqual( masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.)
def test_mask_layer_sparsity_ones_mask(self): """Tests mask calculation with a mask full of ones.""" one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) self.assertEqual( masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.)
def test_cnn_sparse_init_kaiming(self): """Checks kaiming normal sparse initialization for convolutional layer.""" _, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape, )) self._unmasked_model = flax.nn.Model(MaskedCNN, initial_params) mask = masked.simple_mask(self._unmasked_model, jnp.ones, masked.WEIGHT_PARAM_NAMES) _, initial_params = MaskedCNNSparseInit.init_by_shape( jax.random.PRNGKey(42), (self._input_shape, ), mask=mask) self._masked_model_sparse_init = flax.nn.Model(MaskedCNNSparseInit, initial_params) mean_init = jnp.mean(self._unmasked_model.params['MaskedModule_0'] ['unmasked']['kernel']) stddev_init = jnp.std(self._unmasked_model.params['MaskedModule_0'] ['unmasked']['kernel']) mean_sparse_init = jnp.mean( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) stddev_sparse_init = jnp.std( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) with self.subTest(name='test_cnn_sparse_init_mean'): self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init, mean_init + 2 * stddev_init) with self.subTest(name='test_cnn_sparse_init_stddev'): self.assertBetween(stddev_sparse_init, 0.5 * stddev_init, 1.5 * stddev_init)
def test_simple_mask_two_layer(self): """Tests generation of a simple mask.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros, ['kernel']) result, _ = jax.tree_flatten( jax.tree_util.tree_multimap(lambda x, *xs: (x == xs[0]).all(), mask, gen_mask)) self.assertTrue(all(result))
def test_fully_masked_layer(self): """Tests masked module with full-sparsity mask.""" full_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) masked_output = self._masked_model(self._input, mask=full_mask) with self.subTest(name='fully_masked_dense_values'): self.assertTrue((masked_output == 0).all()) with self.subTest(name='fully_masked_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_empty_mask_masked_layer(self): """Tests masked module with an empty (not sparse) mask.""" empty_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) masked_output = self._masked_model(self._input, mask=empty_mask) with self.subTest(name='empty_mask_masked_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='empty_mask_masked_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_count_permutations_mask_empty(self): """Tests count of weight permutations in an empty mask.""" mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_bn_layer_full(self): """Tests count of permutations on a mask for model with non-masked layers.""" mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedDense.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_twolayers_empty(self): """Tests count of weight permutations in an empty mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES)) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES))
def test_count_permutations_mask_twolayer_full(self): """Tests count of weight permutations in a full mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 2) with self.subTest(name='count_permutations_permutations'): self.assertEqual( stats['permutations'], functools.reduce(operator.mul, [ math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES ])) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES))
def prune(model, pruning_rate, saliency_fn=weight_magnitude, mask=None, compare_fn=jnp.greater): """Returns a mask for a model where the params in each layer are pruned using a saliency function. Args: model: The model to create a pruning mask for. pruning_rate: The fraction of lowest magnitude saliency weights that are pruned. If a float, the same rate is used for all layers, otherwise if it is a mapping, it must contain a rate for all masked layers in the model. saliency_fn: A function that returns a float number used to rank the importance of individual weights in the layer. mask: If the model has an existing mask, the mask will be applied before pruning the model. compare_fn: A pairwise operator to compare saliency with threshold, and return True if the saliency indicates the value should not be masked. Returns: A pruned mask for the given model. """ if not mask: mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES) if not isinstance(pruning_rate, collections.Mapping): pruning_rate_dict = {} for param_name, _ in masked.iterate_mask(mask): # Get the layer name from the parameter's full name/path. layer_name = param_name.split('/')[-2] pruning_rate_dict[layer_name] = pruning_rate pruning_rate = pruning_rate_dict for param_path, param_mask in masked.iterate_mask(mask): split_param_path = param_path.split('/') layer_name = split_param_path[-2] param_name = split_param_path[-1] # If we don't have a pruning rate for the given layer, don't mask it. if layer_name in pruning_rate and mask[layer_name][ param_name] is not None: param_value = model.params[layer_name][ masked.MaskedModule.UNMASKED][param_name] # Here any existing mask is first applied to weight matrix. # Note: need to check explicitly is not None for np array. if param_mask is not None: saliencies = saliency_fn(param_mask * param_value) else: saliencies = saliency_fn(param_value) # TODO: Use partition here (partial sort) instead of sort, # since it's O(N), not O(N log N), however JAX doesn't support it. sorted_param = jnp.sort(jnp.abs(saliencies.flatten())) # Figure out the weight magnitude threshold. threshold_index = jnp.round(pruning_rate[layer_name] * sorted_param.size).astype(jnp.int32) threshold = sorted_param[threshold_index] mask[layer_name][param_name] = jnp.array(compare_fn( saliencies, threshold), dtype=jnp.int32) return mask