Exemplo n.º 1
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()])
        number_of_weights_to_prune = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
        threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]

        new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v))
                         for k, v in weights.items()})
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

                # Randomize globally throughout all prunable layers.
        new_mask = Mask(unvectorize(shuffle_tensor(vectorize(new_mask), seed=42), new_mask))


        return new_mask
Exemplo n.º 2
0
    def test_load_state_dict(self):
        mask = Mask.ones_like(self.model)
        mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
        self.model.layer.weight.data = torch.arange(10,
                                                    20,
                                                    dtype=torch.float32)
        pruned_model = PrunedModel(self.model, mask)
        self.assertEqual(70.0, pruned_model(self.example).item())

        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example),
                                    self.label).backward()
        self.optimizer.step()
        self.assertEqual(67.0, pruned_model(self.example).item())

        # Save the old state dict.
        state_dict = pruned_model.state_dict()

        # Create a new model.
        self.model = InnerProductModel(10)
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)
        self.assertEqual(45.0, pruned_model(self.example).item())

        # Load the state dict.
        pruned_model.load_state_dict(state_dict)
        self.assertEqual(67.0, pruned_model(self.example).item())

        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example),
                                    self.label).backward()
        self.optimizer.step()
        self.assertEqual(64.3, np.round(pruned_model(self.example).item(), 1))
Exemplo n.º 3
0
 def test_create_from_tensor(self):
     m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])})
     self.assertEqual(len(m), 2)
     self.assertEqual(len(m.keys()), 2)
     self.assertEqual(len(m.values()), 2)
     self.assertEqual(set(m.keys()), set(['hello', 'world']))
     self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
     self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
Exemplo n.º 4
0
    def branch_function(self,
                        target_model_name: str = None,
                        block_mapping: str = None,
                        start_at_step_zero: bool = False):
        # Process the mapping
        # A valid string format of a mapping is like:
        #   `0:0;1:1,2;2:3,4;3:5,6;4:7,8`
        if 'cifar' in target_model_name and 'resnet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'imagenet' in target_model_name and 'resnet' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'vggnfc' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'cifar' in target_model_name and 'vgg' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'mnist' in target_model_name and 'lenet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        else:
            raise NotImplementedError(
                'Other mapping cases not implemented yet')

        # Load source model at `train_start_step`
        src_mask = Mask.load(self.level_root)
        start_step = self.lottery_desc.str_to_step(
            '0it'
        ) if start_at_step_zero else self.lottery_desc.train_start_step
        # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask)
        src_model = models.registry.load(self.level_root, start_step,
                                         self.lottery_desc.model_hparams)

        # Create target model
        target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams)
        target_model_hparams.model_name = target_model_name
        target_model = models.registry.get(target_model_hparams)
        target_ones_mask = Mask.ones_like(target_model)

        # Do the morphism
        target_sd = change_depth(target_model_name, src_model.state_dict(),
                                 target_model.state_dict(), mappings)
        target_model.load_state_dict(target_sd)
        target_mask = change_depth(target_model_name, src_mask,
                                   target_ones_mask, mappings)
        target_model = PrunedModel(target_model, target_mask)

        # Save and run a standard train
        target_mask.save(self.branch_root)
        train.standard_train(target_model,
                             self.branch_root,
                             self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose)
Exemplo n.º 5
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(new_location)
        else:
            old_location = self.desc.run_path(self.replicate, level-1)
            model = models.registry.load(old_location, self.desc.train_end_step,
                                         self.desc.model_hparams, self.desc.train_outputs)
            pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location)).save(new_location)
Exemplo n.º 6
0
 def branch_function(self, start_at_step_zero: bool = False):
     model = PrunedModel(
         models.registry.get(self.lottery_desc.model_hparams),
         Mask.load(self.level_root))
     start_step = self.lottery_desc.str_to_step(
         '0it'
     ) if start_at_step_zero else self.lottery_desc.train_start_step
     Mask.load(self.level_root).save(self.branch_root)
     train.standard_train(model,
                          self.branch_root,
                          self.lottery_desc.dataset_hparams,
                          self.lottery_desc.training_hparams,
                          start_step=start_step,
                          verbose=self.verbose)
