Пример #1
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()])
        number_of_weights_to_prune = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
        threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]

        new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v))
                         for k, v in weights.items()})
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

                # Randomize globally throughout all prunable layers.
        new_mask = Mask(unvectorize(shuffle_tensor(vectorize(new_mask), seed=42), new_mask))


        return new_mask
Пример #2
0
 def test_create_from_tensor(self):
     m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])})
     self.assertEqual(len(m), 2)
     self.assertEqual(len(m.keys()), 2)
     self.assertEqual(len(m.values()), 2)
     self.assertEqual(set(m.keys()), set(['hello', 'world']))
     self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
     self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
Пример #3
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()
        # number of initializations
        num_inits = next(iter(current_mask.values())).shape[0]
        assert np.array([num_inits == v.shape[0] for v in current_mask.values()]).all()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights_per_init = np.sum([np.sum(v) for v in current_mask.values()]) // num_inits
        number_of_weights_to_prune_per_init = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights_per_init).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vectors = [
                np.concatenate(
                    [
                        v[init_id, ...][current_mask[k][init_id,...] == 1]
                        for k, v in weights.items()
                        ]
                    )
                for init_id in range(num_inits)]
        thresholds = np.array([
                np.sort(np.abs(wv))[number_of_weights_to_prune_per_init] for wv in weight_vectors
                ])
        mask_dict = {}
        for k, v in weights.items():
            threshold_tensor = thresholds.reshape(-1, *[1 for _ in range(v.ndim-1)])
            threshold_tensor = np.tile(threshold_tensor, v.shape[1:])
            mask_dict[k] = np.where(np.abs(v) > threshold_tensor, current_mask[k], np.zeros_like(v))
        new_mask = Mask(mask_dict)
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

        return new_mask
Пример #4
0
    def test_dict_behavior(self):
        m = Mask()
        self.assertEqual(len(m), 0)
        self.assertEqual(len(m.keys()), 0)
        self.assertEqual(len(m.values()), 0)

        m['hello'] = np.ones([2, 3])
        m['world'] = np.zeros([5, 6])
        self.assertEqual(len(m), 2)
        self.assertEqual(len(m.keys()), 2)
        self.assertEqual(len(m.values()), 2)
        self.assertEqual(set(m.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))

        del m['hello']
        self.assertEqual(len(m), 1)
        self.assertEqual(len(m.keys()), 1)
        self.assertEqual(len(m.values()), 1)
        self.assertEqual(set(m.keys()), set(['world']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))