示例#1
0
    def segment_parallelism(self, original_segment, pruned_segment):
        """
        Parallel setting for segment
        """

        self.original_segment_parallel = utils.data_parallel(original_segment, self.settings.n_gpus)
        self.pruned_segment_parallel = utils.data_parallel(pruned_segment, self.settings.n_gpus)
示例#2
0
 def model_parallelism(self):
     # print('**********after create auxiliary classifier**********')
     # print(self.ori_segments)
     # print(self.pruned_segments)
     # print()
     self.ori_segments = utils.data_parallel(model=self.ori_segments,
                                             n_gpus=self.settings.n_gpus)
     self.pruned_segments = utils.data_parallel(model=self.pruned_segments,
                                                n_gpus=self.settings.n_gpus)
     self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1)
示例#3
0
    def __init__(self,
                 model,
                 train_loader,
                 val_loader,
                 settings,
                 logger,
                 tensorboard_logger,
                 optimizer_state=None,
                 run_count=0):
        self.settings = settings

        self.model = utils.data_parallel(model=model,
                                         n_gpus=self.settings.n_gpus)
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.criterion = nn.CrossEntropyLoss().cuda()
        self.lr = self.settings.lr
        self.optimizer = torch.optim.SGD(
            params=self.model.parameters(),
            lr=self.settings.lr,
            momentum=self.settings.momentum,
            weight_decay=self.settings.weight_decay,
            nesterov=True)
        if optimizer_state is not None:
            self.optimizer.load_state_dict(optimizer_state)

        self.logger = logger
        self.tensorboard_logger = tensorboard_logger
        self.run_count = run_count
示例#4
0
    def __init__(self,
                 pruned_model,
                 train_loader,
                 val_loader,
                 settings,
                 logger,
                 tensorboard_logger,
                 run_count=0):
        self.pruned_model = utils.data_parallel(pruned_model, settings.n_gpus)
        self.settings = settings
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.logger = logger
        self.tensorboard_logger = tensorboard_logger
        self.criterion = nn.CrossEntropyLoss().cuda()

        self.optimizer = torch.optim.SGD(
            params=self.pruned_model.parameters(),
            lr=self.settings.lr,
            momentum=self.settings.momentum,
            weight_decay=self.settings.weight_decay,
            nesterov=True)

        self.run_count = run_count
        self.lr = self.settings.lr
        self.scalar_info = {}
    def update_model(self, pruned_model, optimizer_state=None):
        """
        update pruned model parameter
        :param pruned_model: pruned model
        """

        self.optimizer = None
        self.pruned_model = utils.data_parallel(pruned_model, self.settings.n_gpus)
        self.optimizer = torch.optim.SGD(
            params=self.pruned_model.parameters(),
            lr=self.settings.network_wise_lr,
            momentum=self.settings.momentum,
            weight_decay=self.settings.weight_decay,
            nesterov=True)
        if optimizer_state:
            self.optimizer.load_state_dict(optimizer_state)
