コード例 #1
0
    def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Randomize the mask.
        mask = Mask.load(self.level_root)

        # Randomize while keeping the same layerwise proportions as the original mask.
        if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed))

        # Randomize globally throughout all prunable layers.
        elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask))

        # Randomize evenly across all layers.
        elif strategy == 'even':
            sparsity = mask.sparsity
            for i, k in sorted(mask.keys()):
                layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size),
                                         torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size))
                mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size)

        # Identity.
        elif strategy == 'identity': pass

        # Error.
        else: raise ValueError(f'Invalid strategy: {strategy}')

        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Save the new mask.
        mask.save(self.branch_root)

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask)
        train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
コード例 #2
0
    def test_save_load_exists(self):
        self.assertFalse(Mask.exists(self.root))
        self.assertFalse(os.path.exists(paths.mask(self.root)))

        m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])})
        m.save(self.root)
        self.assertTrue(os.path.exists(paths.mask(self.root)))
        self.assertTrue(Mask.exists(self.root))

        m2 = Mask.load(self.root)
        self.assertEqual(len(m2), 2)
        self.assertEqual(len(m2.keys()), 2)
        self.assertEqual(len(m2.values()), 2)
        self.assertEqual(set(m2.keys()), set(['hello', 'world']))
        self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello']))
        self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world']))
コード例 #3
0
    def branch_function(
            self,
            strategy: str,
            prune_fraction: float,
            prune_experiment: str = 'main',
            prune_step: str = '0ep0it',  # The step for states used to prune.
            prune_highest: bool = False,
            prune_iterations: int = 1,
            randomize_layerwise: bool = False,
            state_experiment: str = 'main',
            state_step:
        str = '0ep0it',  # The step of the state to use alongside the pruning mask.
            start_step:
        str = '0ep0it',  # The step at which to start the learning rate schedule.
            seed: int = None,
            reinitialize: bool = False):
        # Get the steps for each part of the process.
        iterations_per_epoch = datasets.registry.iterations_per_epoch(
            self.desc.dataset_hparams)
        prune_step = Step.from_str(prune_step, iterations_per_epoch)
        state_step = Step.from_str(state_step, iterations_per_epoch)
        start_step = Step.from_str(start_step, iterations_per_epoch)
        seed = self.replicate if seed is None else seed

        # Try to load the mask.
        try:
            mask = Mask.load(self.branch_root)
        except:
            mask = None

        result_folder = "Data_Distribution/"
        if reinitialize:
            result_folder = "Data_Distribution_Reinit/"
        elif randomize_layerwise:
            result_folder = "Data_Distribution__Randomize_Layerwise/"

        if not mask and get_platform().is_primary_process:
            # Gather the weights that will be used for pruning.
            prune_path = self.desc.run_path(self.replicate, prune_experiment)
            prune_model = models.registry.load(prune_path, prune_step,
                                               self.desc.model_hparams)

            # Ensure that a valid strategy is available.
            strategy_class = [s for s in strategies if s.valid_name(strategy)]
            if not strategy_class:
                raise ValueError(f'No such pruning strategy {strategy}')
            if len(strategy_class) > 1:
                raise ValueError('Multiple matching strategies')
            strategy_instance = strategy_class[0](strategy, self.desc, seed)

            # Run the strategy for each iteration.
            mask = Mask.ones_like(prune_model)
            iteration_fraction = 1 - (1 - prune_fraction)**(
                1 / float(prune_iterations))

            if iteration_fraction > 0:
                for it in range(0, prune_iterations):
                    # Make a defensive copy of the model and mask out the pruned weights.
                    prune_model2 = copy.deepcopy(prune_model)
                    with torch.no_grad():
                        for k, v in prune_model2.named_parameters():
                            v.mul_(mask.get(k, 1))

                    # Compute the scores.
                    scores = strategy_instance.score(prune_model2, mask)

                    # Prune.
                    mask = unvectorize(
                        prune(vectorize(scores),
                              iteration_fraction,
                              not prune_highest,
                              mask=vectorize(mask)), mask)

            # Shuffle randomly per layer.
            if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed)

            mask = Mask({k: v.clone().detach() for k, v in mask.items()})
            mask.save(self.branch_root)

            # Plot graphs (Move below mask save?)

            # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)
            # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

            # pdb.set_trace()

        # Load the mask.
        get_platform().barrier()
        mask = Mask.load(self.branch_root)

        # Determine the start step.
        state_path = self.desc.run_path(self.replicate, state_experiment)
        if reinitialize: model = models.registry.get(self.desc.model_hparams)
        else:
            model = models.registry.load(state_path, state_step,
                                         self.desc.model_hparams)

        # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

        original_model = copy.deepcopy(model)
        model = PrunedModel(model, mask)

        # pdb.set_trace()
        train.standard_train(model,
                             self.branch_root,
                             self.desc.dataset_hparams,
                             self.desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch)

        weights_analysis(original_model, strategy, reinitialize,
                         randomize_layerwise, "original")
        weights_analysis(model, strategy, reinitialize, randomize_layerwise,
                         "pruned")