Ejemplo n.º 1
0
    def _apply_pruning_effect(self, layer_id, removed_filter,
                              initial_filter_count, effect_applied):
        if layer_id not in self.graph_res.name_dict:
            for sub_node_id in layer_id.split(","):
                self._apply_pruning_effect(sub_node_id, removed_filter,
                                           initial_filter_count)
            return
        layer = get_node_in_model(self.model,
                                  self.graph_res.name_dict[layer_id])

        # print("\tapply effect on node: {}-{}".format(layer_id, self.graph_res.name_dict[layer_id]))
        has_more = True
        if isinstance(layer, torch.nn.modules.conv.Conv2d):
            self._prune_conv_input_filters(layer, removed_filter,
                                           initial_filter_count)
            effect_applied.append(layer_id)
            has_more = False
        elif isinstance(layer, torch.nn.modules.Linear):
            self._prune_input_linear(layer, removed_filter,
                                     initial_filter_count)
            effect_applied.append(layer_id)
            has_more = False
        elif isinstance(layer, torch.nn.modules.BatchNorm2d):
            self._prune_conv_input_batchnorm(layer, removed_filter,
                                             initial_filter_count)
            effect_applied.append(layer_id)
            initial_filter_count = layer.num_features
        else:
            if layer_id in self.graph_res.special_op:
                if layer_id not in self.special_ops_prune_apply_count.keys():
                    self.special_ops_prune_apply_count[layer_id] = 0
                    if layer_id not in self.special_ops_prune_concat_offset.keys(
                    ):
                        self.special_ops_prune_concat_offset[
                            layer_id] = initial_filter_count - len(
                                removed_filter)
                else:
                    size = self.special_ops_prune_apply_count[layer_id]
                    self.special_ops_prune_apply_count[layer_id] = size + 1
                    if self.graph_res.special_op[layer_id] == "Add":
                        has_more = False
                    elif self.graph_res.special_op[layer_id] == "Concat":
                        concat_offset = self.special_ops_prune_concat_offset[
                            layer_id]
                        new_offset = initial_filter_count - len(
                            removed_filter) + concat_offset
                        self.special_ops_prune_concat_offset[
                            layer_id] = new_offset
                        for i, elem_to_remove in enumerate(removed_filter):
                            removed_filter[i] = elem_to_remove + concat_offset

        if has_more:
            next_id = self.graph_res.execution_graph[layer_id]
            if len(next_id) > 0:
                for sub_node_id in next_id.split(","):
                    if sub_node_id not in effect_applied:
                        self._apply_pruning_effect(sub_node_id, removed_filter,
                                                   initial_filter_count,
                                                   effect_applied)
Ejemplo n.º 2
0
    def should_ignore_layer(self, layer_id):
        next_id = self.graph_res.execution_graph[layer_id]
        if next_id not in self.graph_res.name_dict:
            return True

        layer = get_node_in_model(self.model,
                                  self.graph_res.name_dict[next_id])

        has_more = True
        if isinstance(layer, torch.nn.modules.conv.Conv2d) or isinstance(
                layer, torch.nn.modules.Linear):
            has_more = False

        if has_more:
            next_id = self.graph_res.execution_graph[next_id]
            if next_id not in self.graph_res.name_dict:
                return True
            elif self.connection_count_copy[next_id] > 1:
                return True
            else:
                return self.should_ignore_layer(next_id)
Ejemplo n.º 3
0
    def prune(self, pruning_dic):
        for layer_id, filters_to_remove in pruning_dic.items():
            # print("apply pruning on node: {}-{}".format(layer_id, self.graph_res.name_dict[layer_id]))
            layer = get_node_in_model(self.model,
                                      self.graph_res.name_dict[layer_id])

            if layer is not None:
                initial_filter_count = 0
                if isinstance(layer, torch.nn.modules.conv.Conv2d):
                    initial_filter_count = self._prune_conv_output_filters(
                        layer, filters_to_remove)

                if len(filters_to_remove) > 0:
                    effect_applied = []
                    next_id = self.graph_res.execution_graph[layer_id]
                    for sub_node_id in next_id.split(","):
                        if sub_node_id not in effect_applied:
                            self._apply_pruning_effect(sub_node_id,
                                                       filters_to_remove,
                                                       initial_filter_count,
                                                       effect_applied)
    def compute_conv_graph(self):
        conv_layers = []
        to_delete = []
        for key, val in self.conv_graph.items():
            module = get_node_in_model(self.model, self.graph_res.name_dict[key])
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                conv_layers.append(key)
            elif key == self.graph_res.out_node:
                conv_layers.append(key)
            else:
                to_delete.append(key)

        for key in conv_layers:
            next = self.conv_graph[key]
            self.conv_graph[key] = ""
            new_next = []
            if next != "":
                for elem in next.split(","):
                    res = self._get_next_conv_id(conv_layers, elem)
                    new_next.extend(res)

            self.conv_graph[key] = ",".join(new_next)

        for key in to_delete:
            self.conv_graph.pop(key, None)

        self.reverse_conv_graph = {}
        for key, val in self.conv_graph.items():
            for x in val.split(","):
                if x not in self.reverse_conv_graph:
                    self.reverse_conv_graph[x] = [key]
                else:
                    self.reverse_conv_graph[x].append(key)

        #TODO ca c'est ce qu'on va vouloir garder dans notre classe
        temp = {}
        self.sets = []
        elem_to_del = None
        for key, val in self.reverse_conv_graph.items():
            if len(val) == 1:
                continue

            array_as_set = set(val)
            found = False
            for i, j in enumerate(self.sets):
                if len(array_as_set.intersection(j)) > 0:
                    self.sets[i] = array_as_set.union(j)
                    if key == self.graph_res.out_node:
                        elem_to_del = i
                    temp[key] = i
                    found = True
                    break

            if not found:
                self.sets.append(array_as_set)
                temp[key] = len(self.sets) - 1
                if key == self.graph_res.out_node:
                    elem_to_del = len(self.sets) - 1

        if elem_to_del is not None:
            self.ignore_list = list(self.sets[elem_to_del])
            del self.sets[elem_to_del]
