Example #1
0
    def structured_prune(self, indices):
        summed_weights = sum([
            np.prod(x.shape) for name, x in self.model.named_parameters()
            if "weight" in name
        ])
        # handle outer layers
        if not self.model._outer_layer_pruning:
            offsets = [
                len(x[0][1]) for x in lookahead_finished(indices.items())
                if x[1][0] or x[1][1]
            ]
            # breakpoint()
        #     all_scores = all_scores[offsets[0]:-offsets[1]]
        # prune
        summed_pruned = 0
        toggle_row_column = True
        cutoff = 0
        length_nonzero = 0
        for ((identification, name),
             grad), (first, last) in lookahead_finished(indices.items()):
            # breakpoint()
            binary_keep_neuron_vector = ((grad) > 0).float().to(self.device)
            corresponding_weight_parameter = [
                val for key, val in self.model.named_parameters()
                if key == name
            ][0]
            is_conv = len(corresponding_weight_parameter.shape) > 2
            corresponding_module: nn.Module = \
                [val for key, val in self.model.named_modules() if key == name.split(".weight")[0]][0]

            # ensure not disconnecting
            if binary_keep_neuron_vector.sum() == 0:
                best_index = torch.argmax(grad)
                binary_keep_neuron_vector[best_index] = 1

            if first or last:
                # noinspection PyTypeChecker
                length_nonzero = self.handle_outer_layers(
                    binary_keep_neuron_vector, first, is_conv, last,
                    length_nonzero, corresponding_module, name,
                    corresponding_weight_parameter)
            else:

                cutoff, length_nonzero = self.handle_middle_layers(
                    binary_keep_neuron_vector, cutoff, is_conv, length_nonzero,
                    corresponding_module, name, toggle_row_column,
                    corresponding_weight_parameter)

            cutoff, summed_pruned = self.print_layer_progress(
                cutoff, indices, length_nonzero, name, summed_pruned,
                toggle_row_column, corresponding_weight_parameter)
            toggle_row_column = not toggle_row_column
        for line in str(self.model).split("\n"):
            if "BatchNorm" in line or "Conv" in line or "Linear" in line or "AdaptiveAvg" in line or "Sequential" in line:
                print(line)
        print("final percentage after snap:", summed_pruned / summed_weights)

        self.model.apply_weight_mask()
Example #2
0
    def handle_pruning(self, all_scores, grads_abs, norm_factor, percentage):
        summed_weights = sum([
            np.prod(x.shape) for name, x in self.model.named_parameters()
            if "weight" in name
        ])
        num_nodes_to_keep = int(len(all_scores) * (1 - percentage))

        # handle outer layers
        if not self.model._outer_layer_pruning:
            offsets = [
                len(x[0][1]) for x in lookahead_finished(grads_abs.items())
                if x[1][0] or x[1][1]
            ]
            all_scores = all_scores[offsets[0]:-offsets[1]]
            num_nodes_to_keep = int(len(all_scores) * (1 - percentage))

        # dont prune more or less than is available
        if num_nodes_to_keep > len(all_scores):
            num_nodes_to_keep = len(all_scores)
        elif num_nodes_to_keep == 0:
            num_nodes_to_keep = 1

        # threshold
        threshold, _ = torch.topk(all_scores, num_nodes_to_keep, sorted=True)
        del _
        acceptable_score = threshold[-1]

        # prune
        summed_pruned = 0
        toggle_row_column = True
        cutoff = 0
        length_nonzero = 0
        for ((identification, name),
             grad), (first, last) in lookahead_finished(grads_abs.items()):

            binary_keep_neuron_vector = ((grad / norm_factor) >=
                                         acceptable_score).float().to(
                                             self.device)
            corresponding_weight_parameter = [
                val for key, val in self.model.named_parameters()
                if key == name
            ][0]
            is_conv = len(corresponding_weight_parameter.shape) > 2
            corresponding_module: nn.Module = \
                [val for key, val in self.model.named_modules() if key == name.split(".weight")[0]][0]

            # ensure not disconnecting
            if binary_keep_neuron_vector.sum() == 0:
                best_index = torch.argmax(grad)
                binary_keep_neuron_vector[best_index] = 1

            if first or last:
                # noinspection PyTypeChecker
                length_nonzero = self.handle_outer_layers(
                    binary_keep_neuron_vector, first, is_conv, last,
                    length_nonzero, corresponding_module, name,
                    corresponding_weight_parameter)
            else:

                cutoff, length_nonzero = self.handle_middle_layers(
                    binary_keep_neuron_vector, cutoff, is_conv, length_nonzero,
                    corresponding_module, name, toggle_row_column,
                    corresponding_weight_parameter)

            cutoff, summed_pruned = self.print_layer_progress(
                cutoff, grads_abs, length_nonzero, name, summed_pruned,
                toggle_row_column, corresponding_weight_parameter)
            toggle_row_column = not toggle_row_column
        for line in str(self.model).split("\n"):
            if "BatchNorm" in line or "Conv" in line or "Linear" in line or "AdaptiveAvg" in line or "Sequential" in line:
                print(line)
        print("final percentage after snap:", summed_pruned / summed_weights)

        self.model.apply_weight_mask()
        self.cut_lonely_connections()