Example #1
0
    def test_load_state_dict(self):
        mask = Mask.ones_like(self.model)
        mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
        self.model.layer.weight.data = torch.arange(10,
                                                    20,
                                                    dtype=torch.float32)
        pruned_model = PrunedModel(self.model, mask)
        self.assertEqual(70.0, pruned_model(self.example).item())

        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example),
                                    self.label).backward()
        self.optimizer.step()
        self.assertEqual(67.0, pruned_model(self.example).item())

        # Save the old state dict.
        state_dict = pruned_model.state_dict()

        # Create a new model.
        self.model = InnerProductModel(10)
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)
        self.assertEqual(45.0, pruned_model(self.example).item())

        # Load the state dict.
        pruned_model.load_state_dict(state_dict)
        self.assertEqual(67.0, pruned_model(self.example).item())

        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example),
                                    self.label).backward()
        self.optimizer.step()
        self.assertEqual(64.3, np.round(pruned_model(self.example).item(), 1))
Example #2
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(new_location)
        else:
            old_location = self.desc.run_path(self.replicate, level-1)
            model = models.registry.load(old_location, self.desc.train_end_step,
                                         self.desc.model_hparams, self.desc.train_outputs)
            pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location)).save(new_location)
Example #3
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()])
        number_of_weights_to_prune = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
        threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]

        new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v))
                         for k, v in weights.items()})
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

                # Randomize globally throughout all prunable layers.
        new_mask = Mask(unvectorize(shuffle_tensor(vectorize(new_mask), seed=42), new_mask))


        return new_mask
Example #4
0
    def test_state_dict(self):
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)

        state_dict = pruned_model.state_dict()
        self.assertEqual(set(['model.layer.weight', 'mask_layer___weight']),
                         state_dict.keys())
    def test_mask_with_ones_backward(self):
        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)

        self.optimizer.zero_grad()
        pruned_model.loss_criterion(pruned_model(self.example), self.label).backward()
        self.optimizer.step()
        self.assertEqual(44.0, pruned_model(self.example).item())
Example #6
0
    def branch_function(self,
                        target_model_name: str = None,
                        block_mapping: str = None,
                        start_at_step_zero: bool = False):
        # Process the mapping
        # A valid string format of a mapping is like:
        #   `0:0;1:1,2;2:3,4;3:5,6;4:7,8`
        if 'cifar' in target_model_name and 'resnet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'imagenet' in target_model_name and 'resnet' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'vggnfc' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'cifar' in target_model_name and 'vgg' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'mnist' in target_model_name and 'lenet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        else:
            raise NotImplementedError(
                'Other mapping cases not implemented yet')

        # Load source model at `train_start_step`
        src_mask = Mask.load(self.level_root)
        start_step = self.lottery_desc.str_to_step(
            '0it'
        ) if start_at_step_zero else self.lottery_desc.train_start_step
        # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask)
        src_model = models.registry.load(self.level_root, start_step,
                                         self.lottery_desc.model_hparams)

        # Create target model
        target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams)
        target_model_hparams.model_name = target_model_name
        target_model = models.registry.get(target_model_hparams)
        target_ones_mask = Mask.ones_like(target_model)

        # Do the morphism
        target_sd = change_depth(target_model_name, src_model.state_dict(),
                                 target_model.state_dict(), mappings)
        target_model.load_state_dict(target_sd)
        target_mask = change_depth(target_model_name, src_mask,
                                   target_ones_mask, mappings)
        target_model = PrunedModel(target_model, target_mask)

        # Save and run a standard train
        target_mask.save(self.branch_root)
        train.standard_train(target_model,
                             self.branch_root,
                             self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose)
Example #7
0
    def test_ones_like(self):
        model = models.registry.get(models.registry.get_default_hparams('cifar_resnet_20').model_hparams)
        m = Mask.ones_like(model)

        for k, v in model.state_dict().items():
            if k in model.prunable_layer_names:
                self.assertIn(k, m)
                self.assertEqual(list(m[k].shape), list(v.shape))
                self.assertTrue((m[k] == 1).all())
            else:
                self.assertNotIn(k, m)
    def test_save(self):
        state1 = self.get_state(self.model)

        mask = Mask.ones_like(self.model)
        pruned_model = PrunedModel(self.model, mask)
        pruned_model.save(self.root, Step.zero(20))

        self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20))))

        self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20))))
        self.assertStateEqual(state1, self.get_state(self.model))