Ejemplo n.º 5
0
    def compute_conv_graph(self):
        conv_layers = []
        to_delete = []

        # find all the convolution layer
        for key, val in self.conv_graph.items():
            module = get_node_in_model(self.model,
                                       self.graph_res.name_dict[key])
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                conv_layers.append(key)
            elif key == self.graph_res.out_node:
                conv_layers.append(key)
            else:
                to_delete.append(key)

        # find the id of the next convolution layer
        for key in conv_layers:
            next = self.conv_graph[key]
            # self.conv_graph[key] = ""
            new_next = []
            if next != "":
                for elem in next.split(","):
                    res, is_last_conv = self._get_next_conv_id(
                        conv_layers, elem)
                    if is_last_conv and self.ignore_last_conv:
                        self.ignore_list.append(key)
                    new_next.extend(res)

            if len(new_next) >= 1:
                self.conv_graph[key] = ",".join(new_next)

        for key in to_delete:
            self.conv_graph.pop(key, None)

        self.reverse_conv_graph = {}
        for key, val in self.conv_graph.items():
            for x in val.split(","):
                if x not in self.reverse_conv_graph:
                    self.reverse_conv_graph[x] = [key]
                else:
                    self.reverse_conv_graph[x].append(key)

        # TODO see if we have some impact of before split
        # for k, v in self.conv_graph.items():
        #     if len(v.split(",")) > 1 and k not in self.ignore_list:
        #         self.ignore_list.append(k)
        # TODO see if we have some impact of before merge
        # temp = {}
        # for key, val in self.reverse_conv_graph.items():
        #     for x in val:
        #         if x not in temp:
        #             temp[x] = [key]
        #         else:
        #             temp[x].append(key)
        #
        # for key, val in temp.items():
        #     if len(val) > 1:
        #         self.ignore_list.append(key)
        #end temp section

        self.sets = []
        elem_to_del = None
        for key, val in self.reverse_conv_graph.items():
            if len(val) == 1:
                continue

            array_as_set = set(val)
            found = False
            for i, j in enumerate(self.sets):
                if len(array_as_set.intersection(j)) > 0:
                    self.sets[i] = array_as_set.union(j)
                    if key == self.graph_res.out_node:
                        elem_to_del = i
                    found = True
                    break

            if not found:
                self.sets.append(array_as_set)
                if key == self.graph_res.out_node:
                    elem_to_del = len(self.sets) - 1

        if self.ignore_last_conv and self.graph_res.out_node in self.reverse_conv_graph.keys(
        ):
            self.ignore_list.extend(
                self.reverse_conv_graph[self.graph_res.out_node])

        if elem_to_del is not None:
            self.ignore_list.extend(list(self.sets[elem_to_del]))
            del self.sets[elem_to_del]
Ejemplo n.º 6
0
    def parse(self, node_id):
        node_name = self.graph_res.name_dict[node_id]
        if self.connection_count[node_id] > 0:
            return None

        curr_module = get_node_in_model(self.model, node_name)
        if curr_module is None:
            if node_id in self.graph_res.special_op.keys():
                if self.graph_res.special_op[node_id] == "AveragePool":
                    shape, pad, stride = self.graph_res.special_op_params[
                        node_id]
                    out = F.avg_pool2d(self.forward_res[node_id],
                                       kernel_size=shape,
                                       stride=stride)
                elif self.graph_res.special_op[node_id] == "Add":
                    out = self.forward_res[node_id]
            else:
                out = self.forward_res[node_id]
        else:
            # self._pre_parse_internal(node_id)

            x = self.forward_res[node_id]
            if isinstance(curr_module, torch.nn.modules.Linear):
                x = x.view(x.size(0), -1)

            if isinstance(curr_module, torch.nn.modules.conv.Conv2d):
                self.handle_before_conv_in_forward(curr_module, node_id)

            should_not_skip = True
            if node_id in self.graph_res.special_op.keys():
                if self.graph_res.special_op[node_id] == "Concat":
                    should_not_skip = False
                    out = x
                elif self.graph_res.special_op[node_id] == "AveragePool":
                    shape, pad, stride = self.graph_res.special_op_params[
                        node_id]
                    out = F.avg_pool2d(x, kernel_size=shape, stride=stride)
                    should_not_skip = False

            if should_not_skip:
                out = curr_module(x)

            if isinstance(curr_module, torch.nn.modules.conv.Conv2d):
                self.handle_after_conv_in_forward(curr_module, node_id, out)

        res = None
        next_nodes = self.graph_res.execution_graph[node_id]
        if len(next_nodes) == 0:
            res = out
        else:
            for next_id in self.graph_res.execution_graph[node_id].split(","):
                self.connection_count[next_id] -= 1
                if next_id in self.forward_res:
                    if next_id in self.graph_res.special_op.keys(
                    ) and self.graph_res.special_op[next_id] == "Concat":
                        self.forward_res[next_id] = torch.cat(
                            (self.forward_res[next_id], out), 1)
                    else:
                        self.forward_res[
                            next_id] = self.forward_res[next_id] + out
                else:
                    self.forward_res[next_id] = out

                res = self.parse(next_id)
        return res