コード例 #1
0
    def replace_layer_with_mask_conv_resnet(self):
        """
        Replace the conv layer in resnet with mask_conv for ResNet
        """

        for module in self.pruned_model.modules():
            if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                # replace conv2
                temp_conv = MaskConv2d(in_channels=module.conv2.in_channels,
                                       out_channels=module.conv2.out_channels,
                                       kernel_size=module.conv2.kernel_size,
                                       stride=module.conv2.stride,
                                       padding=module.conv2.padding,
                                       bias=(module.conv2.bias is not None))

                temp_conv.weight.data.copy_(module.conv2.weight.data)
                if module.conv2.bias is not None:
                    temp_conv.bias.data.copy_(module.conv2.bias.data)
                module.conv2 = temp_conv

                if isinstance(module, Bottleneck):
                    # replace conv3
                    temp_conv = MaskConv2d(
                        in_channels=module.conv3.in_channels,
                        out_channels=module.conv3.out_channels,
                        kernel_size=module.conv3.kernel_size,
                        stride=module.conv3.stride,
                        padding=module.conv3.padding,
                        bias=(module.conv3.bias is not None))

                    temp_conv.weight.data.copy_(module.conv3.weight.data)
                    if module.conv3.bias is not None:
                        temp_conv.bias.data.copy_(module.conv3.bias.data)
                    module.conv3 = temp_conv
    def _set_model(self):
        """
        get model
        """

        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.net_type == "preresnet":
                self.pruned_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        elif self.settings.dataset in ["imagenet", "imagenet_mio"]:
            if self.settings.net_type == "resnet":
                self.pruned_model = md.ResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        else:
            assert False, "unsupported data set: {}".format(
                self.settings.dataset)

        # replace the conv layer in resnet with mask_conv
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in self.pruned_model.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                    # replace conv2
                    temp_conv = MaskConv2d(
                        in_channels=module.conv2.in_channels,
                        out_channels=module.conv2.out_channels,
                        kernel_size=module.conv2.kernel_size,
                        stride=module.conv2.stride,
                        padding=module.conv2.padding,
                        bias=(module.conv2.bias is not None))

                    temp_conv.weight.data.copy_(module.conv2.weight.data)
                    if module.conv2.bias is not None:
                        temp_conv.bias.data.copy_(module.conv2.bias.data)
                    module.conv2 = temp_conv

                    if isinstance(module, (Bottleneck)):
                        # replace conv3
                        temp_conv = MaskConv2d(
                            in_channels=module.conv3.in_channels,
                            out_channels=module.conv3.out_channels,
                            kernel_size=module.conv3.kernel_size,
                            stride=module.conv3.stride,
                            padding=module.conv3.padding,
                            bias=(module.conv3.bias is not None))

                        temp_conv.weight.data.copy_(module.conv3.weight.data)
                        if module.conv3.bias is not None:
                            temp_conv.bias.data.copy_(module.conv3.bias.data)
                        module.conv3 = temp_conv
コード例 #3
0
ファイル: main.py プロジェクト: liuyixin-louis/DCP_py3
    def replace_layer_mask_conv(self):
        """
        Replace the convolutional layer with mask convolutional layer
        """

        block_count = 0
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in self.pruned_model.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                    block_count += 1
                    layer = module.conv2
                    if block_count <= self.current_block_count and not isinstance(
                            layer, MaskConv2d):
                        temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                               out_channels=layer.out_channels,
                                               kernel_size=layer.kernel_size,
                                               stride=layer.stride,
                                               padding=layer.padding,
                                               bias=(layer.bias is not None))
                        temp_conv.weight.data.copy_(layer.weight.data)

                        if layer.bias is not None:
                            temp_conv.bias.data.copy_(layer.bias.data)
                        module.conv2 = temp_conv

                    if isinstance(module, Bottleneck):
                        layer = module.conv3
                        if block_count <= self.current_block_count and not isinstance(
                                layer, MaskConv2d):
                            temp_conv = MaskConv2d(
                                in_channels=layer.in_channels,
                                out_channels=layer.out_channels,
                                kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding,
                                bias=(layer.bias is not None))
                            temp_conv.weight.data.copy_(layer.weight.data)

                            if layer.bias is not None:
                                temp_conv.bias.data.copy_(layer.bias.data)
                            module.conv3 = temp_conv
