コード例 #1
0
    def __init__(self, model: Model, mask: Mask):
        if isinstance(model, PrunedModel):
            raise ValueError('Cannot nest pruned models.')
        super(PrunedModel, self).__init__()
        self.model = model

        for k in self.model.prunable_layer_names:
            if k not in mask:
                raise ValueError('Missing mask value {}.'.format(k))
            if not np.array_equal(mask[k].shape,
                                  np.array(self.model.state_dict()[k].shape)):
                raise ValueError(
                    'Incorrect mask shape {} for tensor {}.'.format(
                        mask[k].shape, k))

        for k in mask:
            if k not in self.model.prunable_layer_names:
                raise ValueError(
                    'Key {} found in mask but is not a valid model tensor.'.
                    format(k))

        for k, v in mask.items():
            self.register_buffer(PrunedModel.to_mask_name(k), v.float())
        self._apply_mask()
コード例 #2
0
    def branch_function(
            self,
            strategy: str,
            prune_fraction: float,
            prune_experiment: str = 'main',
            prune_step: str = '0ep0it',  # The step for states used to prune.
            prune_highest: bool = False,
            prune_iterations: int = 1,
            randomize_layerwise: bool = False,
            state_experiment: str = 'main',
            state_step:
        str = '0ep0it',  # The step of the state to use alongside the pruning mask.
            start_step:
        str = '0ep0it',  # The step at which to start the learning rate schedule.
            seed: int = None,
            reinitialize: bool = False):
        # Get the steps for each part of the process.
        iterations_per_epoch = datasets.registry.iterations_per_epoch(
            self.desc.dataset_hparams)
        prune_step = Step.from_str(prune_step, iterations_per_epoch)
        state_step = Step.from_str(state_step, iterations_per_epoch)
        start_step = Step.from_str(start_step, iterations_per_epoch)
        seed = self.replicate if seed is None else seed

        # Try to load the mask.
        try:
            mask = Mask.load(self.branch_root)
        except:
            mask = None

        result_folder = "Data_Distribution/"
        if reinitialize:
            result_folder = "Data_Distribution_Reinit/"
        elif randomize_layerwise:
            result_folder = "Data_Distribution__Randomize_Layerwise/"

        if not mask and get_platform().is_primary_process:
            # Gather the weights that will be used for pruning.
            prune_path = self.desc.run_path(self.replicate, prune_experiment)
            prune_model = models.registry.load(prune_path, prune_step,
                                               self.desc.model_hparams)

            # Ensure that a valid strategy is available.
            strategy_class = [s for s in strategies if s.valid_name(strategy)]
            if not strategy_class:
                raise ValueError(f'No such pruning strategy {strategy}')
            if len(strategy_class) > 1:
                raise ValueError('Multiple matching strategies')
            strategy_instance = strategy_class[0](strategy, self.desc, seed)

            # Run the strategy for each iteration.
            mask = Mask.ones_like(prune_model)
            iteration_fraction = 1 - (1 - prune_fraction)**(
                1 / float(prune_iterations))

            if iteration_fraction > 0:
                for it in range(0, prune_iterations):
                    # Make a defensive copy of the model and mask out the pruned weights.
                    prune_model2 = copy.deepcopy(prune_model)
                    with torch.no_grad():
                        for k, v in prune_model2.named_parameters():
                            v.mul_(mask.get(k, 1))

                    # Compute the scores.
                    scores = strategy_instance.score(prune_model2, mask)

                    # Prune.
                    mask = unvectorize(
                        prune(vectorize(scores),
                              iteration_fraction,
                              not prune_highest,
                              mask=vectorize(mask)), mask)

            # Shuffle randomly per layer.
            if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed)

            mask = Mask({k: v.clone().detach() for k, v in mask.items()})
            mask.save(self.branch_root)

            # Plot graphs (Move below mask save?)

            # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)
            # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

            # pdb.set_trace()

        # Load the mask.
        get_platform().barrier()
        mask = Mask.load(self.branch_root)

        # Determine the start step.
        state_path = self.desc.run_path(self.replicate, state_experiment)
        if reinitialize: model = models.registry.get(self.desc.model_hparams)
        else:
            model = models.registry.load(state_path, state_step,
                                         self.desc.model_hparams)

        # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

        original_model = copy.deepcopy(model)
        model = PrunedModel(model, mask)

        # pdb.set_trace()
        train.standard_train(model,
                             self.branch_root,
                             self.desc.dataset_hparams,
                             self.desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch)

        weights_analysis(original_model, strategy, reinitialize,
                         randomize_layerwise, "original")
        weights_analysis(model, strategy, reinitialize, randomize_layerwise,
                         "pruned")