class Experiment(object): """ Run experiments with pre-defined pipeline """ def __init__(self, options=None, conf_path=None): self.settings = options or Option(conf_path) self.checkpoint = None self.train_loader = None self.val_loader = None self.original_model = None self.pruned_model = None self.aux_fc_state = None self.aux_fc_opt_state = None self.seg_opt_state = None self.current_pivot_index = None self.epoch = 0 os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu self.settings.set_save_path() write_settings(self.settings) self.logger = get_logger(self.settings.save_path, "dcp") self.tensorboard_logger = TensorboardLogger(self.settings.save_path) self.settings.copy_code(self.logger, src=os.path.abspath('./'), dst=os.path.join(self.settings.save_path, 'code')) self.logger.info("|===>Result will be saved at {}".format( self.settings.save_path)) self.prepare() def prepare(self): """ Preparing experiments """ self._set_gpu() self._set_dataloader() self._set_model() self._cal_pivot() self._set_checkpoint() self._set_trainier() self._set_channel_selection() torch.set_num_threads(4) def _set_gpu(self): """ Initialize the seed of random number generator """ # set torch seed # init random seed torch.manual_seed(self.settings.seed) torch.cuda.manual_seed(self.settings.seed) torch.cuda.set_device(0) cudnn.benchmark = True def _set_dataloader(self): """ Create train loader and validation loader for channel pruning """ if 'cifar' in self.settings.dataset: self.train_loader, self.val_loader = get_cifar_dataloader( self.settings.dataset, self.settings.batch_size, self.settings.n_threads, self.settings.data_path, self.logger) elif self.settings.dataset in ['imagenet']: self.train_loader, self.val_loader = get_imagenet_dataloader( self.settings.dataset, self.settings.batch_size, self.settings.n_threads, self.settings.data_path, self.logger) elif self.settings.dataset in ['sub_imagenet']: num_samples_per_category = self.settings.max_samples // 1000 self.train_loader, self.val_loader = get_sub_imagenet_dataloader( self.settings.dataset, self.settings.batch_size, num_samples_per_category, self.settings.n_threads, self.settings.data_path, self.logger) def _set_trainier(self): """ Initialize segment-wise trainer trainer """ # initialize segment-wise trainer self.segment_wise_trainer = SegmentWiseTrainer( original_model=self.original_model, pruned_model=self.pruned_model, train_loader=self.train_loader, val_loader=self.val_loader, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger) if self.aux_fc_state is not None: self.segment_wise_trainer.update_aux_fc(self.aux_fc_state) def _set_channel_selection(self): self.layer_channel_selection = LayerChannelSelection( self.segment_wise_trainer, self.train_loader, self.val_loader, self.settings, self.checkpoint, self.logger, self.tensorboard_logger) def _set_model(self): """ Available model cifar: preresnet imagenet: resnet """ self.original_model, self.test_input = get_model( self.settings.dataset, self.settings.net_type, self.settings.depth, self.settings.n_classes) self.pruned_model, self.test_input = get_model(self.settings.dataset, self.settings.net_type, self.settings.depth, self.settings.n_classes) def _set_checkpoint(self): """ Load pre-trained model or resume checkpoint """ assert self.original_model is not None and self.pruned_model is not None, "please create model first" self.checkpoint = DCPCheckPoint(self.settings.save_path, self.logger) self._load_pretrained() self._load_resume() def _load_pretrained(self): """ Load pre-trained model """ if self.settings.pretrained is not None: check_point_params = torch.load(self.settings.pretrained) model_state = check_point_params['model'] self.aux_fc_state = check_point_params['aux_fc'] self.original_model = self.checkpoint.load_state( self.original_model, model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, model_state) self.logger.info("|===>load restrain file: {}".format( self.settings.pretrained)) def _load_resume(self): """ Load resume checkpoint """ if self.settings.resume is not None: check_point_params = torch.load(self.settings.resume) original_model_state = check_point_params["original_model"] pruned_model_state = check_point_params["pruned_model"] self.aux_fc_state = check_point_params["aux_fc"] self.aux_fc_opt_state = check_point_params["aux_fc_opt"] self.seg_opt_state = check_point_params["seg_opt"] self.current_pivot_index = check_point_params["current_pivot"] self.segment_num = check_point_params["segment_num"] self.current_block_count = check_point_params["block_num"] if self.current_block_count > 0: self.replace_layer_mask_conv() self.original_model = self.checkpoint.load_state( self.original_model, original_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.logger.info("|===>load resume file: {}".format( self.settings.resume)) def _cal_pivot(self): """ Calculate the inserted layer for additional loss """ self.num_segments, self.settings.pivot_set = cal_pivot( self.settings.n_losses, self.settings.net_type, self.settings.depth, self.logger) def replace_layer_mask_conv(self): """ Replace the convolutional layer with mask convolutional layer """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): block_count += 1 layer = module.conv2 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv2 = temp_conv if isinstance(module, Bottleneck): layer = module.conv3 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d( in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv3 = temp_conv def channel_selection_for_network(self): """ Conduct channel selection """ # get testing error self.segment_wise_trainer.val(0) time_start = time.time() restart_index = None # find restart segment index if self.current_pivot_index: if self.current_pivot_index in self.settings.pivot_set: restart_index = self.settings.pivot_set.index( self.current_pivot_index) else: restart_index = len(self.settings.pivot_set) for index in range(self.num_segments): if restart_index is not None: if index < restart_index: continue elif index == restart_index: if self.current_block_count == self.current_pivot_index: continue if index == self.num_segments - 1: self.current_pivot_index = self.segment_wise_trainer.final_block_count else: self.current_pivot_index = self.settings.pivot_set[index] # conduct channel selection # contains [0:index] segments original_segment_list = [] pruned_segment_list = [] for j in range(index + 1): original_segment_list += utils.model2list( self.segment_wise_trainer.ori_segments[j]) pruned_segment_list += utils.model2list( self.segment_wise_trainer.pruned_segments[j]) original_segment = nn.Sequential(*original_segment_list) pruned_segment = nn.Sequential(*pruned_segment_list) net_pruned = self.channel_selection_for_one_segment( original_segment, pruned_segment, self.segment_wise_trainer.aux_fc[index], self.current_pivot_index, index) self.logger.info(self.segment_wise_trainer.pruned_segments) self.logger.info(net_pruned) self.logger.info(self.original_model) self.logger.info(self.pruned_model) self.segment_wise_trainer.val(0) self.current_pivot_index = None self.checkpoint.save_dcp_model( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.final_block_count, index=self.num_segments) self.pruning() self.checkpoint.save_dcp_model( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.final_block_count, index=self.num_segments + 1) time_interval = time.time() - time_start log_str = "cost time: {}".format( str(datetime.timedelta(seconds=time_interval))) self.logger.info(log_str) def channel_selection_for_resnet_like_segment(self, original_segment, pruned_segment, aux_fc, pivot_index, index): """ Conduct channel selection for one segment in resnet """ block_count = 0 for module in pruned_segment.modules(): if isinstance(module, (PreBasicBlock, BasicBlock)): block_count += 1 # We will not prune the pruned blocks again if not isinstance(module.conv2, MaskConv2d): self.layer_channel_selection.channel_selection_for_one_layer( original_segment, pruned_segment, aux_fc, module, block_count, "conv2") self.logger.info("|===>checking layer type: {}".format( type(module.conv2))) self.checkpoint.save_dcp_model( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, index=index, block_count=block_count) self.checkpoint.save_dcp_checkpoint( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, index=index, block_count=block_count) elif isinstance(module, Bottleneck): block_count += 1 if not isinstance(module.conv2, MaskConv2d): self.layer_channel_selection.channel_selection_for_one_layer( original_segment, pruned_segment, aux_fc, module, block_count, "conv2") self.checkpoint.save_dcp_model( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, index=index, block_count=block_count) self.checkpoint.save_dcp_checkpoint( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, index=index, block_count=block_count) if not isinstance(module.conv3, MaskConv2d): self.layer_channel_selection.channel_selection_for_one_layer( original_segment, pruned_segment, aux_fc, module, block_count, "conv3") self.checkpoint.save_dcp_model( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, index=index, block_count=block_count) self.checkpoint.save_dcp_checkpoint( self.original_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, index=index, block_count=block_count) def channel_selection_for_one_segment(self, original_segment, pruned_segment, aux_fc, pivot_index, index): """ Conduct channel selection for one segment :param original_segment: original network segments :param pruned_segment: pruned network segments :param aux_fc: auxiliary fully-connected layer :param pivot_index: the layer index of the additional loss :param index: the index of segment """ if self.settings.net_type in ["preresnet", "resnet"]: self.channel_selection_for_resnet_like_segment( original_segment, pruned_segment, aux_fc, pivot_index, index) return pruned_segment def pruning(self): """ Prune channels """ self.logger.info("Before pruning:") self.logger.info(self.pruned_model) self.segment_wise_trainer.val(0) model_analyse = ModelAnalyse(self.pruned_model, self.logger) params_num = model_analyse.params_count() zero_num = model_analyse.zero_count() zero_rate = zero_num * 1.0 / params_num self.logger.info("zero rate is: {}".format(zero_rate)) model_analyse.madds_compute(self.test_input) if self.settings.net_type in ["preresnet", "resnet"]: model_prune = ResModelPrune(model=self.pruned_model, net_type=self.settings.net_type, depth=self.settings.depth) else: assert False, "unsupport net_type: {}".format( self.settings.net_type) model_prune.run() aux_fc_state = [] for i in range(len(self.segment_wise_trainer.aux_fc)): if isinstance(self.segment_wise_trainer.aux_fc[i], nn.DataParallel): temp_state = self.segment_wise_trainer.aux_fc[ i].module.state_dict() else: temp_state = self.segment_wise_trainer.aux_fc[i].state_dict() aux_fc_state.append(temp_state) self.segment_wise_trainer.update_model(self.original_model, self.pruned_model) self.logger.info("After pruning:") self.logger.info(self.pruned_model) self.segment_wise_trainer.val(0) model_analyse = ModelAnalyse(self.pruned_model, self.logger) params_num = model_analyse.params_count() zero_num = model_analyse.zero_count() zero_rate = zero_num * 1.0 / params_num self.logger.info("zero rate is: {}".format(zero_rate)) model_analyse.madds_compute(self.test_input)
class Experiment(object): """ run experiments with pre-defined pipeline """ def __init__(self, options=None, conf_path=None): self.settings = options or Option(conf_path) self.checkpoint = None self.train_loader = None self.val_loader = None self.ori_model = None self.pruned_model = None self.segment_wise_trainer = None self.aux_fc_state = None self.aux_fc_opt_state = None self.seg_opt_state = None self.current_pivot_index = None self.is_segment_wise_finetune = False self.is_channel_selection = False self.epoch = 0 self.feature_cache_origin = {} self.feature_cache_pruned = {} os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu self.settings.set_save_path() self.write_settings() self.logger = self.set_logger() self.tensorboard_logger = TensorboardLogger(self.settings.save_path) self.prepare() def write_settings(self): """ save experimental settings to a file """ with open(os.path.join(self.settings.save_path, "settings.log"), "w") as f: for k, v in self.settings.__dict__.items(): f.write(str(k) + ": " + str(v) + "\n") def set_logger(self): """ initialize logger """ logger = logging.getLogger('channel_selection') file_formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s') console_formatter = logging.Formatter('%(message)s') # file log file_handler = logging.FileHandler( os.path.join(self.settings.save_path, "train_test.log")) file_handler.setFormatter(file_formatter) # console log console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(console_formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.setLevel(logging.INFO) return logger def prepare(self): """ preparing experiments """ self._set_gpu() self._set_dataloader() self._set_model() self._cal_pivot() self._set_checkpoint() self._set_trainier() def _set_gpu(self): """ initialize the seed of random number generator """ # set torch seed # init random seed torch.manual_seed(self.settings.seed) torch.cuda.manual_seed(self.settings.seed) torch.cuda.set_device(0) cudnn.benchmark = True def _set_dataloader(self): """ create train loader and validation loader for channel pruning """ if self.settings.dataset == 'cifar10': data_root = os.path.join(self.settings.data_path, "cifar") norm_mean = [0.49139968, 0.48215827, 0.44653124] norm_std = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) train_dataset = datasets.CIFAR10(root=data_root, train=True, transform=train_transform, download=True) val_dataset = datasets.CIFAR10(root=data_root, train=False, transform=val_transform) self.train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.settings.batch_size, shuffle=True, pin_memory=True, num_workers=self.settings.n_threads) self.val_loader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=self.settings.batch_size, shuffle=False, pin_memory=True, num_workers=self.settings.n_threads) elif self.settings.dataset == 'imagenet': dataset_path = os.path.join(self.settings.data_path, "imagenet") traindir = os.path.join(dataset_path, "train") valdir = os.path.join(dataset_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.train_loader = torch.utils.data.DataLoader( datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=self.settings.batch_size, shuffle=True, num_workers=self.settings.n_threads, pin_memory=True) self.val_loader = torch.utils.data.DataLoader( datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=self.settings.batch_size, shuffle=False, num_workers=self.settings.n_threads, pin_memory=True) def _set_trainier(self): """ initialize segment-wise trainer trainer """ # initialize segment-wise trainer self.segment_wise_trainer = SegmentWiseTrainer( ori_model=self.ori_model, pruned_model=self.pruned_model, train_loader=self.train_loader, val_loader=self.val_loader, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger) if self.aux_fc_state is not None: self.segment_wise_trainer.update_aux_fc(self.aux_fc_state, self.aux_fc_opt_state, self.seg_opt_state) def _set_model(self): """ get model """ if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.net_type == "preresnet": self.ori_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) elif self.settings.dataset in ["imagenet"]: if self.settings.net_type == "resnet": self.ori_model = md.ResNet(depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.ResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) else: assert False, "unsupported data set: {}".format( self.settings.dataset) def _set_checkpoint(self): """ load pre-trained model or resume checkpoint """ assert self.ori_model is not None and self.pruned_model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path, self.logger) self._load_retrain() self._load_resume() def _load_retrain(self): """ load pre-trained model """ if self.settings.retrain is not None: check_point_params = torch.load(self.settings.retrain) if "ori_model" not in check_point_params: model_state = check_point_params self.ori_model = self.checkpoint.load_state( self.ori_model, model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, model_state) self.logger.info("|===>load restrain file: {}".format( self.settings.retrain)) else: ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] # self.current_block_count = check_point_params["current_pivot"] self.aux_fc_state = check_point_params["aux_fc"] # self.replace_layer() self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.logger.info("|===>load pre-trained model: {}".format( self.settings.retrain)) def _load_resume(self): """ load resume checkpoint """ if self.settings.resume is not None: check_point_params = torch.load(self.settings.resume) ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] self.aux_fc_state = check_point_params["aux_fc"] self.aux_fc_opt_state = check_point_params["aux_fc_opt"] self.seg_opt_state = check_point_params["seg_opt"] self.current_pivot_index = check_point_params["current_pivot"] self.is_segment_wise_finetune = check_point_params[ "segment_wise_finetune"] self.is_channel_selection = check_point_params["channel_selection"] self.epoch = check_point_params["epoch"] self.epoch = self.settings.segment_wise_n_epochs self.current_block_count = check_point_params[ "current_block_count"] if self.is_channel_selection or \ (self.is_segment_wise_finetune and self.current_pivot_index > self.settings.pivot_set[0]): self.replace_layer() self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.logger.info("|===>load resume file: {}".format( self.settings.resume)) def _cal_pivot(self): """ calculate the inserted layer for additional loss """ self.num_segments = self.settings.n_losses + 1 num_block_per_segment = ( block_num[self.settings.net_type + str(self.settings.depth)] // self.num_segments) + 1 pivot_set = [] for i in range(self.num_segments - 1): pivot_set.append(num_block_per_segment * (i + 1)) self.settings.pivot_set = pivot_set self.logger.info("pivot set: {}".format(pivot_set)) def segment_wise_fine_tune(self, index): """ conduct segment-wise fine-tuning :param index: segment index """ best_top1 = 100 best_top5 = 100 start_epoch = 0 if self.is_segment_wise_finetune and self.epoch != 0: start_epoch = self.epoch + 1 self.epoch = 0 for epoch in range(start_epoch, self.settings.segment_wise_n_epochs): train_error, train_loss, train5_error = self.segment_wise_trainer.train( epoch, index) val_error, val_loss, val5_error = self.segment_wise_trainer.val( epoch) # write and print result if isinstance(train_error, list): best_flag = False if best_top1 >= val_error[-1]: best_top1 = val_error[-1] best_top5 = val5_error[-1] best_flag = True else: best_flag = False if best_top1 >= val_error: best_top1 = val_error best_top5 = val5_error best_flag = True if best_flag: self.checkpoint.save_model( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index) self.logger.info( "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n". format(best_top1, best_top5)) self.logger.info( "|===>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n" .format(100 - best_top1, 100 - best_top5)) if self.settings.dataset in ["imagenet"]: self.checkpoint.save_checkpoint( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, aux_fc_opt=self.segment_wise_trainer.fc_optimizer, seg_opt=self.segment_wise_trainer.seg_optimizer, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index, epoch=epoch) else: self.checkpoint.save_checkpoint( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, aux_fc_opt=self.segment_wise_trainer.fc_optimizer, seg_opt=self.segment_wise_trainer.seg_optimizer, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index) def replace_layer(self): """ Replace the convolutional layer to mask convolutional layer """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): block_count += 1 layer = module.conv2 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv2 = temp_conv if isinstance(module, Bottleneck): layer = module.conv3 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d( in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv3 = temp_conv def channel_selection(self): """ conduct channel selection """ # get testing error self.segment_wise_trainer.val(0) time_start = time.time() restart_index = None # find restart segment index if self.current_pivot_index: if self.current_pivot_index in self.settings.pivot_set: restart_index = self.settings.pivot_set.index( self.current_pivot_index) else: restart_index = len(self.settings.pivot_set) for index in range(self.num_segments): if restart_index is not None: if index < restart_index: continue elif index == restart_index: if self.is_channel_selection and self.current_block_count == self.current_pivot_index: self.is_channel_selection = False continue if index == self.num_segments - 1: self.current_pivot_index = self.segment_wise_trainer.final_block_count else: self.current_pivot_index = self.settings.pivot_set[index] # fine tune the network with additional loss and final loss if (not self.is_segment_wise_finetune and not self.is_channel_selection) or \ (self.is_segment_wise_finetune and self.epoch != self.settings.segment_wise_n_epochs - 1): self.segment_wise_fine_tune(index) else: self.is_segment_wise_finetune = False # load best model best_model_path = os.path.join( self.checkpoint.save_path, 'model_{:0>3d}_swft.pth'.format(index)) check_point_params = torch.load(best_model_path) ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] aux_fc_state = check_point_params["aux_fc"] self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) # replace the baseline model if index == 0: if self.settings.net_type in ['preresnet']: self.ori_model.conv = copy.deepcopy(self.pruned_model.conv) for ori_module, pruned_module in zip( self.ori_model.modules(), self.pruned_model.modules()): if isinstance(ori_module, PreBasicBlock): ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) self.ori_model.bn = copy.deepcopy(self.pruned_model.bn) self.ori_model.fc = copy.deepcopy(self.pruned_model.fc) elif self.settings.net_type in ['resnet']: self.ori_model.conv1 = copy.deepcopy( self.pruned_model.conv) self.ori_model.bn1 = copy.deepcopy(self.pruned_model.bn1) for ori_module, pruned_module in zip( self.ori_model.modules(), self.pruned_model.modules()): if isinstance(ori_module, BasicBlock): ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) if isinstance(ori_module, Bottleneck): ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) ori_module.conv3 = copy.deepcopy( pruned_module.conv3) ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) ori_module.bn3 = copy.deepcopy(pruned_module.bn3) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) self.ori_model.fc = copy.deepcopy(self.pruned_model.fc) aux_fc_state = [] for i in range(len(self.segment_wise_trainer.aux_fc)): if isinstance(self.segment_wise_trainer.aux_fc[i], nn.DataParallel): temp_state = self.segment_wise_trainer.aux_fc[ i].module.state_dict() else: temp_state = self.segment_wise_trainer.aux_fc[ i].state_dict() aux_fc_state.append(temp_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) self.segment_wise_trainer.val(0) # conduct channel selection # contains [0:index] segments net_origin_list = [] net_pruned_list = [] for j in range(index + 1): net_origin_list += utils.model2list( self.segment_wise_trainer.ori_segments[j]) net_pruned_list += utils.model2list( self.segment_wise_trainer.pruned_segments[j]) net_origin = nn.Sequential(*net_origin_list) net_pruned = nn.Sequential(*net_pruned_list) self._seg_channel_selection( net_origin=net_origin, net_pruned=net_pruned, aux_fc=self.segment_wise_trainer.aux_fc[index], pivot_index=self.current_pivot_index, index=index) # update optimizer aux_fc_state = [] for i in range(len(self.segment_wise_trainer.aux_fc)): if isinstance(self.segment_wise_trainer.aux_fc[i], nn.DataParallel): temp_state = self.segment_wise_trainer.aux_fc[ i].module.state_dict() else: temp_state = self.segment_wise_trainer.aux_fc[ i].state_dict() aux_fc_state.append(temp_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, self.current_pivot_index, channel_selection=True, index=index, block_count=self.current_pivot_index) self.logger.info(self.ori_model) self.logger.info(self.pruned_model) self.segment_wise_trainer.val(0) self.current_pivot_index = None self.checkpoint.save_model(self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.final_block_count, index=self.num_segments) time_interval = time.time() - time_start log_str = "cost time: {}".format( str(datetime.timedelta(seconds=time_interval))) self.logger.info(log_str) def _hook_origin_feature(self, module, input, output): gpu_id = str(output.get_device()) self.feature_cache_origin[gpu_id] = output def _hook_pruned_feature(self, module, input, output): gpu_id = str(output.get_device()) self.feature_cache_pruned[gpu_id] = output @staticmethod def _concat_gpu_data(data): data_cat = data["0"] for i in range(1, len(data)): data_cat = torch.cat((data_cat, data[str(i)].cuda(0))) return data_cat def _layer_channel_selection(self, net_origin, net_pruned, aux_fc, module, block_count, layer_name="conv2"): """ conduct channel selection for module :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param module: the module need to be pruned :param block_count: current block no. :param layer_name: the name of layer need to be pruned """ self.logger.info( "|===>layer-wise channel selection: block-{}-{}".format( block_count, layer_name)) # layer-wise channel selection if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv layer = temp_conv # define criterion criterion_mse = nn.MSELoss().cuda() criterion_softmax = nn.CrossEntropyLoss().cuda() # register hook if layer_name == "conv2": hook_origin = net_origin[block_count].conv2.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv2.register_forward_hook( self._hook_pruned_feature) elif layer_name == "conv3": hook_origin = net_origin[block_count].conv3.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv3.register_forward_hook( self._hook_pruned_feature) net_origin_parallel = utils.data_parallel(net_origin, self.settings.n_gpus) net_pruned_parallel = utils.data_parallel(net_pruned, self.settings.n_gpus) # avoid computing the gradient for params in net_origin_parallel.parameters(): params.requires_grad = False for params in net_pruned_parallel.parameters(): params.requires_grad = False net_origin_parallel.eval() net_pruned_parallel.eval() layer.pruned_weight.requires_grad = True aux_fc.cuda() logger_counter = 0 record_time = utils.AverageMeter() for channel in range(layer.in_channels): if layer.d.eq(0).sum() <= math.floor( layer.in_channels * self.settings.pruning_rate): break time_start = time.time() cum_grad = None record_selection_mse_loss = utils.AverageMeter() record_selection_softmax_loss = utils.AverageMeter() record_selection_loss = utils.AverageMeter() img_count = 0 for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() net_origin_parallel(images) output = net_pruned_parallel(images) softmax_loss = criterion_softmax(aux_fc(output), labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() record_selection_loss.update(loss.item(), images.size(0)) record_selection_mse_loss.update(mse_loss.item(), images.size(0)) record_selection_softmax_loss.update(softmax_loss.item(), images.size(0)) if cum_grad is None: cum_grad = layer.pruned_weight.grad.data.clone() else: cum_grad.add_(layer.pruned_weight.grad.data) layer.pruned_weight.grad = None img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break # write tensorboard log self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_LossAll".format(block_count, layer_name), value=record_selection_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_selection_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_selection_softmax_loss.avg, step=logger_counter) cum_grad.abs_() # calculate gradient F norm grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0) # find grad_fnorm with maximum absolute gradient while True: _, max_index = torch.topk(grad_fnorm, 1) if layer.d[max_index[0]] == 0: layer.d[max_index[0]] = 1 layer.pruned_weight.data[:, max_index[ 0], :, :] = layer.weight[:, max_index[0], :, :].data.clone( ) break else: grad_fnorm[max_index[0]] = -1 # fine-tune average meter record_finetune_softmax_loss = utils.AverageMeter() record_finetune_mse_loss = utils.AverageMeter() record_finetune_loss = utils.AverageMeter() record_finetune_top1_error = utils.AverageMeter() record_finetune_top5_error = utils.AverageMeter() # define optimizer params_list = [] params_list.append({ "params": layer.pruned_weight, "lr": self.settings.layer_wise_lr }) if layer.bias is not None: layer.bias.requires_grad = True params_list.append({"params": layer.bias, "lr": 0.001}) optimizer = torch.optim.SGD( params=params_list, weight_decay=self.settings.weight_decay, momentum=self.settings.momentum, nesterov=True) img_count = 0 for epoch in range(1): for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() features = net_pruned_parallel(images) net_origin_parallel(images) output = aux_fc(features) softmax_loss = criterion_softmax(output, labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) top1_error, _, top5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=softmax_loss, top5_flag=True, mean_flag=True) # update parameters optimizer.zero_grad() loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=10.0) layer.pruned_weight.grad.data.mul_( layer.d.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(layer.pruned_weight)) optimizer.step() # update record info record_finetune_softmax_loss.update( softmax_loss.item(), images.size(0)) record_finetune_mse_loss.update(mse_loss.item(), images.size(0)) record_finetune_loss.update(loss.item(), images.size(0)) record_finetune_top1_error.update(top1_error, images.size(0)) record_finetune_top5_error.update(top5_error, images.size(0)) img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break layer.pruned_weight.grad = None if layer.bias is not None: layer.bias.requires_grad = False self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_finetune_softmax_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Loss".format(block_count, layer_name), value=record_finetune_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_finetune_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top1Error".format(block_count, layer_name), value=record_finetune_top1_error.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top5Error".format(block_count, layer_name), value=record_finetune_top5_error.avg, step=logger_counter) # write log information to file self._write_log( dir_name=os.path.join(self.settings.save_path, "log"), file_name="log_block-{:0>2d}_{}.txt".format( block_count, layer_name), log_str= "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n". format(int(layer.d.sum()), record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg, record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg)) log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format( block_count, layer_name, int(layer.d.sum()), layer.d.size(0)) log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg) log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg) log_str += "top1error: {:4f}\ttop5error: {:4f}".format( record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) logger_counter += 1 time_interval = time.time() - time_start record_time.update(time_interval) for params in net_origin_parallel.parameters(): params.requires_grad = True for params in net_pruned_parallel.parameters(): params.requires_grad = True # remove hook hook_origin.remove() hook_pruned.remove() log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format( block_count, layer_name, str(datetime.timedelta(seconds=record_time.sum)), str(datetime.timedelta(seconds=record_time.avg))) self.logger.info(log_str) log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) self.logger.info("|===>remove hook") @staticmethod def _write_log(dir_name, file_name, log_str): """ Write log to file :param dir_name: the path of directory :param file_name: the name of the saved file :param log_str: the string that need to be saved """ if not os.path.isdir(dir_name): os.mkdir(dir_name) with open(os.path.join(dir_name, file_name), "a+") as f: f.write(log_str) def _seg_channel_selection(self, net_origin, net_pruned, aux_fc, pivot_index, index): """ conduct segment channel selection :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param pivot_index: the layer index of the additional loss :param index: the index of segment :return: """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in net_pruned.modules(): if isinstance(module, (PreBasicBlock, BasicBlock)): block_count += 1 # We will not prune the pruned blocks again if not isinstance(module.conv2, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv2") self.logger.info("|===>checking layer type: {}".format( type(module.conv2))) self.checkpoint.save_model( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, channel_selection=True, index=index, block_count=block_count) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, channel_selection=True, index=index, block_count=block_count) elif isinstance(module, Bottleneck): block_count += 1 if not isinstance(module.conv2, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv2") if not isinstance(module.conv3, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv3") self.checkpoint.save_model( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, channel_selection=True, index=index, block_count=block_count) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, channel_selection=True, index=index, block_count=block_count)