Ejemplo n.º 1
0
    def test_mask_layer_sparsity_zero_mask(self):
        """Tests mask calculation with a zeroed mask."""
        zero_mask = masked.simple_mask(self._masked_model, jnp.ones,
                                       ['kernel'])

        self.assertEqual(
            masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.)
Ejemplo n.º 2
0
    def test_mask_layer_sparsity_ones_mask(self):
        """Tests mask calculation with a mask full of ones."""
        one_mask = masked.simple_mask(self._masked_model, jnp.zeros,
                                      ['kernel'])

        self.assertEqual(
            masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.)
Ejemplo n.º 3
0
    def test_cnn_sparse_init_kaiming(self):
        """Checks kaiming normal sparse initialization for convolutional layer."""
        _, initial_params = MaskedCNN.init_by_shape(self._rng,
                                                    (self._input_shape, ))
        self._unmasked_model = flax.nn.Model(MaskedCNN, initial_params)

        mask = masked.simple_mask(self._unmasked_model, jnp.ones,
                                  masked.WEIGHT_PARAM_NAMES)

        _, initial_params = MaskedCNNSparseInit.init_by_shape(
            jax.random.PRNGKey(42), (self._input_shape, ), mask=mask)
        self._masked_model_sparse_init = flax.nn.Model(MaskedCNNSparseInit,
                                                       initial_params)

        mean_init = jnp.mean(self._unmasked_model.params['MaskedModule_0']
                             ['unmasked']['kernel'])

        stddev_init = jnp.std(self._unmasked_model.params['MaskedModule_0']
                              ['unmasked']['kernel'])

        mean_sparse_init = jnp.mean(
            self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
            ['kernel'])

        stddev_sparse_init = jnp.std(
            self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
            ['kernel'])

        with self.subTest(name='test_cnn_sparse_init_mean'):
            self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
                               mean_init + 2 * stddev_init)

        with self.subTest(name='test_cnn_sparse_init_stddev'):
            self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
                               1.5 * stddev_init)
Ejemplo n.º 4
0
    def test_simple_mask_two_layer(self):
        """Tests generation of a simple mask."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']
                          ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1']
                          ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
        }

        gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros,
                                      ['kernel'])

        result, _ = jax.tree_flatten(
            jax.tree_util.tree_multimap(lambda x, *xs: (x == xs[0]).all(),
                                        mask, gen_mask))

        self.assertTrue(all(result))
Ejemplo n.º 5
0
    def test_fully_masked_layer(self):
        """Tests masked module with full-sparsity mask."""
        full_mask = masked.simple_mask(self._masked_model, jnp.zeros,
                                       ['kernel'])

        masked_output = self._masked_model(self._input, mask=full_mask)

        with self.subTest(name='fully_masked_dense_values'):
            self.assertTrue((masked_output == 0).all())

        with self.subTest(name='fully_masked_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
Ejemplo n.º 6
0
    def test_empty_mask_masked_layer(self):
        """Tests masked module with an empty (not sparse) mask."""
        empty_mask = masked.simple_mask(self._masked_model, jnp.ones,
                                        ['kernel'])

        masked_output = self._masked_model(self._input, mask=empty_mask)

        with self.subTest(name='empty_mask_masked_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output, self._unmasked_output).all())

        with self.subTest(name='empty_mask_masked_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
Ejemplo n.º 7
0
    def test_count_permutations_mask_empty(self):
        """Tests count of weight permutations in an empty mask."""
        mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 0)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 0)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
Ejemplo n.º 8
0
    def test_count_permutations_mask_bn_layer_full(self):
        """Tests count of permutations on a mask for model with non-masked layers."""
        mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 1)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'],
                             math.factorial(MaskedDense.NUM_FEATURES))

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
Ejemplo n.º 9
0
    def test_count_permutations_mask_twolayers_empty(self):
        """Tests count of weight permutations in an empty mask for 2 layers."""
        mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros,
                                  ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 0)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(stats['permutations'], 0)

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))
Ejemplo n.º 10
0
    def test_count_permutations_mask_twolayer_full(self):
        """Tests count of weight permutations in a full mask for 2 layers."""
        mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones,
                                  ['kernel'])

        stats = symmetry.count_permutations_mask(mask)

        with self.subTest(name='count_permutations_mask_unique'):
            self.assertEqual(stats['unique_neurons'], 2)

        with self.subTest(name='count_permutations_permutations'):
            self.assertEqual(
                stats['permutations'],
                functools.reduce(operator.mul, [
                    math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES
                ]))

        with self.subTest(name='count_permutations_zeroed'):
            self.assertEqual(stats['zeroed_neurons'], 0)

        with self.subTest(name='count_permutations_total'):
            self.assertEqual(stats['total_neurons'],
                             sum(MaskedTwoLayerDense.NUM_FEATURES))
Ejemplo n.º 11
0
def prune(model,
          pruning_rate,
          saliency_fn=weight_magnitude,
          mask=None,
          compare_fn=jnp.greater):
    """Returns a mask for a model where the params in each layer are pruned using a saliency function.

  Args:
    model: The model to create a pruning mask for.
    pruning_rate: The fraction of lowest magnitude saliency weights that are
      pruned. If a float, the same rate is used for all layers, otherwise if it
      is a mapping, it must contain a rate for all masked layers in the model.
    saliency_fn: A function that returns a float number used to rank
      the importance of individual weights in the layer.
    mask: If the model has an existing mask, the mask will be applied before
      pruning the model.
    compare_fn: A pairwise operator to compare saliency with threshold, and
      return True if the saliency indicates the value should not be masked.

  Returns:
    A pruned mask for the given model.
  """
    if not mask:
        mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES)

    if not isinstance(pruning_rate, collections.Mapping):
        pruning_rate_dict = {}
        for param_name, _ in masked.iterate_mask(mask):
            # Get the layer name from the parameter's full name/path.
            layer_name = param_name.split('/')[-2]
            pruning_rate_dict[layer_name] = pruning_rate
        pruning_rate = pruning_rate_dict

    for param_path, param_mask in masked.iterate_mask(mask):
        split_param_path = param_path.split('/')
        layer_name = split_param_path[-2]
        param_name = split_param_path[-1]

        # If we don't have a pruning rate for the given layer, don't mask it.
        if layer_name in pruning_rate and mask[layer_name][
                param_name] is not None:
            param_value = model.params[layer_name][
                masked.MaskedModule.UNMASKED][param_name]

            # Here any existing mask is first applied to weight matrix.
            # Note: need to check explicitly is not None for np array.
            if param_mask is not None:
                saliencies = saliency_fn(param_mask * param_value)
            else:
                saliencies = saliency_fn(param_value)

            # TODO: Use partition here (partial sort) instead of sort,
            # since it's O(N), not O(N log N), however JAX doesn't support it.
            sorted_param = jnp.sort(jnp.abs(saliencies.flatten()))

            # Figure out the weight magnitude threshold.
            threshold_index = jnp.round(pruning_rate[layer_name] *
                                        sorted_param.size).astype(jnp.int32)
            threshold = sorted_param[threshold_index]

            mask[layer_name][param_name] = jnp.array(compare_fn(
                saliencies, threshold),
                                                     dtype=jnp.int32)

    return mask