Exemple #1
0
    def split_segment_into_three_parts(self, original_segment, pruned_segment,
                                       block_count):
        """
        Split the segment into three parts:
            segment_before_pruned_module, pruned_module, segment_after_pruned_module.
        In this way, we can store the input of the pruned module.
        """

        original_segment_list = utils.model2list(original_segment)
        pruned_segment_list = utils.model2list(pruned_segment)

        original_segment_before_pruned_module = []
        pruned_segment_before_pruned_module = []
        pruned_segment_after_pruned_module = []
        for i in range(len(pruned_segment)):
            if i < block_count:
                original_segment_before_pruned_module.append(
                    original_segment_list[i])
                pruned_segment_before_pruned_module.append(
                    pruned_segment_list[i])
            if i > block_count:
                pruned_segment_after_pruned_module.append(
                    pruned_segment_list[i])
        self.original_segment_before_pruned_module = nn.Sequential(
            *original_segment_before_pruned_module)
        self.pruned_segment_before_pruned_module = nn.Sequential(
            *pruned_segment_before_pruned_module)
        self.pruned_segment_after_pruned_module = nn.Sequential(
            *pruned_segment_after_pruned_module)
Exemple #2
0
    def channel_selection_for_network(self):
        """
        Conduct channel selection
        """

        # get testing error
        self.segment_wise_trainer.val(0)
        time_start = time.time()

        restart_index = None

        # find restart segment index
        if self.current_pivot_index:
            if self.current_pivot_index in self.settings.pivot_set:
                restart_index = self.settings.pivot_set.index(
                    self.current_pivot_index)
            else:
                restart_index = len(self.settings.pivot_set)

        for index in range(self.num_segments):
            if restart_index is not None:
                if index < restart_index:
                    continue
                elif index == restart_index:
                    if self.current_block_count == self.current_pivot_index:
                        continue

            if index == self.num_segments - 1:
                self.current_pivot_index = self.segment_wise_trainer.final_block_count
            else:
                self.current_pivot_index = self.settings.pivot_set[index]

            # conduct channel selection
            # contains [0:index] segments
            original_segment_list = []
            pruned_segment_list = []
            for j in range(index + 1):
                original_segment_list += utils.model2list(
                    self.segment_wise_trainer.ori_segments[j])
                pruned_segment_list += utils.model2list(
                    self.segment_wise_trainer.pruned_segments[j])

            original_segment = nn.Sequential(*original_segment_list)
            pruned_segment = nn.Sequential(*pruned_segment_list)

            net_pruned = self.channel_selection_for_one_segment(
                original_segment, pruned_segment,
                self.segment_wise_trainer.aux_fc[index],
                self.current_pivot_index, index)

            self.logger.info(self.segment_wise_trainer.pruned_segments)
            self.logger.info(net_pruned)
            self.logger.info(self.original_model)
            self.logger.info(self.pruned_model)
            self.segment_wise_trainer.val(0)
            self.current_pivot_index = None

        self.checkpoint.save_dcp_model(
            self.original_model,
            self.pruned_model,
            self.segment_wise_trainer.aux_fc,
            self.segment_wise_trainer.final_block_count,
            index=self.num_segments)
        self.pruning()

        self.checkpoint.save_dcp_model(
            self.original_model,
            self.pruned_model,
            self.segment_wise_trainer.aux_fc,
            self.segment_wise_trainer.final_block_count,
            index=self.num_segments + 1)

        time_interval = time.time() - time_start
        log_str = "cost time: {}".format(
            str(datetime.timedelta(seconds=time_interval)))
        self.logger.info(log_str)
    def channel_selection(self):
        """
        conduct channel selection
        """

        # get testing error
        self.segment_wise_trainer.val(0)
        time_start = time.time()

        restart_index = None
        # find restart segment index
        if self.current_pivot_index:
            if self.current_pivot_index in self.settings.pivot_set:
                restart_index = self.settings.pivot_set.index(
                    self.current_pivot_index)
            else:
                restart_index = len(self.settings.pivot_set)

        for index in range(self.num_segments):
            if restart_index is not None:
                if index < restart_index:
                    continue
                elif index == restart_index:
                    if self.is_channel_selection and self.current_block_count == self.current_pivot_index:
                        self.is_channel_selection = False
                        continue

            if index == self.num_segments - 1:
                self.current_pivot_index = self.segment_wise_trainer.final_block_count
            else:
                self.current_pivot_index = self.settings.pivot_set[index]

            # fine tune the network with additional loss and final loss
            if (not self.is_segment_wise_finetune and not self.is_channel_selection) or \
                    (self.is_segment_wise_finetune and self.epoch != self.settings.segment_wise_n_epochs - 1):
                self.segment_wise_fine_tune(index)
            else:
                self.is_segment_wise_finetune = False

            # load best model
            best_model_path = os.path.join(
                self.checkpoint.save_path,
                'model_{:0>3d}_swft.pth'.format(index))
            check_point_params = torch.load(best_model_path)
            ori_model_state = check_point_params["ori_model"]
            pruned_model_state = check_point_params["pruned_model"]
            aux_fc_state = check_point_params["aux_fc"]
            self.ori_model = self.checkpoint.load_state(
                self.ori_model, ori_model_state)
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            # replace the baseline model
            if index == 0:
                if self.settings.net_type in ['preresnet']:
                    self.ori_model.conv = copy.deepcopy(self.pruned_model.conv)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, PreBasicBlock):
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.bn = copy.deepcopy(self.pruned_model.bn)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)
                elif self.settings.net_type in ['resnet']:
                    self.ori_model.conv1 = copy.deepcopy(
                        self.pruned_model.conv)
                    self.ori_model.bn1 = copy.deepcopy(self.pruned_model.bn1)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, BasicBlock):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                        if isinstance(ori_module, Bottleneck):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.conv3 = copy.deepcopy(
                                pruned_module.conv3)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.bn3 = copy.deepcopy(pruned_module.bn3)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)

                aux_fc_state = []
                for i in range(len(self.segment_wise_trainer.aux_fc)):
                    if isinstance(self.segment_wise_trainer.aux_fc[i],
                                  nn.DataParallel):
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].module.state_dict()
                    else:
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].state_dict()
                    aux_fc_state.append(temp_state)
                self.segment_wise_trainer.update_model(self.ori_model,
                                                       self.pruned_model,
                                                       aux_fc_state)
            self.segment_wise_trainer.val(0)

            # conduct channel selection
            # contains [0:index] segments
            net_origin_list = []
            net_pruned_list = []
            for j in range(index + 1):
                net_origin_list += utils.model2list(
                    self.segment_wise_trainer.ori_segments[j])
                net_pruned_list += utils.model2list(
                    self.segment_wise_trainer.pruned_segments[j])

            net_origin = nn.Sequential(*net_origin_list)
            net_pruned = nn.Sequential(*net_pruned_list)

            self._seg_channel_selection(
                net_origin=net_origin,
                net_pruned=net_pruned,
                aux_fc=self.segment_wise_trainer.aux_fc[index],
                pivot_index=self.current_pivot_index,
                index=index)

            # update optimizer
            aux_fc_state = []
            for i in range(len(self.segment_wise_trainer.aux_fc)):
                if isinstance(self.segment_wise_trainer.aux_fc[i],
                              nn.DataParallel):
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].module.state_dict()
                else:
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].state_dict()
                aux_fc_state.append(temp_state)

            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            self.checkpoint.save_checkpoint(
                self.ori_model,
                self.pruned_model,
                self.segment_wise_trainer.aux_fc,
                self.segment_wise_trainer.fc_optimizer,
                self.segment_wise_trainer.seg_optimizer,
                self.current_pivot_index,
                channel_selection=True,
                index=index,
                block_count=self.current_pivot_index)

            self.logger.info(self.ori_model)
            self.logger.info(self.pruned_model)
            self.segment_wise_trainer.val(0)
            self.current_pivot_index = None

        self.checkpoint.save_model(self.ori_model,
                                   self.pruned_model,
                                   self.segment_wise_trainer.aux_fc,
                                   self.segment_wise_trainer.final_block_count,
                                   index=self.num_segments)
        time_interval = time.time() - time_start
        log_str = "cost time: {}".format(
            str(datetime.timedelta(seconds=time_interval)))
        self.logger.info(log_str)