def val(self, epoch): """ validation """ top1_error = utils.AverageMeter() top1_loss = utils.AverageMeter() top5_error = utils.AverageMeter() self.pruned_model.eval() iters = len(self.val_loader) start_time = time.time() end_time = start_time with torch.no_grad(): for i, (images, labels) in enumerate(self.val_loader): start_time = time.time() data_time = start_time - end_time if self.settings.n_gpus == 1: images = images.cuda() labels = labels.cuda() output, loss = self.forward(images, labels) single_error, single_loss, single5_error = utils.compute_singlecrop( outputs=output, loss=loss, labels=labels, top5_flag=True, mean_flag=True) top1_error.update(single_error, images.size(0)) top1_loss.update(single_loss, images.size(0)) top5_error.update(single5_error, images.size(0)) end_time = time.time() iter_time = end_time - start_time utils.print_result( epoch, self.settings.network_wise_n_epochs, i + 1, iters, self.network_wise_lr, data_time, iter_time, single_error, single_loss, top5error=single5_error, mode="Validation", logger=self.logger) self.scalar_info['network_wise_fine_tune_val_top1_error'] = top1_error.avg self.scalar_info['network_wise_fine_tune_val_top5_error'] = top5_error.avg self.scalar_info['network_wise_fine_tune_val_loss'] = top1_loss.avg if self.tensorboard_logger is not None: for tag, value in self.scalar_info.items(): self.tensorboard_logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 self.logger.info( "|===>Validation Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}".format(top1_error.avg, top1_loss.avg, top5_error.avg)) return top1_error.avg, top1_loss.avg, top5_error.avg
def val(self, epoch): """ validation :param epoch: index of epoch """ top1_error = [] top5_error = [] top1_loss = [] num_segments = len(self.pruned_segments) for i in range(num_segments): self.pruned_segments[i].eval() self.aux_fc[i].eval() top1_error.append(utils.AverageMeter()) top5_error.append(utils.AverageMeter()) top1_loss.append(utils.AverageMeter()) iters = len(self.val_loader) start_time = time.time() end_time = start_time with torch.no_grad(): for i, (images, labels) in enumerate(self.val_loader): start_time = time.time() data_time = start_time - end_time if self.settings.n_gpus == 1: images = images.cuda() labels = labels.cuda() outputs, losses = self.forward(images, labels) # compute loss and error rate single_error, single_loss, single5_error = utils.compute_singlecrop( outputs=outputs, labels=labels, loss=losses, top5_flag=True, mean_flag=True) for j in range(num_segments): top1_error[j].update(single_error[j], images.size(0)) top5_error[j].update(single5_error[j], images.size(0)) top1_loss[j].update(single_loss[j], images.size(0)) end_time = time.time() iter_time = end_time - start_time utils.print_result(epoch, self.settings.segment_wise_n_epochs, i + 1, iters, self.segment_wise_lr, data_time, iter_time, single_error, single_loss, mode="Validation", logger=self.logger) top1_error_list, top1_loss_list, top5_error_list = self._convert_results( top1_error=top1_error, top1_loss=top1_loss, top5_error=top5_error) if self.logger is not None: for i in range(num_segments): self.tensorboard_logger.scalar_summary( "segment_wise_fine_tune_val_top1_error_{:d}".format(i), top1_error[i].avg, self.run_count) self.tensorboard_logger.scalar_summary( "segment_wise_fine_tune_val_top5_error_{:d}".format(i), top5_error[i].avg, self.run_count) self.tensorboard_logger.scalar_summary( "segment_wise_fine_tune_val_loss_{:d}".format(i), top1_loss[i].avg, self.run_count) self.run_count += 1 self.logger.info("|===>Validation Error: {:4f}/{:4f}, Loss: {:4f}".format( top1_error[-1].avg, top5_error[-1].avg, top1_loss[-1].avg)) return top1_error_list, top1_loss_list, top5_error_list
def train(self, epoch, index): """ train :param epoch: index of epoch :param index: index of segment """ iters = len(self.train_loader) self.update_lr(epoch) top1_error = [] top5_error = [] top1_loss = [] num_segments = len(self.pruned_segments) for i in range(num_segments): if self.pruned_segments[i] is not None: self.pruned_segments[i].train() if i != index and i != num_segments - 1: self.aux_fc[i].eval() else: self.aux_fc[i].train() top1_error.append(utils.AverageMeter()) top5_error.append(utils.AverageMeter()) top1_loss.append(utils.AverageMeter()) start_time = time.time() end_time = start_time for i, (images, labels) in enumerate(self.train_loader): start_time = time.time() data_time = start_time - end_time if self.settings.n_gpus == 1: images = images.cuda() labels = labels.cuda() # forward outputs, losses = self.forward(images, labels) # backward self.backward(losses, index) # compute loss and error rate single_error, single_loss, single5_error = utils.compute_singlecrop( outputs=outputs, labels=labels, loss=losses, top5_flag=True, mean_flag=True) for j in range(num_segments): if self.pruned_segments[j] is not None: top1_error[j].update(single_error[j], images.size(0)) top5_error[j].update(single5_error[j], images.size(0)) top1_loss[j].update(single_loss[j], images.size(0)) end_time = time.time() iter_time = end_time - start_time utils.print_result(epoch, self.settings.segment_wise_n_epochs, i + 1, iters, self.segment_wise_lr, data_time, iter_time, single_error, single_loss, mode="Train", logger=self.logger) top1_error_list, top1_loss_list, top5_error_list = self._convert_results( top1_error=top1_error, top1_loss=top1_loss, top5_error=top5_error) if self.logger is not None: for i in range(num_segments): if self.pruned_segments[i] is not None: self.v_logger.log_scalar( "segment_wise_fine_tune_train_top1_error_{:d}".format( i), top1_error[i].avg, self.run_count) self.v_logger.log_scalar( "segment_wise_fine_tune_train_top5_error_{:d}".format( i), top5_error[i].avg, self.run_count) self.v_logger.log_scalar( "segment_wise_fine_tune_train_loss_{:d}".format(i), top1_loss[i].avg, self.run_count) self.v_logger.log_scalar("segment_wise_fine_tune_lr", self.segment_wise_lr, self.run_count) self.logger.info( "|===>Training Error: {:4f}/{:4f}, Loss: {:4f}".format( top1_error[-1].avg, top5_error[-1].avg, top1_loss[-1].avg)) return top1_error_list, top1_loss_list, top5_error_list
def train(self, epoch): """ training """ top1_error = utils.AverageMeter() top1_loss = utils.AverageMeter() top5_error = utils.AverageMeter() iters = len(self.train_loader) self.update_lr(epoch) # switch to train mode self.pruned_model.train() start_time = time.time() end_time = start_time for i, (images, labels) in enumerate(self.train_loader): start_time = time.time() data_time = start_time - end_time if self.settings.n_gpus == 1: images = images.cuda() labels = labels.cuda() output, loss = self.forward(images, labels) self.backward(loss) single_error, single_loss, single5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=loss, top5_flag=True, mean_flag=True) top1_error.update(single_error, images.size(0)) top1_loss.update(single_loss, images.size(0)) top5_error.update(single5_error, images.size(0)) end_time = time.time() iter_time = end_time - start_time utils.print_result(epoch, self.settings.network_wise_n_epochs, i + 1, iters, self.network_wise_lr, data_time, iter_time, single_error, single_loss, top5error=single5_error, mode="Train", logger=self.logger) self.scalar_info[ 'network_wise_fine_tune_train_top1_error'] = top1_error.avg self.scalar_info[ 'network_wise_fine_tune_train_top5_error'] = top5_error.avg self.scalar_info['network_wise_fine_tune_train_loss'] = top1_loss.avg self.scalar_info['network_wise_fine_tune_lr'] = self.network_wise_lr if self.v_logger is not None: for tag, value in list(self.scalar_info.items()): self.v_logger.log_scalar(tag, value, self.run_count) self.scalar_info = {} self.logger.info( "|===>Training Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}". format(top1_error.avg, top1_loss.avg, top5_error.avg)) return top1_error.avg, top1_loss.avg, top5_error.avg
def _layer_channel_selection(self, net_origin, net_pruned, aux_fc, module, block_count, layer_name="conv2"): """ conduct channel selection for module :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param module: the module need to be pruned :param block_count: current block no. :param layer_name: the name of layer need to be pruned """ self.logger.info( "|===>layer-wise channel selection: block-{}-{}".format( block_count, layer_name)) # layer-wise channel selection if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv layer = temp_conv # define criterion criterion_mse = nn.MSELoss().cuda() criterion_softmax = nn.CrossEntropyLoss().cuda() # register hook if layer_name == "conv2": hook_origin = net_origin[block_count].conv2.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv2.register_forward_hook( self._hook_pruned_feature) elif layer_name == "conv3": hook_origin = net_origin[block_count].conv3.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv3.register_forward_hook( self._hook_pruned_feature) net_origin_parallel = utils.data_parallel(net_origin, self.settings.n_gpus) net_pruned_parallel = utils.data_parallel(net_pruned, self.settings.n_gpus) # avoid computing the gradient for params in net_origin_parallel.parameters(): params.requires_grad = False for params in net_pruned_parallel.parameters(): params.requires_grad = False net_origin_parallel.eval() net_pruned_parallel.eval() layer.pruned_weight.requires_grad = True aux_fc.cuda() logger_counter = 0 record_time = utils.AverageMeter() for channel in range(layer.in_channels): if layer.d.eq(0).sum() <= math.floor( layer.in_channels * self.settings.pruning_rate): break time_start = time.time() cum_grad = None record_selection_mse_loss = utils.AverageMeter() record_selection_softmax_loss = utils.AverageMeter() record_selection_loss = utils.AverageMeter() img_count = 0 for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() net_origin_parallel(images) output = net_pruned_parallel(images) softmax_loss = criterion_softmax(aux_fc(output), labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() record_selection_loss.update(loss.item(), images.size(0)) record_selection_mse_loss.update(mse_loss.item(), images.size(0)) record_selection_softmax_loss.update(softmax_loss.item(), images.size(0)) if cum_grad is None: cum_grad = layer.pruned_weight.grad.data.clone() else: cum_grad.add_(layer.pruned_weight.grad.data) layer.pruned_weight.grad = None img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break # write tensorboard log self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_LossAll".format(block_count, layer_name), value=record_selection_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_selection_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_selection_softmax_loss.avg, step=logger_counter) cum_grad.abs_() # calculate gradient F norm grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0) # find grad_fnorm with maximum absolute gradient while True: _, max_index = torch.topk(grad_fnorm, 1) if layer.d[max_index[0]] == 0: layer.d[max_index[0]] = 1 layer.pruned_weight.data[:, max_index[ 0], :, :] = layer.weight[:, max_index[0], :, :].data.clone( ) break else: grad_fnorm[max_index[0]] = -1 # fine-tune average meter record_finetune_softmax_loss = utils.AverageMeter() record_finetune_mse_loss = utils.AverageMeter() record_finetune_loss = utils.AverageMeter() record_finetune_top1_error = utils.AverageMeter() record_finetune_top5_error = utils.AverageMeter() # define optimizer params_list = [] params_list.append({ "params": layer.pruned_weight, "lr": self.settings.layer_wise_lr }) if layer.bias is not None: layer.bias.requires_grad = True params_list.append({"params": layer.bias, "lr": 0.001}) optimizer = torch.optim.SGD( params=params_list, weight_decay=self.settings.weight_decay, momentum=self.settings.momentum, nesterov=True) img_count = 0 for epoch in range(1): for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() features = net_pruned_parallel(images) net_origin_parallel(images) output = aux_fc(features) softmax_loss = criterion_softmax(output, labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) top1_error, _, top5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=softmax_loss, top5_flag=True, mean_flag=True) # update parameters optimizer.zero_grad() loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=10.0) layer.pruned_weight.grad.data.mul_( layer.d.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(layer.pruned_weight)) optimizer.step() # update record info record_finetune_softmax_loss.update( softmax_loss.item(), images.size(0)) record_finetune_mse_loss.update(mse_loss.item(), images.size(0)) record_finetune_loss.update(loss.item(), images.size(0)) record_finetune_top1_error.update(top1_error, images.size(0)) record_finetune_top5_error.update(top5_error, images.size(0)) img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break layer.pruned_weight.grad = None if layer.bias is not None: layer.bias.requires_grad = False self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_finetune_softmax_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Loss".format(block_count, layer_name), value=record_finetune_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_finetune_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top1Error".format(block_count, layer_name), value=record_finetune_top1_error.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top5Error".format(block_count, layer_name), value=record_finetune_top5_error.avg, step=logger_counter) # write log information to file self._write_log( dir_name=os.path.join(self.settings.save_path, "log"), file_name="log_block-{:0>2d}_{}.txt".format( block_count, layer_name), log_str= "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n". format(int(layer.d.sum()), record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg, record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg)) log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format( block_count, layer_name, int(layer.d.sum()), layer.d.size(0)) log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg) log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg) log_str += "top1error: {:4f}\ttop5error: {:4f}".format( record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) logger_counter += 1 time_interval = time.time() - time_start record_time.update(time_interval) for params in net_origin_parallel.parameters(): params.requires_grad = True for params in net_pruned_parallel.parameters(): params.requires_grad = True # remove hook hook_origin.remove() hook_pruned.remove() log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format( block_count, layer_name, str(datetime.timedelta(seconds=record_time.sum)), str(datetime.timedelta(seconds=record_time.avg))) self.logger.info(log_str) log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) self.logger.info("|===>remove hook")