def replace_layer_with_mask_conv_resnet(self): """ Replace the conv layer in resnet with mask_conv for ResNet """ for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): # replace conv2 temp_conv = MaskConv2d(in_channels=module.conv2.in_channels, out_channels=module.conv2.out_channels, kernel_size=module.conv2.kernel_size, stride=module.conv2.stride, padding=module.conv2.padding, bias=(module.conv2.bias is not None)) temp_conv.weight.data.copy_(module.conv2.weight.data) if module.conv2.bias is not None: temp_conv.bias.data.copy_(module.conv2.bias.data) module.conv2 = temp_conv if isinstance(module, Bottleneck): # replace conv3 temp_conv = MaskConv2d( in_channels=module.conv3.in_channels, out_channels=module.conv3.out_channels, kernel_size=module.conv3.kernel_size, stride=module.conv3.stride, padding=module.conv3.padding, bias=(module.conv3.bias is not None)) temp_conv.weight.data.copy_(module.conv3.weight.data) if module.conv3.bias is not None: temp_conv.bias.data.copy_(module.conv3.bias.data) module.conv3 = temp_conv
def _set_model(self): """ get model """ if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.net_type == "preresnet": self.pruned_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) elif self.settings.dataset in ["imagenet", "imagenet_mio"]: if self.settings.net_type == "resnet": self.pruned_model = md.ResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) else: assert False, "unsupported data set: {}".format( self.settings.dataset) # replace the conv layer in resnet with mask_conv if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): # replace conv2 temp_conv = MaskConv2d( in_channels=module.conv2.in_channels, out_channels=module.conv2.out_channels, kernel_size=module.conv2.kernel_size, stride=module.conv2.stride, padding=module.conv2.padding, bias=(module.conv2.bias is not None)) temp_conv.weight.data.copy_(module.conv2.weight.data) if module.conv2.bias is not None: temp_conv.bias.data.copy_(module.conv2.bias.data) module.conv2 = temp_conv if isinstance(module, (Bottleneck)): # replace conv3 temp_conv = MaskConv2d( in_channels=module.conv3.in_channels, out_channels=module.conv3.out_channels, kernel_size=module.conv3.kernel_size, stride=module.conv3.stride, padding=module.conv3.padding, bias=(module.conv3.bias is not None)) temp_conv.weight.data.copy_(module.conv3.weight.data) if module.conv3.bias is not None: temp_conv.bias.data.copy_(module.conv3.bias.data) module.conv3 = temp_conv
def replace_layer_mask_conv(self): """ Replace the convolutional layer with mask convolutional layer """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): block_count += 1 layer = module.conv2 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv2 = temp_conv if isinstance(module, Bottleneck): layer = module.conv3 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d( in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv3 = temp_conv
def replace_layer_with_mask_conv(self, pruned_segment, module, layer_name, block_count): """ Replace the pruned layer with mask convolution """ if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 elif layer_name == "conv": assert self.settings.net_type in ["vgg"], "only support vgg" layer = module else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) # 剪裁权重先预先设置为全0 temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv elif layer_name == "conv": pruned_segment = self._replace_layer(net=pruned_segment, layer=temp_conv, layer_index=block_count) layer = temp_conv return pruned_segment, layer
def replace_layer(old_layer, init_weight, init_bias=None, keeping=False): """ replace specific layer of model :params layer: original layer :params init_weight: thin_weight :params init_bias: thin_bias :returns new_layer """ if hasattr(old_layer, "bias") and old_layer.bias is not None: bias_flag = True else: bias_flag = False if isinstance(old_layer, MaskConv2d) and keeping: new_layer = MaskConv2d(init_weight.size(1), init_weight.size(0), kernel_size=old_layer.kernel_size, stride=old_layer.stride, padding=old_layer.padding, bias=bias_flag) new_layer.pruned_weight.data.copy_(init_weight) if init_bias is not None: new_layer.bias.data.copy_(init_bias) new_layer.d.copy_(old_layer.d) new_layer.float_weight.data.copy_(old_layer.d) elif isinstance(old_layer, (nn.Conv2d, MaskConv2d)): if old_layer.groups != 1: new_groups = init_weight.size(0) in_channels = init_weight.size(0) out_channels = init_weight.size(0) else: new_groups = 1 in_channels = init_weight.size(1) out_channels = init_weight.size(0) new_layer = nn.Conv2d(in_channels, out_channels, kernel_size=old_layer.kernel_size, stride=old_layer.stride, padding=old_layer.padding, bias=bias_flag, groups=new_groups) new_layer.weight.data.copy_(init_weight) if init_bias is not None: new_layer.bias.data.copy_(init_bias) elif isinstance(old_layer, nn.BatchNorm2d): weight = init_weight[0] mean_ = init_weight[1] bias = init_bias[0] var_ = init_bias[1] new_layer = nn.BatchNorm2d(weight.size(0)) new_layer.weight.data.copy_(weight) assert init_bias is not None, "batch normalization needs bias" new_layer.bias.data.copy_(bias) new_layer.running_mean.copy_(mean_) new_layer.running_var.copy_(var_) elif isinstance(old_layer, nn.PReLU): new_layer = nn.PReLU(init_weight.size(0)) new_layer.weight.data.copy_(init_weight) else: assert False, "unsupport layer type:" + \ str(type(old_layer)) return new_layer
def _layer_channel_selection(self, net_origin, net_pruned, aux_fc, module, block_count, layer_name="conv2"): """ conduct channel selection for module :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param module: the module need to be pruned :param block_count: current block no. :param layer_name: the name of layer need to be pruned """ self.logger.info( "|===>layer-wise channel selection: block-{}-{}".format( block_count, layer_name)) # layer-wise channel selection if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv layer = temp_conv # define criterion criterion_mse = nn.MSELoss().cuda() criterion_softmax = nn.CrossEntropyLoss().cuda() # register hook if layer_name == "conv2": hook_origin = net_origin[block_count].conv2.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv2.register_forward_hook( self._hook_pruned_feature) elif layer_name == "conv3": hook_origin = net_origin[block_count].conv3.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv3.register_forward_hook( self._hook_pruned_feature) net_origin_parallel = utils.data_parallel(net_origin, self.settings.n_gpus) net_pruned_parallel = utils.data_parallel(net_pruned, self.settings.n_gpus) # avoid computing the gradient for params in net_origin_parallel.parameters(): params.requires_grad = False for params in net_pruned_parallel.parameters(): params.requires_grad = False net_origin_parallel.eval() net_pruned_parallel.eval() layer.pruned_weight.requires_grad = True aux_fc.cuda() logger_counter = 0 record_time = utils.AverageMeter() for channel in range(layer.in_channels): if layer.d.eq(0).sum() <= math.floor( layer.in_channels * self.settings.pruning_rate): break time_start = time.time() cum_grad = None record_selection_mse_loss = utils.AverageMeter() record_selection_softmax_loss = utils.AverageMeter() record_selection_loss = utils.AverageMeter() img_count = 0 for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() net_origin_parallel(images) output = net_pruned_parallel(images) softmax_loss = criterion_softmax(aux_fc(output), labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() record_selection_loss.update(loss.item(), images.size(0)) record_selection_mse_loss.update(mse_loss.item(), images.size(0)) record_selection_softmax_loss.update(softmax_loss.item(), images.size(0)) if cum_grad is None: cum_grad = layer.pruned_weight.grad.data.clone() else: cum_grad.add_(layer.pruned_weight.grad.data) layer.pruned_weight.grad = None img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break # write tensorboard log self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_LossAll".format(block_count, layer_name), value=record_selection_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_selection_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_selection_softmax_loss.avg, step=logger_counter) cum_grad.abs_() # calculate gradient F norm grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0) # find grad_fnorm with maximum absolute gradient while True: _, max_index = torch.topk(grad_fnorm, 1) if layer.d[max_index[0]] == 0: layer.d[max_index[0]] = 1 layer.pruned_weight.data[:, max_index[ 0], :, :] = layer.weight[:, max_index[0], :, :].data.clone( ) break else: grad_fnorm[max_index[0]] = -1 # fine-tune average meter record_finetune_softmax_loss = utils.AverageMeter() record_finetune_mse_loss = utils.AverageMeter() record_finetune_loss = utils.AverageMeter() record_finetune_top1_error = utils.AverageMeter() record_finetune_top5_error = utils.AverageMeter() # define optimizer params_list = [] params_list.append({ "params": layer.pruned_weight, "lr": self.settings.layer_wise_lr }) if layer.bias is not None: layer.bias.requires_grad = True params_list.append({"params": layer.bias, "lr": 0.001}) optimizer = torch.optim.SGD( params=params_list, weight_decay=self.settings.weight_decay, momentum=self.settings.momentum, nesterov=True) img_count = 0 for epoch in range(1): for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() features = net_pruned_parallel(images) net_origin_parallel(images) output = aux_fc(features) softmax_loss = criterion_softmax(output, labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) top1_error, _, top5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=softmax_loss, top5_flag=True, mean_flag=True) # update parameters optimizer.zero_grad() loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=10.0) layer.pruned_weight.grad.data.mul_( layer.d.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(layer.pruned_weight)) optimizer.step() # update record info record_finetune_softmax_loss.update( softmax_loss.item(), images.size(0)) record_finetune_mse_loss.update(mse_loss.item(), images.size(0)) record_finetune_loss.update(loss.item(), images.size(0)) record_finetune_top1_error.update(top1_error, images.size(0)) record_finetune_top5_error.update(top5_error, images.size(0)) img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break layer.pruned_weight.grad = None if layer.bias is not None: layer.bias.requires_grad = False self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_finetune_softmax_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Loss".format(block_count, layer_name), value=record_finetune_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_finetune_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top1Error".format(block_count, layer_name), value=record_finetune_top1_error.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top5Error".format(block_count, layer_name), value=record_finetune_top5_error.avg, step=logger_counter) # write log information to file self._write_log( dir_name=os.path.join(self.settings.save_path, "log"), file_name="log_block-{:0>2d}_{}.txt".format( block_count, layer_name), log_str= "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n". format(int(layer.d.sum()), record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg, record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg)) log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format( block_count, layer_name, int(layer.d.sum()), layer.d.size(0)) log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg) log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg) log_str += "top1error: {:4f}\ttop5error: {:4f}".format( record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) logger_counter += 1 time_interval = time.time() - time_start record_time.update(time_interval) for params in net_origin_parallel.parameters(): params.requires_grad = True for params in net_pruned_parallel.parameters(): params.requires_grad = True # remove hook hook_origin.remove() hook_pruned.remove() log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format( block_count, layer_name, str(datetime.timedelta(seconds=record_time.sum)), str(datetime.timedelta(seconds=record_time.avg))) self.logger.info(log_str) log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) self.logger.info("|===>remove hook")