示例#6
0
 def model_parallelism(self):
     self.segments = utils.data_parallel(model=self.segments, n_gpus=self.settings.n_gpus)
     self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1)
    def _network_split(self):
        """"
            1. split the network into several segments with pre-define pivot set
            2. create auxiliary classifiers
            3. create optimizers for network segments and fcs
        """

        net_origin = None
        net_pruned = None

        if self.settings.net_type in ["preresnet", "resnet"]:
            if self.settings.net_type == "preresnet":
                net_origin = nn.Sequential(self.ori_model.conv)
                net_pruned = nn.Sequential(self.pruned_model.conv)
            elif self.settings.net_type == "resnet":
                net_head = nn.Sequential(
                    self.ori_model.conv1,
                    self.ori_model.bn1,
                    self.ori_model.relu,
                    self.ori_model.maxpool)
                net_origin = nn.Sequential(net_head)
                net_head = nn.Sequential(
                    self.pruned_model.conv1,
                    self.pruned_model.bn1,
                    self.pruned_model.relu,
                    self.pruned_model.maxpool)
                net_pruned = nn.Sequential(net_head)
            self.logger.info("init shallow head done!")

        else:
            assert False, "unsupported net_type: {}".format(self.settings.net_type)

        block_count = 0
        if self.settings.net_type in ["resnet", "preresnet"]:
            for ori_module, pruned_module in zip(self.ori_model.modules(), self.pruned_model.modules()):
                if isinstance(ori_module, (PreBasicBlock, Bottleneck, BasicBlock)):
                    self.logger.info("enter block: {}".format(type(ori_module)))
                    if net_origin is not None:
                        net_origin.add_module(str(len(net_origin)), ori_module)
                    else:
                        net_origin = nn.Sequential(ori_module)

                    if net_pruned is not None:
                        net_pruned.add_module(str(len(net_pruned)), pruned_module)
                    else:
                        net_pruned = nn.Sequential(pruned_module)
                    block_count += 1

                    # if block_count is equals to pivot_num, then create new segment
                    if block_count in self.settings.pivot_set:
                        self.ori_segments.append(net_origin)
                        self.pruned_segments.append(net_pruned)
                        net_origin = None
                        net_pruned = None

        self.final_block_count = block_count
        self.ori_segments.append(net_origin)
        self.pruned_segments.append(net_pruned)

        # create auxiliary classifier
        num_classes = self.settings.n_classes
        in_channels = 0
        for i in range(len(self.pruned_segments) - 1):
            if isinstance(self.pruned_segments[i][-1], (PreBasicBlock, BasicBlock)):
                in_channels = self.pruned_segments[i][-1].conv2.out_channels
            elif isinstance(self.pruned_segments[i][-1], Bottleneck):
                in_channels = self.pruned_segments[i][-1].conv3.out_channels
            assert in_channels != 0, "in_channels is zero"

            self.aux_fc.append(AuxClassifier(in_channels=in_channels, num_classes=num_classes))

        pruned_final_fc = None
        if self.settings.net_type == "preresnet":
            pruned_final_fc = nn.Sequential(*[
                self.pruned_model.bn,
                self.pruned_model.relu,
                self.pruned_model.avg_pool,
                View(),
                self.pruned_model.fc])
        elif self.settings.net_type == "resnet":
            pruned_final_fc = nn.Sequential(*[
                self.pruned_model.avgpool,
                View(),
                self.pruned_model.fc])
        self.aux_fc.append(pruned_final_fc)

        # model parallel
        self.ori_segments = utils.data_parallel(model=self.ori_segments, n_gpus=self.settings.n_gpus)
        self.pruned_segments = utils.data_parallel(model=self.pruned_segments, n_gpus=self.settings.n_gpus)
        self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1)

        # create optimizers
        for i in range(len(self.pruned_segments)):
            temp_optim = []
            # add parameters in segmenets into optimizer
            # from the i-th optimizer contains [0:i] segments
            for j in range(i + 1):
                temp_optim.append({'params': self.pruned_segments[j].parameters(),
                                   'lr': self.settings.segment_wise_lr})

            # optimizer for segments and fc
            temp_seg_optim = torch.optim.SGD(
                temp_optim,
                momentum=self.settings.momentum,
                weight_decay=self.settings.weight_decay,
                nesterov=True)

            temp_fc_optim = torch.optim.SGD(
                params=self.aux_fc[i].parameters(),
                lr=self.settings.segment_wise_lr,
                momentum=self.settings.momentum,
                weight_decay=self.settings.weight_decay,
                nesterov=True)

            self.seg_optimizer.append(temp_seg_optim)
            self.fc_optimizer.append(temp_fc_optim)
    def _layer_channel_selection(self,
                                 net_origin,
                                 net_pruned,
                                 aux_fc,
                                 module,
                                 block_count,
                                 layer_name="conv2"):
        """
        conduct channel selection for module
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param module: the module need to be pruned
        :param block_count: current block no.
        :param layer_name: the name of layer need to be pruned
        """

        self.logger.info(
            "|===>layer-wise channel selection: block-{}-{}".format(
                block_count, layer_name))
        # layer-wise channel selection
        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            layer = temp_conv

        # define criterion
        criterion_mse = nn.MSELoss().cuda()
        criterion_softmax = nn.CrossEntropyLoss().cuda()

        # register hook
        if layer_name == "conv2":
            hook_origin = net_origin[block_count].conv2.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv2.register_forward_hook(
                self._hook_pruned_feature)
        elif layer_name == "conv3":
            hook_origin = net_origin[block_count].conv3.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv3.register_forward_hook(
                self._hook_pruned_feature)

        net_origin_parallel = utils.data_parallel(net_origin,
                                                  self.settings.n_gpus)
        net_pruned_parallel = utils.data_parallel(net_pruned,
                                                  self.settings.n_gpus)

        # avoid computing the gradient
        for params in net_origin_parallel.parameters():
            params.requires_grad = False
        for params in net_pruned_parallel.parameters():
            params.requires_grad = False

        net_origin_parallel.eval()
        net_pruned_parallel.eval()

        layer.pruned_weight.requires_grad = True
        aux_fc.cuda()
        logger_counter = 0
        record_time = utils.AverageMeter()

        for channel in range(layer.in_channels):
            if layer.d.eq(0).sum() <= math.floor(
                    layer.in_channels * self.settings.pruning_rate):
                break

            time_start = time.time()
            cum_grad = None
            record_selection_mse_loss = utils.AverageMeter()
            record_selection_softmax_loss = utils.AverageMeter()
            record_selection_loss = utils.AverageMeter()
            img_count = 0
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.cuda()
                labels = labels.cuda()
                net_origin_parallel(images)
                output = net_pruned_parallel(images)
                softmax_loss = criterion_softmax(aux_fc(output), labels)

                origin_feature = self._concat_gpu_data(
                    self.feature_cache_origin)
                self.feature_cache_origin = {}
                pruned_feature = self._concat_gpu_data(
                    self.feature_cache_pruned)
                self.feature_cache_pruned = {}
                mse_loss = criterion_mse(pruned_feature, origin_feature)

                loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                loss.backward()
                record_selection_loss.update(loss.item(), images.size(0))
                record_selection_mse_loss.update(mse_loss.item(),
                                                 images.size(0))
                record_selection_softmax_loss.update(softmax_loss.item(),
                                                     images.size(0))

                if cum_grad is None:
                    cum_grad = layer.pruned_weight.grad.data.clone()
                else:
                    cum_grad.add_(layer.pruned_weight.grad.data)
                    layer.pruned_weight.grad = None

                img_count += images.size(0)
                if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                    break

            # write tensorboard log
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_LossAll".format(block_count, layer_name),
                value=record_selection_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_selection_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_selection_softmax_loss.avg,
                step=logger_counter)
            cum_grad.abs_()
            # calculate gradient F norm
            grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0)

            # find grad_fnorm with maximum absolute gradient
            while True:
                _, max_index = torch.topk(grad_fnorm, 1)
                if layer.d[max_index[0]] == 0:
                    layer.d[max_index[0]] = 1
                    layer.pruned_weight.data[:, max_index[
                        0], :, :] = layer.weight[:,
                                                 max_index[0], :, :].data.clone(
                                                 )
                    break
                else:
                    grad_fnorm[max_index[0]] = -1

            # fine-tune average meter
            record_finetune_softmax_loss = utils.AverageMeter()
            record_finetune_mse_loss = utils.AverageMeter()
            record_finetune_loss = utils.AverageMeter()

            record_finetune_top1_error = utils.AverageMeter()
            record_finetune_top5_error = utils.AverageMeter()

            # define optimizer
            params_list = []
            params_list.append({
                "params": layer.pruned_weight,
                "lr": self.settings.layer_wise_lr
            })
            if layer.bias is not None:
                layer.bias.requires_grad = True
                params_list.append({"params": layer.bias, "lr": 0.001})
            optimizer = torch.optim.SGD(
                params=params_list,
                weight_decay=self.settings.weight_decay,
                momentum=self.settings.momentum,
                nesterov=True)
            img_count = 0
            for epoch in range(1):
                for i, (images, labels) in enumerate(self.train_loader):
                    images = images.cuda()
                    labels = labels.cuda()
                    features = net_pruned_parallel(images)
                    net_origin_parallel(images)
                    output = aux_fc(features)
                    softmax_loss = criterion_softmax(output, labels)

                    origin_feature = self._concat_gpu_data(
                        self.feature_cache_origin)
                    self.feature_cache_origin = {}
                    pruned_feature = self._concat_gpu_data(
                        self.feature_cache_pruned)
                    self.feature_cache_pruned = {}
                    mse_loss = criterion_mse(pruned_feature, origin_feature)

                    top1_error, _, top5_error = utils.compute_singlecrop(
                        outputs=output,
                        labels=labels,
                        loss=softmax_loss,
                        top5_flag=True,
                        mean_flag=True)

                    # update parameters
                    optimizer.zero_grad()
                    loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(layer.parameters(),
                                                   max_norm=10.0)
                    layer.pruned_weight.grad.data.mul_(
                        layer.d.unsqueeze(0).unsqueeze(2).unsqueeze(
                            3).expand_as(layer.pruned_weight))
                    optimizer.step()
                    # update record info
                    record_finetune_softmax_loss.update(
                        softmax_loss.item(), images.size(0))
                    record_finetune_mse_loss.update(mse_loss.item(),
                                                    images.size(0))
                    record_finetune_loss.update(loss.item(), images.size(0))
                    record_finetune_top1_error.update(top1_error,
                                                      images.size(0))
                    record_finetune_top5_error.update(top5_error,
                                                      images.size(0))

                    img_count += images.size(0)
                    if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                        break

            layer.pruned_weight.grad = None
            if layer.bias is not None:
                layer.bias.requires_grad = False

            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_finetune_softmax_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Loss".format(block_count, layer_name),
                value=record_finetune_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_finetune_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top1Error".format(block_count, layer_name),
                value=record_finetune_top1_error.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top5Error".format(block_count, layer_name),
                value=record_finetune_top5_error.avg,
                step=logger_counter)

            # write log information to file
            self._write_log(
                dir_name=os.path.join(self.settings.save_path, "log"),
                file_name="log_block-{:0>2d}_{}.txt".format(
                    block_count, layer_name),
                log_str=
                "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n".
                format(int(layer.d.sum()), record_selection_loss.avg,
                       record_selection_mse_loss.avg,
                       record_selection_softmax_loss.avg,
                       record_finetune_loss.avg, record_finetune_mse_loss.avg,
                       record_finetune_softmax_loss.avg,
                       record_finetune_top1_error.avg,
                       record_finetune_top5_error.avg))
            log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format(
                block_count, layer_name, int(layer.d.sum()), layer.d.size(0))
            log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_selection_loss.avg, record_selection_mse_loss.avg,
                record_selection_softmax_loss.avg)
            log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_finetune_loss.avg, record_finetune_mse_loss.avg,
                record_finetune_softmax_loss.avg)
            log_str += "top1error: {:4f}\ttop5error: {:4f}".format(
                record_finetune_top1_error.avg, record_finetune_top5_error.avg)
            self.logger.info(log_str)

            logger_counter += 1
            time_interval = time.time() - time_start
            record_time.update(time_interval)

        for params in net_origin_parallel.parameters():
            params.requires_grad = True
        for params in net_pruned_parallel.parameters():
            params.requires_grad = True

        # remove hook
        hook_origin.remove()
        hook_pruned.remove()
        log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format(
            block_count, layer_name,
            str(datetime.timedelta(seconds=record_time.sum)),
            str(datetime.timedelta(seconds=record_time.avg)))
        self.logger.info(log_str)
        log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format(
            record_finetune_loss.avg, record_finetune_mse_loss.avg,
            record_finetune_softmax_loss.avg, record_finetune_top1_error.avg,
            record_finetune_top5_error.avg)
        self.logger.info(log_str)

        self.logger.info("|===>remove hook")