Example #1
0
 def test_create_from_tensor(self):
     m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])})
     self.assertEqual(len(m), 2)
     self.assertEqual(len(m.keys()), 2)
     self.assertEqual(len(m.values()), 2)
     self.assertEqual(set(m.keys()), set(['hello', 'world']))
     self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
     self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
Example #2
0
    def test_dict_behavior(self):
        m = Mask()
        self.assertEqual(len(m), 0)
        self.assertEqual(len(m.keys()), 0)
        self.assertEqual(len(m.values()), 0)

        m['hello'] = np.ones([2, 3])
        m['world'] = np.zeros([5, 6])
        self.assertEqual(len(m), 2)
        self.assertEqual(len(m.keys()), 2)
        self.assertEqual(len(m.values()), 2)
        self.assertEqual(set(m.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))

        del m['hello']
        self.assertEqual(len(m), 1)
        self.assertEqual(len(m.keys()), 1)
        self.assertEqual(len(m.values()), 1)
        self.assertEqual(set(m.keys()), set(['world']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
Example #3
0
    def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Randomize the mask.
        mask = Mask.load(self.level_root)

        # Randomize while keeping the same layerwise proportions as the original mask.
        if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed))

        # Randomize globally throughout all prunable layers.
        elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask))

        # Randomize evenly across all layers.
        elif strategy == 'even':
            sparsity = mask.sparsity
            for i, k in sorted(mask.keys()):
                layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size),
                                         torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size))
                mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size)

        # Identity.
        elif strategy == 'identity': pass

        # Error.
        else: raise ValueError(f'Invalid strategy: {strategy}')

        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Save the new mask.
        mask.save(self.branch_root)

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask)
        train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)