Ejemplo n.º 1
0
    def prune(self, percentage, train_loader=None, manager=None, **kwargs):

        # get weigt elasticity
        x = SNIP.get_weight_saliencies(self, train_loader)
        weight_scores = x[0]

        # get node elasticity
        y = SNAP.get_weight_saliencies(self, train_loader)
        node_scores = y[0]

        # combined
        all_scores = torch.cat([weight_scores, node_scores])

        # get threshold
        num_params_to_keep = int(len(all_scores) * (1 - percentage))
        if num_params_to_keep < 1:
            num_params_to_keep += 1
        elif num_params_to_keep > len(all_scores):
            num_params_to_keep = len(all_scores)
        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        acceptable_score = threshold[-1]

        # get percentages
        percentage_weights = (weight_scores < acceptable_score).sum().item() / len(weight_scores)
        percentage_nodes = (node_scores < acceptable_score).sum().item() / len(node_scores)

        print("fraction for pruning nodes", percentage_nodes, "fraction for pruning weights", percentage_weights)

        # prune
        SNIP.handle_pruning(self, weight_scores, x[1], x[2], manager, x[-1], percentage_weights)
        SNAP.handle_pruning(self, node_scores, y[1], y[3], percentage_nodes)
Ejemplo n.º 2
0
    def prune(self, percentage=0.0, *args, **kwargs):
        if len(self.steps) > 0:
            # determine k_i
            percentage = self.steps.pop(0)
            kwargs["percentage"] = percentage

            # prune
            SNIP.prune(self, **kwargs)
Ejemplo n.º 3
0
    def prune(self, percentage=0.0, *args, **kwargs):
        self.grads_abs = None
        while len(self.steps) > 0:

            # determine k_i
            percentage = self.steps.pop(0)

            criterion = SNIP(model=self.model)

            # prune
            criterion.prune(percentage=percentage, *args, **kwargs)
Ejemplo n.º 4
0
    def prune(self, percentage=0.0, *args, **kwargs):
        if len(self.steps) == 0:
            print("finished all pruning events already")
            return

        if len(self.steps) > 0:
            # determine k_i
            percentage = self.steps.pop(0)
            kwargs["percentage"] = percentage

            # prune
            SNIP.prune(self, **kwargs)
Ejemplo n.º 5
0
    def _write_snip(self, epoch, trainer_ns):

        all_scores, grads_abs, log10, norm_factor = SNIP(model=trainer_ns._model,
                                                         device=trainer_ns._device).get_weight_saliencies(
            trainer_ns._train_loader)

        fig = plt.figure()
        canvas = FigureCanvasAgg(fig)

        scores = log10

        if len(scores) > 5e6:
            indices = np.random.rand(len(scores)) > (1 - (5e5 / len(scores)))
            indices[-int(len(scores) / 150):] = 1
            scores = scores[indices]

        plt.plot(scores.cpu().numpy(), label="sorted_weight_relevance")
        plt.ylim((-11, 0))
        plt.xticks([i * (len(scores) // 8) for i in range(0, 9)],
                   [str(int(100 * (i * (len(scores) // 8)) / len(scores))) + "%" for i in range(0, 9)])
        plt.grid()

        picture = self.plt_to_tensor(canvas, fig)

        self._writer.add_image("track/weight_saliency", picture, epoch)