Exemplo n.º 7
0
    def prune(pruning_hparams: PruningHparams,
              trained_model: models.base.Model,
              current_mask: Mask = None,
              training_hparams: hparams.TrainingHparams = None,
              dataset_hparams: hparams.DatasetHparams = None,
              data_order_seed: int = None):
        current_mask = Mask.ones_like(
            trained_model) if current_mask is None else current_mask
        current_mask_numpy = current_mask.numpy()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum(
            [np.sum(v) for v in current_mask_numpy.values()])
        number_of_weights_to_prune = np.ceil(
            pruning_hparams.pruning_fraction *
            number_of_remaining_weights).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(
                pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model score.
        scores = Strategy.get_score(trained_model, current_mask,
                                    prunable_tensors, training_hparams,
                                    dataset_hparams, data_order_seed)

        # Get the model weights.
        # weights = {k: v.clone().cpu().detach().numpy()
        #            for k, v in trained_model.state_dict().items()
        #            if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        # weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
        score_vector = np.concatenate(
            [v[current_mask_numpy[k] == 1] for k, v in scores.items()])
        threshold = np.sort(np.abs(score_vector))[number_of_weights_to_prune]

        new_mask = Mask({
            k: np.where(
                np.abs(v) > threshold, current_mask_numpy[k], np.zeros_like(v))
            for k, v in scores.items()
        })
        for k in current_mask_numpy:
            if k not in new_mask:
                new_mask[k] = current_mask_numpy[k]

        return new_mask
Exemplo n.º 8
0
 def branch_function(self,
                     retrain_d: hparams.DatasetHparams,
                     retrain_t: hparams.TrainingHparams,
                     start_at_step_zero: bool = False,
                     transfer_learn: bool = False):
     # Get the mask and model.
     if transfer_learn:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_end_step,
                                  self.lottery_desc.model_hparams)
     else:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_start_step,
                                  self.lottery_desc.model_hparams)
     m = PrunedModel(m, Mask.load(self.level_root))
     start_step = Step.from_iteration(
         0 if start_at_step_zero else
         self.lottery_desc.train_start_step.iteration,
         datasets.registry.iterations_per_epoch(retrain_d))
     train.standard_train(m,
                          self.branch_root,
                          retrain_d,
                          retrain_t,
                          start_step=start_step,
                          verbose=self.verbose)
Exemplo n.º 9
0
    def test_state_dict(self):
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)

        state_dict = pruned_model.state_dict()
        self.assertEqual(set(['model.layer.weight', 'mask_layer___weight']),
                         state_dict.keys())
Exemplo n.º 10
0
    def branch_function(self,
                        seed: int,
                        strategy: str = 'sparse_global',
                        start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = models.registry.load(self.pretrain_root, state_step, self.lottery_desc.model_hparams)

        # Get the current level mask and get the target pruning ratio
        mask = Mask.load(self.level_root)
        sparsity_ratio = mask.get_sparsity_ratio()
        target_pruning_fraction = 1.0 - sparsity_ratio

        # Run pruning
        pruning_hparams = copy.deepcopy(self.lottery_desc.pruning_hparams)
        pruning_hparams.pruning_strategy = strategy
        pruning_hparams.pruning_fraction = target_pruning_fraction
        new_mask = pruning.registry.get(pruning_hparams)(
            model, Mask.ones_like(model),
            self.lottery_desc.training_hparams,
            self.lottery_desc.dataset_hparams, seed
        )
        new_mask.save(self.branch_root)

        repruned_model = PrunedModel(model.to(device=get_platform().cpu_device), new_mask)

        # Run training
        train.standard_train(repruned_model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
Exemplo n.º 11
0
    def test_mask_with_ones_backward(self):
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)

        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example), self.label).backward()
        self.optimizer.step()
        self.assertEqual(44.0, pruned_model(self.example).item())