Example #9
0
    def _prune_level(self, level: int):
        new_location = self.desc.run_path(self.replicate, level)
        if Mask.exists(new_location): return

        if level == 0:
            Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(
                new_location)  # level=0일때는 mask 다 1 => weight다 살리기
        else:
            old_location = self.desc.run_path(
                self.replicate,
                level - 1)  # 아니라면 old location = 직전 level 에 저장된 path인 run_path
            model = models.registry.load(
                old_location, self.desc.train_end_step,
                self.desc.model_hparams,
                self.desc.train_outputs)  # pruning이기때문에 train_end_step 불러오는것임!

            pruning.registry.get(self.desc.pruning_hparams)(
                model, Mask.load(old_location)
            ).save(
                new_location
            )  # registry.get 에는 return partial 부분에 .prune이 있어 프루닝이 되고 이후 new_location에 저장.
Example #10
0
    def prune(pruning_hparams: PruningHparams,
              trained_model: models.base.Model,
              current_mask: Mask = None,
              training_hparams: hparams.TrainingHparams = None,
              dataset_hparams: hparams.DatasetHparams = None,
              data_order_seed: int = None):
        current_mask = Mask.ones_like(
            trained_model) if current_mask is None else current_mask
        current_mask_numpy = current_mask.numpy()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum(
            [np.sum(v) for v in current_mask_numpy.values()])
        number_of_weights_to_prune = np.ceil(
            pruning_hparams.pruning_fraction *
            number_of_remaining_weights).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(
                pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model score.
        scores = Strategy.get_score(trained_model, current_mask,
                                    prunable_tensors, training_hparams,
                                    dataset_hparams, data_order_seed)

        # Get the model weights.
        # weights = {k: v.clone().cpu().detach().numpy()
        #            for k, v in trained_model.state_dict().items()
        #            if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        # weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
        score_vector = np.concatenate(
            [v[current_mask_numpy[k] == 1] for k, v in scores.items()])
        threshold = np.sort(np.abs(score_vector))[number_of_weights_to_prune]

        new_mask = Mask({
            k: np.where(
                np.abs(v) > threshold, current_mask_numpy[k], np.zeros_like(v))
            for k, v in scores.items()
        })
        for k in current_mask_numpy:
            if k not in new_mask:
                new_mask[k] = current_mask_numpy[k]

        return new_mask
Example #11
0
    def branch_function(self,
                        seed: int,
                        strategy: str = 'sparse_global',
                        start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = models.registry.load(self.pretrain_root, state_step, self.lottery_desc.model_hparams)

        # Get the current level mask and get the target pruning ratio
        mask = Mask.load(self.level_root)
        sparsity_ratio = mask.get_sparsity_ratio()
        target_pruning_fraction = 1.0 - sparsity_ratio

        # Run pruning
        pruning_hparams = copy.deepcopy(self.lottery_desc.pruning_hparams)
        pruning_hparams.pruning_strategy = strategy
        pruning_hparams.pruning_fraction = target_pruning_fraction
        new_mask = pruning.registry.get(pruning_hparams)(
            model, Mask.ones_like(model),
            self.lottery_desc.training_hparams,
            self.lottery_desc.dataset_hparams, seed
        )
        new_mask.save(self.branch_root)

        repruned_model = PrunedModel(model.to(device=get_platform().cpu_device), new_mask)

        # Run training
        train.standard_train(repruned_model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
Example #12
0
    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()
        # number of initializations
        num_inits = next(iter(current_mask.values())).shape[0]
        assert np.array([num_inits == v.shape[0] for v in current_mask.values()]).all()

        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights_per_init = np.sum([np.sum(v) for v in current_mask.values()]) // num_inits
        number_of_weights_to_prune_per_init = np.ceil(
            pruning_hparams.pruning_fraction * number_of_remaining_weights_per_init).astype(int)

        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        weights = {k: v.clone().cpu().detach().numpy()
                   for k, v in trained_model.state_dict().items()
                   if k in prunable_tensors}

        # Create a vector of all the unpruned weights in the model.
        weight_vectors = [
                np.concatenate(
                    [
                        v[init_id, ...][current_mask[k][init_id,...] == 1]
                        for k, v in weights.items()
                        ]
                    )
                for init_id in range(num_inits)]
        thresholds = np.array([
                np.sort(np.abs(wv))[number_of_weights_to_prune_per_init] for wv in weight_vectors
                ])
        mask_dict = {}
        for k, v in weights.items():
            threshold_tensor = thresholds.reshape(-1, *[1 for _ in range(v.ndim-1)])
            threshold_tensor = np.tile(threshold_tensor, v.shape[1:])
            mask_dict[k] = np.where(np.abs(v) > threshold_tensor, current_mask[k], np.zeros_like(v))
        new_mask = Mask(mask_dict)
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]

        return new_mask
Example #13
0
    def test_with_incorrect_shape(self):
        mask = Mask.ones_like(self.model)
        mask['layer.weight'] = np.ones(30)

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Example #14
0
    def test_with_excess_mask_value(self):
        mask = Mask.ones_like(self.model)
        mask['layer2.weight'] = np.ones(20)

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Example #15
0
    def test_with_missing_mask_value(self):
        mask = Mask.ones_like(self.model)
        del mask['layer.weight']

        with self.assertRaises(ValueError):
            PrunedModel(self.model, mask)
Example #16
0
 def test_mask_with_ones_forward(self):
     mask = Mask.ones_like(self.model)
     pruned_model = PrunedModel(self.model, mask)
     self.assertEqual(45.0, pruned_model(self.example).item())
Example #17
0
    def branch_function(
            self,
            strategy: str,
            prune_fraction: float,
            prune_experiment: str = 'main',
            prune_step: str = '0ep0it',  # The step for states used to prune.
            prune_highest: bool = False,
            prune_iterations: int = 1,
            randomize_layerwise: bool = False,
            state_experiment: str = 'main',
            state_step:
        str = '0ep0it',  # The step of the state to use alongside the pruning mask.
            start_step:
        str = '0ep0it',  # The step at which to start the learning rate schedule.
            seed: int = None,
            reinitialize: bool = False):
        # Get the steps for each part of the process.
        iterations_per_epoch = datasets.registry.iterations_per_epoch(
            self.desc.dataset_hparams)
        prune_step = Step.from_str(prune_step, iterations_per_epoch)
        state_step = Step.from_str(state_step, iterations_per_epoch)
        start_step = Step.from_str(start_step, iterations_per_epoch)
        seed = self.replicate if seed is None else seed

        # Try to load the mask.
        try:
            mask = Mask.load(self.branch_root)
        except:
            mask = None

        result_folder = "Data_Distribution/"
        if reinitialize:
            result_folder = "Data_Distribution_Reinit/"
        elif randomize_layerwise:
            result_folder = "Data_Distribution__Randomize_Layerwise/"

        if not mask and get_platform().is_primary_process:
            # Gather the weights that will be used for pruning.
            prune_path = self.desc.run_path(self.replicate, prune_experiment)
            prune_model = models.registry.load(prune_path, prune_step,
                                               self.desc.model_hparams)

            # Ensure that a valid strategy is available.
            strategy_class = [s for s in strategies if s.valid_name(strategy)]
            if not strategy_class:
                raise ValueError(f'No such pruning strategy {strategy}')
            if len(strategy_class) > 1:
                raise ValueError('Multiple matching strategies')
            strategy_instance = strategy_class[0](strategy, self.desc, seed)

            # Run the strategy for each iteration.
            mask = Mask.ones_like(prune_model)
            iteration_fraction = 1 - (1 - prune_fraction)**(
                1 / float(prune_iterations))

            if iteration_fraction > 0:
                for it in range(0, prune_iterations):
                    # Make a defensive copy of the model and mask out the pruned weights.
                    prune_model2 = copy.deepcopy(prune_model)
                    with torch.no_grad():
                        for k, v in prune_model2.named_parameters():
                            v.mul_(mask.get(k, 1))

                    # Compute the scores.
                    scores = strategy_instance.score(prune_model2, mask)

                    # Prune.
                    mask = unvectorize(
                        prune(vectorize(scores),
                              iteration_fraction,
                              not prune_highest,
                              mask=vectorize(mask)), mask)

            # Shuffle randomly per layer.
            if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed)

            mask = Mask({k: v.clone().detach() for k, v in mask.items()})
            mask.save(self.branch_root)

            # Plot graphs (Move below mask save?)

            # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)
            # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

            # pdb.set_trace()

        # Load the mask.
        get_platform().barrier()
        mask = Mask.load(self.branch_root)

        # Determine the start step.
        state_path = self.desc.run_path(self.replicate, state_experiment)
        if reinitialize: model = models.registry.get(self.desc.model_hparams)
        else:
            model = models.registry.load(state_path, state_step,
                                         self.desc.model_hparams)

        # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

        original_model = copy.deepcopy(model)
        model = PrunedModel(model, mask)

        # pdb.set_trace()
        train.standard_train(model,
                             self.branch_root,
                             self.desc.dataset_hparams,
                             self.desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch)

        weights_analysis(original_model, strategy, reinitialize,
                         randomize_layerwise, "original")
        weights_analysis(model, strategy, reinitialize, randomize_layerwise,
                         "pruned")
    def branch_function(self,
                        target_model_name: str = None,
                        block_mapping: str = None,
                        start_at_step_zero: bool = False,
                        data_seed: int = 118):
        # Process the mapping
        # A valid string format of a mapping is like:
        #   `0:0;1:1,2;2:3,4;3:5,6;4:7,8`
        if 'cifar' in target_model_name and 'resnet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'imagenet' in target_model_name and 'resnet' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'vggnfc' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'cifar' in target_model_name and 'vgg' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'mnist' in target_model_name and 'lenet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        else:
            raise NotImplementedError(
                'Other mapping cases not implemented yet')

        # Load source model at `train_start_step`
        src_mask = Mask.load(self.level_root)
        start_step = self.lottery_desc.str_to_step(
            '0it'
        ) if start_at_step_zero else self.lottery_desc.train_start_step
        # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask)
        src_model = models.registry.load(self.level_root, start_step,
                                         self.lottery_desc.model_hparams)

        # Create target model
        target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams)
        target_model_hparams.model_name = target_model_name
        target_model = models.registry.get(target_model_hparams)
        target_ones_mask = Mask.ones_like(target_model)

        # Do the morphism
        target_sd = change_depth(target_model_name, src_model.state_dict(),
                                 target_model.state_dict(), mappings)
        target_model.load_state_dict(target_sd)
        target_mask = change_depth(target_model_name, src_mask,
                                   target_ones_mask, mappings)
        target_model_a = PrunedModel(target_model, target_mask)
        target_model_b = copy.deepcopy(target_model_a)

        # Save and run a standard train on model a
        seed_a = data_seed + 9999
        training_hparams_a = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_a.data_order_seed = seed_a
        output_dir_a = os.path.join(self.branch_root, f'seed_{seed_a}')
        target_mask.save(output_dir_a)
        train.standard_train(target_model_a,
                             output_dir_a,
                             self.lottery_desc.dataset_hparams,
                             training_hparams_a,
                             start_step=start_step,
                             verbose=self.verbose)

        # Save and run a standard train on model b
        seed_b = data_seed + 10001
        training_hparams_b = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_b.data_order_seed = seed_b
        output_dir_b = os.path.join(self.branch_root, f'seed_{seed_b}')
        target_mask.save(output_dir_b)
        train.standard_train(target_model_b,
                             output_dir_b,
                             self.lottery_desc.dataset_hparams,
                             training_hparams_b,
                             start_step=start_step,
                             verbose=self.verbose)

        # Linear connectivity between model_a and model_b
        training_hparams_c = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_c.training_steps = '1ep'
        for alpha in np.linspace(0, 1.0, 21):
            model_c = linear_interpolate(target_model_a, target_model_b, alpha)
            output_dir_c = os.path.join(self.branch_root, f'alpha_{alpha}')
            # Measure acc of model_c
            train.standard_train(model_c,
                                 output_dir_c,
                                 self.lottery_desc.dataset_hparams,
                                 training_hparams_c,
                                 start_step=None,
                                 verbose=self.verbose)