コード例 #4
0
    def replace_layer_with_mask_conv(self, pruned_segment, module, layer_name,
                                     block_count):
        """
        Replace the pruned layer with mask convolution
        """

        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        elif layer_name == "conv":
            assert self.settings.net_type in ["vgg"], "only support vgg"
            layer = module
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            # 剪裁权重先预先设置为全0
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            elif layer_name == "conv":
                pruned_segment = self._replace_layer(net=pruned_segment,
                                                     layer=temp_conv,
                                                     layer_index=block_count)
            layer = temp_conv
        return pruned_segment, layer
def replace_layer(old_layer, init_weight, init_bias=None, keeping=False):
    """
    replace specific layer of model
    :params layer: original layer
    :params init_weight: thin_weight
    :params init_bias: thin_bias
    :returns new_layer
    """

    if hasattr(old_layer, "bias") and old_layer.bias is not None:
        bias_flag = True
    else:
        bias_flag = False
    if isinstance(old_layer, MaskConv2d) and keeping:
        new_layer = MaskConv2d(init_weight.size(1),
                               init_weight.size(0),
                               kernel_size=old_layer.kernel_size,
                               stride=old_layer.stride,
                               padding=old_layer.padding,
                               bias=bias_flag)

        new_layer.pruned_weight.data.copy_(init_weight)
        if init_bias is not None:
            new_layer.bias.data.copy_(init_bias)
        new_layer.d.copy_(old_layer.d)
        new_layer.float_weight.data.copy_(old_layer.d)

    elif isinstance(old_layer, (nn.Conv2d, MaskConv2d)):
        if old_layer.groups != 1:
            new_groups = init_weight.size(0)
            in_channels = init_weight.size(0)
            out_channels = init_weight.size(0)
        else:
            new_groups = 1
            in_channels = init_weight.size(1)
            out_channels = init_weight.size(0)

        new_layer = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=old_layer.kernel_size,
                              stride=old_layer.stride,
                              padding=old_layer.padding,
                              bias=bias_flag,
                              groups=new_groups)

        new_layer.weight.data.copy_(init_weight)
        if init_bias is not None:
            new_layer.bias.data.copy_(init_bias)

    elif isinstance(old_layer, nn.BatchNorm2d):
        weight = init_weight[0]
        mean_ = init_weight[1]
        bias = init_bias[0]
        var_ = init_bias[1]
        new_layer = nn.BatchNorm2d(weight.size(0))
        new_layer.weight.data.copy_(weight)
        assert init_bias is not None, "batch normalization needs bias"
        new_layer.bias.data.copy_(bias)
        new_layer.running_mean.copy_(mean_)
        new_layer.running_var.copy_(var_)
    elif isinstance(old_layer, nn.PReLU):
        new_layer = nn.PReLU(init_weight.size(0))
        new_layer.weight.data.copy_(init_weight)

    else:
        assert False, "unsupport layer type:" + \
                      str(type(old_layer))
    return new_layer
    def _layer_channel_selection(self,
                                 net_origin,
                                 net_pruned,
                                 aux_fc,
                                 module,
                                 block_count,
                                 layer_name="conv2"):
        """
        conduct channel selection for module
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param module: the module need to be pruned
        :param block_count: current block no.
        :param layer_name: the name of layer need to be pruned
        """

        self.logger.info(
            "|===>layer-wise channel selection: block-{}-{}".format(
                block_count, layer_name))
        # layer-wise channel selection
        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            layer = temp_conv

        # define criterion
        criterion_mse = nn.MSELoss().cuda()
        criterion_softmax = nn.CrossEntropyLoss().cuda()

        # register hook
        if layer_name == "conv2":
            hook_origin = net_origin[block_count].conv2.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv2.register_forward_hook(
                self._hook_pruned_feature)
        elif layer_name == "conv3":
            hook_origin = net_origin[block_count].conv3.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv3.register_forward_hook(
                self._hook_pruned_feature)

        net_origin_parallel = utils.data_parallel(net_origin,
                                                  self.settings.n_gpus)
        net_pruned_parallel = utils.data_parallel(net_pruned,
                                                  self.settings.n_gpus)

        # avoid computing the gradient
        for params in net_origin_parallel.parameters():
            params.requires_grad = False
        for params in net_pruned_parallel.parameters():
            params.requires_grad = False

        net_origin_parallel.eval()
        net_pruned_parallel.eval()

        layer.pruned_weight.requires_grad = True
        aux_fc.cuda()
        logger_counter = 0
        record_time = utils.AverageMeter()

        for channel in range(layer.in_channels):
            if layer.d.eq(0).sum() <= math.floor(
                    layer.in_channels * self.settings.pruning_rate):
                break

            time_start = time.time()
            cum_grad = None
            record_selection_mse_loss = utils.AverageMeter()
            record_selection_softmax_loss = utils.AverageMeter()
            record_selection_loss = utils.AverageMeter()
            img_count = 0
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.cuda()
                labels = labels.cuda()
                net_origin_parallel(images)
                output = net_pruned_parallel(images)
                softmax_loss = criterion_softmax(aux_fc(output), labels)

                origin_feature = self._concat_gpu_data(
                    self.feature_cache_origin)
                self.feature_cache_origin = {}
                pruned_feature = self._concat_gpu_data(
                    self.feature_cache_pruned)
                self.feature_cache_pruned = {}
                mse_loss = criterion_mse(pruned_feature, origin_feature)

                loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                loss.backward()
                record_selection_loss.update(loss.item(), images.size(0))
                record_selection_mse_loss.update(mse_loss.item(),
                                                 images.size(0))
                record_selection_softmax_loss.update(softmax_loss.item(),
                                                     images.size(0))

                if cum_grad is None:
                    cum_grad = layer.pruned_weight.grad.data.clone()
                else:
                    cum_grad.add_(layer.pruned_weight.grad.data)
                    layer.pruned_weight.grad = None

                img_count += images.size(0)
                if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                    break

            # write tensorboard log
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_LossAll".format(block_count, layer_name),
                value=record_selection_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_selection_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_selection_softmax_loss.avg,
                step=logger_counter)
            cum_grad.abs_()
            # calculate gradient F norm
            grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0)

            # find grad_fnorm with maximum absolute gradient
            while True:
                _, max_index = torch.topk(grad_fnorm, 1)
                if layer.d[max_index[0]] == 0:
                    layer.d[max_index[0]] = 1
                    layer.pruned_weight.data[:, max_index[
                        0], :, :] = layer.weight[:,
                                                 max_index[0], :, :].data.clone(
                                                 )
                    break
                else:
                    grad_fnorm[max_index[0]] = -1

            # fine-tune average meter
            record_finetune_softmax_loss = utils.AverageMeter()
            record_finetune_mse_loss = utils.AverageMeter()
            record_finetune_loss = utils.AverageMeter()

            record_finetune_top1_error = utils.AverageMeter()
            record_finetune_top5_error = utils.AverageMeter()

            # define optimizer
            params_list = []
            params_list.append({
                "params": layer.pruned_weight,
                "lr": self.settings.layer_wise_lr
            })
            if layer.bias is not None:
                layer.bias.requires_grad = True
                params_list.append({"params": layer.bias, "lr": 0.001})
            optimizer = torch.optim.SGD(
                params=params_list,
                weight_decay=self.settings.weight_decay,
                momentum=self.settings.momentum,
                nesterov=True)
            img_count = 0
            for epoch in range(1):
                for i, (images, labels) in enumerate(self.train_loader):
                    images = images.cuda()
                    labels = labels.cuda()
                    features = net_pruned_parallel(images)
                    net_origin_parallel(images)
                    output = aux_fc(features)
                    softmax_loss = criterion_softmax(output, labels)

                    origin_feature = self._concat_gpu_data(
                        self.feature_cache_origin)
                    self.feature_cache_origin = {}
                    pruned_feature = self._concat_gpu_data(
                        self.feature_cache_pruned)
                    self.feature_cache_pruned = {}
                    mse_loss = criterion_mse(pruned_feature, origin_feature)

                    top1_error, _, top5_error = utils.compute_singlecrop(
                        outputs=output,
                        labels=labels,
                        loss=softmax_loss,
                        top5_flag=True,
                        mean_flag=True)

                    # update parameters
                    optimizer.zero_grad()
                    loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(layer.parameters(),
                                                   max_norm=10.0)
                    layer.pruned_weight.grad.data.mul_(
                        layer.d.unsqueeze(0).unsqueeze(2).unsqueeze(
                            3).expand_as(layer.pruned_weight))
                    optimizer.step()
                    # update record info
                    record_finetune_softmax_loss.update(
                        softmax_loss.item(), images.size(0))
                    record_finetune_mse_loss.update(mse_loss.item(),
                                                    images.size(0))
                    record_finetune_loss.update(loss.item(), images.size(0))
                    record_finetune_top1_error.update(top1_error,
                                                      images.size(0))
                    record_finetune_top5_error.update(top5_error,
                                                      images.size(0))

                    img_count += images.size(0)
                    if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                        break

            layer.pruned_weight.grad = None
            if layer.bias is not None:
                layer.bias.requires_grad = False

            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_finetune_softmax_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Loss".format(block_count, layer_name),
                value=record_finetune_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_finetune_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top1Error".format(block_count, layer_name),
                value=record_finetune_top1_error.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top5Error".format(block_count, layer_name),
                value=record_finetune_top5_error.avg,
                step=logger_counter)

            # write log information to file
            self._write_log(
                dir_name=os.path.join(self.settings.save_path, "log"),
                file_name="log_block-{:0>2d}_{}.txt".format(
                    block_count, layer_name),
                log_str=
                "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n".
                format(int(layer.d.sum()), record_selection_loss.avg,
                       record_selection_mse_loss.avg,
                       record_selection_softmax_loss.avg,
                       record_finetune_loss.avg, record_finetune_mse_loss.avg,
                       record_finetune_softmax_loss.avg,
                       record_finetune_top1_error.avg,
                       record_finetune_top5_error.avg))
            log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format(
                block_count, layer_name, int(layer.d.sum()), layer.d.size(0))
            log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_selection_loss.avg, record_selection_mse_loss.avg,
                record_selection_softmax_loss.avg)
            log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_finetune_loss.avg, record_finetune_mse_loss.avg,
                record_finetune_softmax_loss.avg)
            log_str += "top1error: {:4f}\ttop5error: {:4f}".format(
                record_finetune_top1_error.avg, record_finetune_top5_error.avg)
            self.logger.info(log_str)

            logger_counter += 1
            time_interval = time.time() - time_start
            record_time.update(time_interval)

        for params in net_origin_parallel.parameters():
            params.requires_grad = True
        for params in net_pruned_parallel.parameters():
            params.requires_grad = True

        # remove hook
        hook_origin.remove()
        hook_pruned.remove()
        log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format(
            block_count, layer_name,
            str(datetime.timedelta(seconds=record_time.sum)),
            str(datetime.timedelta(seconds=record_time.avg)))
        self.logger.info(log_str)
        log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format(
            record_finetune_loss.avg, record_finetune_mse_loss.avg,
            record_finetune_softmax_loss.avg, record_finetune_top1_error.avg,
            record_finetune_top5_error.avg)
        self.logger.info(log_str)

        self.logger.info("|===>remove hook")