Exemplo n.º 12
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()
        # number of initializations
        num_inits = next(iter(current_mask.values())).shape[0]
        assert np.array([num_inits == v.shape[0] for v in current_mask.values()]).all()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights_per_init = np.sum([np.sum(v) for v in current_mask.values()]) // num_inits
        number_of_weights_to_prune_per_init = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights_per_init).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vectors = [
                np.concatenate(
                    [
                        v[init_id, ...][current_mask[k][init_id,...] == 1]
                        for k, v in weights.items()
                        ]
                    )
                for init_id in range(num_inits)]
        thresholds = np.array([
                np.sort(np.abs(wv))[number_of_weights_to_prune_per_init] for wv in weight_vectors
                ])
        mask_dict = {}
        for k, v in weights.items():
            threshold_tensor = thresholds.reshape(-1, *[1 for _ in range(v.ndim-1)])
            threshold_tensor = np.tile(threshold_tensor, v.shape[1:])
            mask_dict[k] = np.where(np.abs(v) > threshold_tensor, current_mask[k], np.zeros_like(v))
        new_mask = Mask(mask_dict)
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

        return new_mask
Exemplo n.º 13
0
    def test_save(self):
        state1 = self.get_state(self.model)

        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)
        pruned_model.save(self.root, Step.zero(20))

        self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20))))

        self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20))))
        self.assertStateEqual(state1, self.get_state(self.model))
Exemplo n.º 14
0
    def test_ones_like(self):
        model = models.registry.get(models.registry.get_default_hparams('cifar_resnet_20').model_hparams)
        m = Mask.ones_like(model)

        for k, v in model.state_dict().items():
            if k in model.prunable_layer_names:
                self.assertIn(k, m)
                self.assertEqual(list(m[k].shape), list(v.shape))
                self.assertTrue((m[k] == 1).all())
            else:
                self.assertNotIn(k, m)
Exemplo n.º 15
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(
                new_location)  # level=0일때는 mask 다 1 => weight다 살리기
        else:
            old_location = self.desc.run_path(
                self.replicate,
                level - 1)  # 아니라면 old location = 직전 level 에 저장된 path인 run_path
            model = models.registry.load(
                old_location, self.desc.train_end_step,
                self.desc.model_hparams,
                self.desc.train_outputs)  # pruning이기때문에 train_end_step 불러오는것임!

            pruning.registry.get(self.desc.pruning_hparams)(
                model, Mask.load(old_location)
            ).save(
                new_location
            )  # registry.get 에는 return partial 부분에 .prune이 있어 프루닝이 되고 이후 new_location에 저장.
Exemplo n.º 16
0
    def test_dict_behavior(self):
        m = Mask()
        self.assertEqual(len(m), 0)
        self.assertEqual(len(m.keys()), 0)
        self.assertEqual(len(m.values()), 0)

        m['hello'] = np.ones([2, 3])
        m['world'] = np.zeros([5, 6])
        self.assertEqual(len(m), 2)
        self.assertEqual(len(m.keys()), 2)
        self.assertEqual(len(m.values()), 2)
        self.assertEqual(set(m.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))

        del m['hello']
        self.assertEqual(len(m), 1)
        self.assertEqual(len(m.keys()), 1)
        self.assertEqual(len(m.values()), 1)
        self.assertEqual(set(m.keys()), set(['world']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
Exemplo n.º 17
0
    def test_level3_4it_pretrain2it(self):
        self.desc.pretrain_dataset_hparams = copy.deepcopy(
            self.desc.dataset_hparams)
        self.desc.pretrain_training_hparams = copy.deepcopy(
            self.desc.training_hparams)
        self.desc.pretrain_training_hparams.training_steps = '2it'
        self.desc.training_hparams.training_steps = '4it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        # Check that the pretrain weights are present.
        pretrain_root = self.desc.run_path(2, 'pretrain')
        self.assertLevelFilesPresent(pretrain_root,
                                     self.to_step('0it'),
                                     self.to_step('2it'),
                                     masks=False)

        # Load the pretrain and level0 start weights to ensure they're the same.
        pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'),
                                           self.desc.pretrain_end_step)
        pretrain_end_weights = {
            k: v.numpy()
            for k, v in torch.load(pretrain_end_weights).items()
        }

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.desc.train_start_step)
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        self.assertStateEqual(pretrain_end_weights, level0_weights)

        # Evaluate each of the pruning levels.
        for level in range(0, 2):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('2it'),
                                         self.to_step('4it'))

            # Ensure that the initial weights are a masked version of the level 0 weights
            # (which are identical to the weights at the end of pretraining).
            mask = Mask.load(level_root).numpy()
            level_weights = paths.model(level_root, self.desc.train_start_step)
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Exemplo n.º 18
0
    def _train_level(self, level: int):
        location = self.desc.run_path(self.replicate, level)
        if models.registry.exists(location, self.desc.train_end_step): return

        model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step,
                                     self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(location))
        pruned_model.save(location, self.desc.train_start_step)
        if self.verbose and get_platform().is_primary_process:
            print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82)
        train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams,
                             start_step=self.desc.train_start_step, verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch, weight_save_steps=self.weight_save_steps)
