def segment_parallelism(self, original_segment, pruned_segment): """ Parallel setting for segment """ self.original_segment_parallel = utils.data_parallel(original_segment, self.settings.n_gpus) self.pruned_segment_parallel = utils.data_parallel(pruned_segment, self.settings.n_gpus)
def model_parallelism(self): # print('**********after create auxiliary classifier**********') # print(self.ori_segments) # print(self.pruned_segments) # print() self.ori_segments = utils.data_parallel(model=self.ori_segments, n_gpus=self.settings.n_gpus) self.pruned_segments = utils.data_parallel(model=self.pruned_segments, n_gpus=self.settings.n_gpus) self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1)
def __init__(self, model, train_loader, val_loader, settings, logger, tensorboard_logger, optimizer_state=None, run_count=0): self.settings = settings self.model = utils.data_parallel(model=model, n_gpus=self.settings.n_gpus) self.train_loader = train_loader self.val_loader = val_loader self.criterion = nn.CrossEntropyLoss().cuda() self.lr = self.settings.lr self.optimizer = torch.optim.SGD( params=self.model.parameters(), lr=self.settings.lr, momentum=self.settings.momentum, weight_decay=self.settings.weight_decay, nesterov=True) if optimizer_state is not None: self.optimizer.load_state_dict(optimizer_state) self.logger = logger self.tensorboard_logger = tensorboard_logger self.run_count = run_count
def __init__(self, pruned_model, train_loader, val_loader, settings, logger, tensorboard_logger, run_count=0): self.pruned_model = utils.data_parallel(pruned_model, settings.n_gpus) self.settings = settings self.train_loader = train_loader self.val_loader = val_loader self.logger = logger self.tensorboard_logger = tensorboard_logger self.criterion = nn.CrossEntropyLoss().cuda() self.optimizer = torch.optim.SGD( params=self.pruned_model.parameters(), lr=self.settings.lr, momentum=self.settings.momentum, weight_decay=self.settings.weight_decay, nesterov=True) self.run_count = run_count self.lr = self.settings.lr self.scalar_info = {}
def update_model(self, pruned_model, optimizer_state=None): """ update pruned model parameter :param pruned_model: pruned model """ self.optimizer = None self.pruned_model = utils.data_parallel(pruned_model, self.settings.n_gpus) self.optimizer = torch.optim.SGD( params=self.pruned_model.parameters(), lr=self.settings.network_wise_lr, momentum=self.settings.momentum, weight_decay=self.settings.weight_decay, nesterov=True) if optimizer_state: self.optimizer.load_state_dict(optimizer_state)
def model_parallelism(self): self.segments = utils.data_parallel(model=self.segments, n_gpus=self.settings.n_gpus) self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1)
def _network_split(self): """" 1. split the network into several segments with pre-define pivot set 2. create auxiliary classifiers 3. create optimizers for network segments and fcs """ net_origin = None net_pruned = None if self.settings.net_type in ["preresnet", "resnet"]: if self.settings.net_type == "preresnet": net_origin = nn.Sequential(self.ori_model.conv) net_pruned = nn.Sequential(self.pruned_model.conv) elif self.settings.net_type == "resnet": net_head = nn.Sequential( self.ori_model.conv1, self.ori_model.bn1, self.ori_model.relu, self.ori_model.maxpool) net_origin = nn.Sequential(net_head) net_head = nn.Sequential( self.pruned_model.conv1, self.pruned_model.bn1, self.pruned_model.relu, self.pruned_model.maxpool) net_pruned = nn.Sequential(net_head) self.logger.info("init shallow head done!") else: assert False, "unsupported net_type: {}".format(self.settings.net_type) block_count = 0 if self.settings.net_type in ["resnet", "preresnet"]: for ori_module, pruned_module in zip(self.ori_model.modules(), self.pruned_model.modules()): if isinstance(ori_module, (PreBasicBlock, Bottleneck, BasicBlock)): self.logger.info("enter block: {}".format(type(ori_module))) if net_origin is not None: net_origin.add_module(str(len(net_origin)), ori_module) else: net_origin = nn.Sequential(ori_module) if net_pruned is not None: net_pruned.add_module(str(len(net_pruned)), pruned_module) else: net_pruned = nn.Sequential(pruned_module) block_count += 1 # if block_count is equals to pivot_num, then create new segment if block_count in self.settings.pivot_set: self.ori_segments.append(net_origin) self.pruned_segments.append(net_pruned) net_origin = None net_pruned = None self.final_block_count = block_count self.ori_segments.append(net_origin) self.pruned_segments.append(net_pruned) # create auxiliary classifier num_classes = self.settings.n_classes in_channels = 0 for i in range(len(self.pruned_segments) - 1): if isinstance(self.pruned_segments[i][-1], (PreBasicBlock, BasicBlock)): in_channels = self.pruned_segments[i][-1].conv2.out_channels elif isinstance(self.pruned_segments[i][-1], Bottleneck): in_channels = self.pruned_segments[i][-1].conv3.out_channels assert in_channels != 0, "in_channels is zero" self.aux_fc.append(AuxClassifier(in_channels=in_channels, num_classes=num_classes)) pruned_final_fc = None if self.settings.net_type == "preresnet": pruned_final_fc = nn.Sequential(*[ self.pruned_model.bn, self.pruned_model.relu, self.pruned_model.avg_pool, View(), self.pruned_model.fc]) elif self.settings.net_type == "resnet": pruned_final_fc = nn.Sequential(*[ self.pruned_model.avgpool, View(), self.pruned_model.fc]) self.aux_fc.append(pruned_final_fc) # model parallel self.ori_segments = utils.data_parallel(model=self.ori_segments, n_gpus=self.settings.n_gpus) self.pruned_segments = utils.data_parallel(model=self.pruned_segments, n_gpus=self.settings.n_gpus) self.aux_fc = utils.data_parallel(model=self.aux_fc, n_gpus=1) # create optimizers for i in range(len(self.pruned_segments)): temp_optim = [] # add parameters in segmenets into optimizer # from the i-th optimizer contains [0:i] segments for j in range(i + 1): temp_optim.append({'params': self.pruned_segments[j].parameters(), 'lr': self.settings.segment_wise_lr}) # optimizer for segments and fc temp_seg_optim = torch.optim.SGD( temp_optim, momentum=self.settings.momentum, weight_decay=self.settings.weight_decay, nesterov=True) temp_fc_optim = torch.optim.SGD( params=self.aux_fc[i].parameters(), lr=self.settings.segment_wise_lr, momentum=self.settings.momentum, weight_decay=self.settings.weight_decay, nesterov=True) self.seg_optimizer.append(temp_seg_optim) self.fc_optimizer.append(temp_fc_optim)
def _layer_channel_selection(self, net_origin, net_pruned, aux_fc, module, block_count, layer_name="conv2"): """ conduct channel selection for module :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param module: the module need to be pruned :param block_count: current block no. :param layer_name: the name of layer need to be pruned """ self.logger.info( "|===>layer-wise channel selection: block-{}-{}".format( block_count, layer_name)) # layer-wise channel selection if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv layer = temp_conv # define criterion criterion_mse = nn.MSELoss().cuda() criterion_softmax = nn.CrossEntropyLoss().cuda() # register hook if layer_name == "conv2": hook_origin = net_origin[block_count].conv2.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv2.register_forward_hook( self._hook_pruned_feature) elif layer_name == "conv3": hook_origin = net_origin[block_count].conv3.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv3.register_forward_hook( self._hook_pruned_feature) net_origin_parallel = utils.data_parallel(net_origin, self.settings.n_gpus) net_pruned_parallel = utils.data_parallel(net_pruned, self.settings.n_gpus) # avoid computing the gradient for params in net_origin_parallel.parameters(): params.requires_grad = False for params in net_pruned_parallel.parameters(): params.requires_grad = False net_origin_parallel.eval() net_pruned_parallel.eval() layer.pruned_weight.requires_grad = True aux_fc.cuda() logger_counter = 0 record_time = utils.AverageMeter() for channel in range(layer.in_channels): if layer.d.eq(0).sum() <= math.floor( layer.in_channels * self.settings.pruning_rate): break time_start = time.time() cum_grad = None record_selection_mse_loss = utils.AverageMeter() record_selection_softmax_loss = utils.AverageMeter() record_selection_loss = utils.AverageMeter() img_count = 0 for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() net_origin_parallel(images) output = net_pruned_parallel(images) softmax_loss = criterion_softmax(aux_fc(output), labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() record_selection_loss.update(loss.item(), images.size(0)) record_selection_mse_loss.update(mse_loss.item(), images.size(0)) record_selection_softmax_loss.update(softmax_loss.item(), images.size(0)) if cum_grad is None: cum_grad = layer.pruned_weight.grad.data.clone() else: cum_grad.add_(layer.pruned_weight.grad.data) layer.pruned_weight.grad = None img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break # write tensorboard log self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_LossAll".format(block_count, layer_name), value=record_selection_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_selection_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_selection_softmax_loss.avg, step=logger_counter) cum_grad.abs_() # calculate gradient F norm grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0) # find grad_fnorm with maximum absolute gradient while True: _, max_index = torch.topk(grad_fnorm, 1) if layer.d[max_index[0]] == 0: layer.d[max_index[0]] = 1 layer.pruned_weight.data[:, max_index[ 0], :, :] = layer.weight[:, max_index[0], :, :].data.clone( ) break else: grad_fnorm[max_index[0]] = -1 # fine-tune average meter record_finetune_softmax_loss = utils.AverageMeter() record_finetune_mse_loss = utils.AverageMeter() record_finetune_loss = utils.AverageMeter() record_finetune_top1_error = utils.AverageMeter() record_finetune_top5_error = utils.AverageMeter() # define optimizer params_list = [] params_list.append({ "params": layer.pruned_weight, "lr": self.settings.layer_wise_lr }) if layer.bias is not None: layer.bias.requires_grad = True params_list.append({"params": layer.bias, "lr": 0.001}) optimizer = torch.optim.SGD( params=params_list, weight_decay=self.settings.weight_decay, momentum=self.settings.momentum, nesterov=True) img_count = 0 for epoch in range(1): for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() features = net_pruned_parallel(images) net_origin_parallel(images) output = aux_fc(features) softmax_loss = criterion_softmax(output, labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) top1_error, _, top5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=softmax_loss, top5_flag=True, mean_flag=True) # update parameters optimizer.zero_grad() loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=10.0) layer.pruned_weight.grad.data.mul_( layer.d.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(layer.pruned_weight)) optimizer.step() # update record info record_finetune_softmax_loss.update( softmax_loss.item(), images.size(0)) record_finetune_mse_loss.update(mse_loss.item(), images.size(0)) record_finetune_loss.update(loss.item(), images.size(0)) record_finetune_top1_error.update(top1_error, images.size(0)) record_finetune_top5_error.update(top5_error, images.size(0)) img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break layer.pruned_weight.grad = None if layer.bias is not None: layer.bias.requires_grad = False self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_finetune_softmax_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Loss".format(block_count, layer_name), value=record_finetune_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_finetune_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top1Error".format(block_count, layer_name), value=record_finetune_top1_error.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top5Error".format(block_count, layer_name), value=record_finetune_top5_error.avg, step=logger_counter) # write log information to file self._write_log( dir_name=os.path.join(self.settings.save_path, "log"), file_name="log_block-{:0>2d}_{}.txt".format( block_count, layer_name), log_str= "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n". format(int(layer.d.sum()), record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg, record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg)) log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format( block_count, layer_name, int(layer.d.sum()), layer.d.size(0)) log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg) log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg) log_str += "top1error: {:4f}\ttop5error: {:4f}".format( record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) logger_counter += 1 time_interval = time.time() - time_start record_time.update(time_interval) for params in net_origin_parallel.parameters(): params.requires_grad = True for params in net_pruned_parallel.parameters(): params.requires_grad = True # remove hook hook_origin.remove() hook_pruned.remove() log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format( block_count, layer_name, str(datetime.timedelta(seconds=record_time.sum)), str(datetime.timedelta(seconds=record_time.avg))) self.logger.info(log_str) log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) self.logger.info("|===>remove hook")