def test_create_from_tensor(self): m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])}) self.assertEqual(len(m), 2) self.assertEqual(len(m.keys()), 2) self.assertEqual(len(m.values()), 2) self.assertEqual(set(m.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
def test_dict_behavior(self): m = Mask() self.assertEqual(len(m), 0) self.assertEqual(len(m.keys()), 0) self.assertEqual(len(m.values()), 0) m['hello'] = np.ones([2, 3]) m['world'] = np.zeros([5, 6]) self.assertEqual(len(m), 2) self.assertEqual(len(m.keys()), 2) self.assertEqual(len(m.values()), 2) self.assertEqual(set(m.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) del m['hello'] self.assertEqual(len(m), 1) self.assertEqual(len(m.keys()), 1) self.assertEqual(len(m.values()), 1) self.assertEqual(set(m.keys()), set(['world'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind', layers_to_ignore: str = ''): # Randomize the mask. mask = Mask.load(self.level_root) # Randomize while keeping the same layerwise proportions as the original mask. if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed)) # Randomize globally throughout all prunable layers. elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask)) # Randomize evenly across all layers. elif strategy == 'even': sparsity = mask.sparsity for i, k in sorted(mask.keys()): layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size), torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size)) mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size) # Identity. elif strategy == 'identity': pass # Error. else: raise ValueError(f'Invalid strategy: {strategy}') # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Save the new mask. mask.save(self.branch_root) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)