Exemplo n.º 19
0
    def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Randomize the mask.
        mask = Mask.load(self.level_root)

        # Randomize while keeping the same layerwise proportions as the original mask.
        if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed))

        # Randomize globally throughout all prunable layers.
        elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask))

        # Randomize evenly across all layers.
        elif strategy == 'even':
            sparsity = mask.sparsity
            for i, k in sorted(mask.keys()):
                layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size),
                                         torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size))
                mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size)

        # Identity.
        elif strategy == 'identity': pass

        # Error.
        else: raise ValueError(f'Invalid strategy: {strategy}')

        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Save the new mask.
        mask.save(self.branch_root)

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask)
        train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
Exemplo n.º 20
0
    def test_bad_inputs(self):
        m = Mask()

        with self.assertRaises(ValueError):
            m[''] = np.ones([2, 3])

        with self.assertRaises(ValueError):
            m[6] = np.ones([2, 3])

        with self.assertRaises(ValueError):
            m['hello'] = [[0, 1], [1, 0]]

        with self.assertRaises(ValueError):
            m['hello'] = np.array([[0, 1], [2, 0]])
Exemplo n.º 21
0
    def test_level3_2it(self):
        self.desc.training_hparams.training_steps = '2it'
        LotteryRunner(replicate=2, levels=3, desc=self.desc,
                      verbose=False).run()

        level0_weights = paths.model(self.desc.run_path(2, 0),
                                     self.to_step('0it'))
        level0_weights = {
            k: v.numpy()
            for k, v in torch.load(level0_weights).items()
        }

        for level in range(0, 4):
            level_root = self.desc.run_path(2, level)
            self.assertLevelFilesPresent(level_root, self.to_step('0it'),
                                         self.to_step('2it'))

            # Check the mask.
            pct = 0.8**level
            mask = Mask.load(level_root).numpy()

            # Check the mask itself.
            total, total_present = 0.0, 0.0
            for v in mask.values():
                total += v.size
                total_present += np.sum(v)
            self.assertTrue(np.allclose(pct, total_present / total, atol=0.01))

            # Check the sparsity report.
            with open(paths.sparsity_report(level_root)) as fp:
                sparsity_report = json.loads(fp.read())
            self.assertTrue(
                np.allclose(pct,
                            sparsity_report['unpruned'] /
                            sparsity_report['total'],
                            atol=0.01))

            # Ensure that the initial weights are a masked version of the level 0 weights.
            level_weights = paths.model(level_root, self.to_step('0it'))
            level_weights = {
                k: v.numpy()
                for k, v in torch.load(level_weights).items()
            }
            self.assertStateEqual(
                level_weights,
                {k: v * mask.get(k, 1)
                 for k, v in level0_weights.items()})
Exemplo n.º 22
0
    def run(self):
        if self.verbose and get_platform().is_primary_process:
            print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82)
            print(self.desc.display)
            print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n')
        self.desc.save(self.desc.run_path(self.replicate))

        #TODO: make mask and model init paths configurable
        init_path = os.path.join(get_platform().root, 'resnet18_lth')
        model = models.registry.load(init_path, Step.from_str('2ep218it', 1000),
                                    self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(init_path))

        train.standard_train(
            #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate),
            pruned_model, self.desc.run_path(self.replicate),
            self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
Exemplo n.º 23
0
    def test_save_load_exists(self):
        self.assertFalse(Mask.exists(self.root))
        self.assertFalse(os.path.exists(paths.mask(self.root)))

        m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])})
        m.save(self.root)
        self.assertTrue(os.path.exists(paths.mask(self.root)))
        self.assertTrue(Mask.exists(self.root))

        m2 = Mask.load(self.root)
        self.assertEqual(len(m2), 2)
        self.assertEqual(len(m2.keys()), 2)
        self.assertEqual(len(m2.values()), 2)
        self.assertEqual(set(m2.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world']))
Exemplo n.º 24
0
    def test_level0_2it(self):
        self.desc.training_hparams.training_steps = '2it'
        LotteryRunner(replicate=2, levels=0, desc=self.desc,
                      verbose=False).run()
        level_root = self.desc.run_path(2, 0)

        # Ensure the important files are there.
        self.assertLevelFilesPresent(level_root, self.to_step('0it'),
                                     self.to_step('2it'))

        # Ensure that the mask is all 1's.
        mask = Mask.load(level_root)
        for v in mask.numpy().values():
            self.assertTrue(np.all(np.equal(v, 1)))
        with open(paths.sparsity_report(level_root)) as fp:
            sparsity_report = json.loads(fp.read())
        self.assertEqual(
            sparsity_report['unpruned'] / sparsity_report['total'], 1)
Exemplo n.º 25
0
    def test_with_mask(self):
        mask = Mask()
        mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
        pruned_model = PrunedModel(self.model, mask)

        # Check that the forward pass gives the correct value.
        self.assertEqual(20.0, pruned_model(self.example).item())

        # Check that the appropriate weights have been zeroed out.
        self.assertTrue(np.array_equal(self.model.state_dict()['layer.weight'].numpy(),
                                       np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0])))

        # Perform a backward pass.
        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example), self.label).backward()
        self.optimizer.step()
        self.assertEqual(22.0, pruned_model(self.example).item())

        # Verify the weights.
        self.assertTrue(np.allclose(self.model.state_dict()['layer.weight'].numpy(),
                                    np.array([0.4, 0, 2.4, 0, 4.4, 0, 6.4, 0, 8.4, 0])))
Exemplo n.º 26
0
    def __init__(self, model: Model, mask: Mask):
        if isinstance(model, PrunedModel):
            raise ValueError('Cannot nest pruned models.')
        super(PrunedModel, self).__init__()
        self.model = model

        for k in self.model.prunable_layer_names:
            if k not in mask:
                raise ValueError('Missing mask value {}.'.format(k))
            if not np.array_equal(mask[k].shape,
                                  np.array(self.model.state_dict()[k].shape)):
                raise ValueError(
                    'Incorrect mask shape {} for tensor {}.'.format(
                        mask[k].shape, k))

        for k in mask:
            if k not in self.model.prunable_layer_names:
                raise ValueError(
                    'Key {} found in mask but is not a valid model tensor.'.
                    format(k))

        for k, v in mask.items():
            self.register_buffer(PrunedModel.to_mask_name(k), v.float())
        self._apply_mask()
Exemplo n.º 27
0
    def test_with_incorrect_shape(self):
        mask = Mask.ones_like(self.model)
        mask['layer.weight'] = np.ones(30)

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Exemplo n.º 28
0
    def test_with_excess_mask_value(self):
        mask = Mask.ones_like(self.model)
        mask['layer2.weight'] = np.ones(20)

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Exemplo n.º 29
0
    def test_with_missing_mask_value(self):
        mask = Mask.ones_like(self.model)
        del mask['layer.weight']

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Exemplo n.º 30
0
 def test_mask_with_ones_forward(self):
     mask = Mask.ones_like(self.model)
     pruned_model = PrunedModel(self.model, mask)
     self.assertEqual(45.0, pruned_model(self.example).item())