def dynamic_prune_inactive_neural_with_feature_extractor(net, net_name, exp_name, target_accuracy, initial_prune_rate, delta_prune_rate=0.02, round_for_train=2, tar_acc_gradual_decent=False, flop_expected=None, dataset_name='imagenet', batch_size=conf.batch_size, num_workers=conf.num_workers, learning_rate=0.01, evaluate_step=1200, num_epoch=450, num_epoch_after_finetune=20, filter_preserve_ratio=0.3, # max_filters_pruned_for_one_time=0.5, optimizer=optim.Adam, learning_rate_decay=False, learning_rate_decay_factor=conf.learning_rate_decay_factor, weight_decay=conf.weight_decay, learning_rate_decay_epoch=conf.learning_rate_decay_epoch, max_training_round=9999, round=1, top_acc=1, max_data_to_test=10000, **kwargs ): ''' :param net: :param net_name: :param exp_name: :param target_accuracy: :param initial_prune_rate: :param delta_prune_rate: increase/decrease of prune_rate after each round :param round_for_train: :param tar_acc_gradual_decent: :param flop_expected: :param dataset_name: :param batch_size: :param num_workers: :param optimizer: :param learning_rate: :param evaluate_step: :param num_epoch: :param num_epoch_after_finetune: epoch to train after fine-tune the pruned net :param filter_preserve_ratio: :param max_filters_pruned_for_one_time: :param learning_rate_decay: :param learning_rate_decay_factor: :param weight_decay: :param learning_rate_decay_epoch: :param max_training_round:if the net can't reach target accuracy in max_training_round , the program stop. :param top_acc: :param kwargs: :return: ''' # save the output to log print('save log in:' + os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt')) if not os.path.exists(os.path.join(conf.root_path, 'model_saved', exp_name)): os.makedirs(os.path.join(conf.root_path, 'model_saved', exp_name), exist_ok=True) sys.stdout = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'), sys.stdout) sys.stderr = logger.Logger(os.path.join(conf.root_path, 'model_saved', exp_name, 'log.txt'), sys.stderr) # redirect std err, if necessary print('net:', net) print('net_name:', net_name) print('exp_name:', exp_name) print('target_accuracy:', target_accuracy) print('initial_prune_rate:', initial_prune_rate) print('delta_prune_rate:',delta_prune_rate) print('round_for_train:', round_for_train) print('tar_acc_gradual_decent:', tar_acc_gradual_decent) print('flop_expected:', flop_expected) print('dataset_name:', dataset_name) print('batch_size:', batch_size) print('num_workers:', num_workers) print('optimizer:', optimizer) print('learning_rate:', learning_rate) print('evaluate_step:', evaluate_step) print('num_epoch:', num_epoch) print('num_epoch_after_finetune:',num_epoch_after_finetune) print('filter_preserve_ratio:', filter_preserve_ratio) # print('max_filters_pruned_for_one_time:', max_filters_pruned_for_one_time) print('learning_rate_decay:', learning_rate_decay) print('learning_rate_decay_factor:', learning_rate_decay_factor) print('weight_decay:', weight_decay) print('learning_rate_decay_epoch:', learning_rate_decay_epoch) print('max_training_round:', max_training_round) print('top_acc:', top_acc) print('round:', round) print(kwargs) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('using: ', end='') if torch.cuda.is_available(): print(torch.cuda.device_count(),' * ',end='') print(torch.cuda.get_device_name(torch.cuda.current_device())) else: print(device) checkpoint_path=os.path.join(conf.root_path, 'model_saved', exp_name) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) validation_loader = data_loader.create_validation_loader( batch_size=batch_size, num_workers=num_workers, dataset_name=dataset_name, ) if isinstance(net,nn.DataParallel): net_entity=net.module else: net_entity=net flop_original_net=flop_before_prune = measure_model(net_entity.prune(), dataset_name) original_accuracy = evaluate.evaluate_net(net=net, data_loader=validation_loader, save_net=False, dataset_name=dataset_name, top_acc=top_acc ) if tar_acc_gradual_decent is True: if flop_expected<=1: flop_expected=int(flop_original_net*flop_expected) flop_drop_expected = flop_original_net - flop_expected acc_drop_tolerance = original_accuracy - target_accuracy conv_list, filter_num, filter_num_lower_bound=get_information_for_pruned_conv(net_entity,net_name,filter_preserve_ratio) num_filters_to_prune_at_most=[] for i in range(len(filter_num)): num_filters_to_prune_at_most+=[filter_num[i]-filter_num_lower_bound[i]] prune_rate=initial_prune_rate while True: print('{} start round {} of filter pruning.'.format(datetime.now(), round)) print('{} current prune_rate:{}'.format(datetime.now(),prune_rate)) #todo 待考虑 net_entity.reset_mask() if round <= round_for_train: dead_filter_index, module_list, neural_list, FIRE_tmp = evaluate.find_useless_filters_data_version( net=net_entity, batch_size=16, #this function need to run on sigle gpu percent_of_inactive_filter=prune_rate, dead_or_inactive='inactive', dataset_name=dataset_name, max_data_to_test=max_data_to_test, num_filters_to_prune_at_most=num_filters_to_prune_at_most, ) num_test_images = math.ceil(max_data_to_test / 16) * 16 #Warning, this goes wrong if dataset_size is smaller if not os.path.exists(os.path.join(checkpoint_path, 'dead_neural')): os.makedirs(os.path.join(checkpoint_path, 'dead_neural'), exist_ok=True) checkpoint = {'prune_rate': prune_rate, 'module_list': module_list, 'neural_list': neural_list, 'state_dict': net_entity.state_dict(), 'num_test_images':num_test_images} checkpoint.update(storage.get_net_information(net_entity, dataset_name, net_name)) torch.save(checkpoint, os.path.join(checkpoint_path, 'dead_neural/round %d.tar' % round) ) else : print('hahaha') return #todo:round>round_for_train,use filter feature extractor '''卷积核剪枝''' for i in conv_list: #todo 待考虑 # # ensure the number of filters pruned will not be too large for one time # if type(max_filters_pruned_for_one_time) is list: # num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time[i] # else: # num_filters_to_prune_max = filter_num[i] * max_filters_pruned_for_one_time # if num_filters_to_prune_max < len(dead_filter_index[i]): # dead_filter_index[i] = dead_filter_index[i][:int(num_filters_to_prune_max)] # ensure the lower bound of filter number if filter_num[i] - len(dead_filter_index[i]) < filter_num_lower_bound[i]: raise Exception('i think something is wrong') dead_filter_index[i] = dead_filter_index[i][:filter_num[i] - filter_num_lower_bound[i]] print('layer {}: has {} filters, prunes {} filters, remains {} filters.'. format(i, filter_num[i], len(dead_filter_index[i]),filter_num[i]-len(dead_filter_index[i]))) net_entity.mask_filters(i,dead_filter_index[i]) flop_after_prune = measure_model(net_entity.prune(), dataset_name) net_compressed= (flop_after_prune != flop_before_prune) if net_compressed is False: round -= 1 print('{} round {} did not prune any filters. Restart.'.format(datetime.now(), round + 1)) prune_rate += 0.02 continue if tar_acc_gradual_decent is True: # decent the target_accuracy flop_reduced = flop_original_net - flop_after_prune target_accuracy = original_accuracy - acc_drop_tolerance * (flop_reduced / flop_drop_expected) print('{} current target accuracy:{}'.format(datetime.now(), target_accuracy)) success = False training_round=0 while not success: old_net = deepcopy(net) training_round+=1 success = train.train(net=net, net_name=net_name, exp_name=exp_name, num_epochs=num_epoch, target_accuracy=target_accuracy, learning_rate=learning_rate, load_net=False, evaluate_step=evaluate_step, dataset_name=dataset_name, optimizer=optimizer, batch_size=batch_size, learning_rate_decay=learning_rate_decay, learning_rate_decay_factor=learning_rate_decay_factor, weight_decay=weight_decay, learning_rate_decay_epoch=learning_rate_decay_epoch, test_net=True, top_acc=top_acc, no_grad=['mask'] ) if success: prune_rate+= delta_prune_rate round += 1 flop_before_prune=flop_after_prune net_entity.reset_mask() if num_epoch_after_finetune ==0: break train.train(net=net, net_name=net_name, exp_name=exp_name, num_epochs=num_epoch_after_finetune, learning_rate=learning_rate, load_net=False, evaluate_step=evaluate_step, dataset_name=dataset_name, optimizer=optimizer, batch_size=batch_size, test_net=True, top_acc=top_acc, no_grad=['mask'] ) else: net = old_net if isinstance(net, nn.DataParallel): net_entity = net.module else: net_entity = net net_entity.reset_mask() if max_training_round == training_round: prune_rate-=delta_prune_rate/2 if prune_rate<=0: print('{} failed to prune the net, pruning stop.'.format(datetime.now())) return break
groups=conv.groups, bias=(conv.bias is not None)) self.weight = conv.weight if self.bias is not None: self.bias = conv.bias self.mask = mask self.front_mask = front_mask def forward(self, input): #mask the pruned filter and channel masked_weight = self.weight * self.mask.detach().reshape( (-1, 1, 1, 1)) * self.front_mask.detach().reshape((1, -1, 1, 1)) masked_bias = self.bias * self.mask.detach() out = F.conv2d(input, masked_weight, masked_bias, self.stride, self.padding, self.dilation, self.groups) return out if __name__ == "__main__": net = NetWithMask(dataset_name='imagenet', net_name='vgg16_bn') # # tmp=net.prune() # net.mask_filters(layer_index=0,filter_index=[0,2,5]) from framework import evaluate, data_loader dl = data_loader.create_validation_loader(batch_size=512, num_workers=2, dataset_name='cifar10') evaluate.evaluate_net(net, dl, save_net=False)
def speed_up_pruned_net(): fontsize = 15 checkpoint = torch.load( '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar') net_original = storage.restore_net(checkpoint) checkpoint = torch.load( '../data/baseline/resnet56_cifar10,accuracy=0.93280.tar') net_original = resnet_cifar.resnet56() net_original.load_state_dict(checkpoint['state_dict']) checkpoint = torch.load( '../data/model_saved/vgg16bn_cifar10_realdata_regressor6_大幅度/checkpoint/flop=39915982,accuracy=0.93200.tar' ) checkpoint = torch.load( '../data/model_saved/resnet56_cifar10_regressor_prunedBaseline2/checkpoint/flop=36145802,accuracy=0.92110.tar' ) net_pruned = storage.restore_net(checkpoint) net_pruned.load_state_dict(checkpoint['state_dict']) # batch_size=[256,512,1024] num_workers = [i for i in range(4, 5)] batch_size = [300, 600, 1000, 1600] device_list = [torch.device('cuda')] # # device_list=[torch.device('cpu')] for num_worker in num_workers: time_original = [] time_pruned = [] for d in device_list: for bs in batch_size: net_original.to(d) net_pruned.to(d) dl = data_loader.create_validation_loader( batch_size=bs, num_workers=num_worker, dataset_name='cifar10') start_time = time.time() evaluate.evaluate_net(net=net_original, data_loader=dl, save_net=False, device=d) end_time = time.time() time_original.append(end_time - start_time) del dl dl = data_loader.create_validation_loader( batch_size=bs, num_workers=num_worker, dataset_name='cifar10') start_time = time.time() evaluate.evaluate_net(net=net_pruned, data_loader=dl, save_net=False, device=d) end_time = time.time() time_pruned.append(end_time - start_time) del dl print('time before pruned:', time_original) print('time after pruned:', time_pruned) acceleration = np.array(time_original) / np.array(time_pruned) baseline = np.ones(shape=len(batch_size)) x_tick = range(len(baseline)) plt.figure() plt.bar(x_tick, acceleration, color='blue', hatch='//') #,label='GPU') # plt.bar(x_tick[len(batch_size):], acceleration[len(batch_size):], color='grey', hatch='\\', label='CPU') # plt.bar(x_tick,baseline,color='red',hatch='*',label='Baseline') plt.xticks(x_tick, batch_size, fontsize=fontsize) plt.yticks(fontsize=fontsize) for x, y in enumerate(list(acceleration)): plt.text(x, y + 0.1, '%.2f x' % y, ha='center', fontsize=fontsize) plt.ylim([0, np.max(acceleration) + 0.5]) plt.xlabel('Batch-Size', fontsize=fontsize) plt.ylabel('Speed-Up', fontsize=fontsize) # plt.legend(loc='upper left') plt.savefig('resnet_gpu_speed_up.eps', format='eps') # plt.savefig(str(num_worker)+'speed_up.jpg') plt.show() print()
# def conversion(dataset_name,net_name,checkpoint_path='',checkpoint=None): #转化之前的cifar上的resnet # checkpoint=torch.load(checkpoint_path) # # net=checkpoint.pop('net') # net=resnet_cifar.resnet32() # checkpoint['state_dict']['fc.weight']=checkpoint['state_dict'].pop('linear.weight') # checkpoint['state_dict']['fc.bias']=checkpoint['state_dict'].pop('linear.bias') # # net.load_state_dict(checkpoint['state_dict']) # checkpoint.update(get_net_information(net,dataset_name,net_name)) # torch.save(checkpoint,checkpoint_path) if __name__ == "__main__": # conversion(checkpoint_path='../data/baseline/resnet32_cifar10,accuracy=0.92380.tar',net_name='resnet56',dataset_name='cifar10') checkpoint = torch.load( '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar10,accuracy=0.94230.tar' ) net = resnet_cifar.resnet56() # net.load_state_dict(checkpoint['state_dict']) c = get_net_information(net=net, dataset_name='cifar10', net_name='resnet56') net = restore_net(checkpoint, True) from framework import evaluate, data_loader evaluate.evaluate_net( net, data_loader.create_validation_loader(512, 4, 'cifar10'), False) # c=get_net_information(net=net,dataset_name=dataset_name,net_name='resnet50')
# c_new = {'highest_accuracy': c_original['highest_accuracy'], # 'state_dict': new_state_dict} # # torch.save(c_new, path) if __name__ == "__main__": import torch from network import storage from framework import evaluate, data_loader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load( '/home/victorfang/model_pytorch/data/baseline/resnet56_cifar100_0.71580.tar' ) # c_sample=torch.load('/home/victorfang/model_pytorch/data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar') net = resnet56(num_classes=200) net.load_state_dict(checkpoint['state_dict']) net.to(device) # checkpoint.update(storage.get_net_information(net=net,dataset_name='tiny_imagenet',net_name='resnet18')) # checkpoint.pop('net') # torch.save(checkpoint,'/home/victorfang/model_pytorch/data/baseline/resnet18_tinyimagenet_v2_0.72990.tar') # net=storage.restore_net(checkpoint=checkpoint,pretrained=True) # net=nn.DataParallel(net) evaluate.evaluate_net(net=net, data_loader=data_loader.create_validation_loader( 512, 8, 'cifar100'), save_net=False, dataset_name='cifar100') print()
def train( net, net_name, exp_name='', dataset_name='imagenet', train_loader=None, validation_loader=None, learning_rate=conf.learning_rate, num_epochs=conf.num_epochs, batch_size=conf.batch_size, evaluate_step=conf.evaluate_step, load_net=True, test_net=False, root_path=conf.root_path, checkpoint_path=None, momentum=conf.momentum, num_workers=conf.num_workers, learning_rate_decay=False, learning_rate_decay_factor=conf.learning_rate_decay_factor, learning_rate_decay_epoch=conf.learning_rate_decay_epoch, weight_decay=conf.weight_decay, target_accuracy=1.0, optimizer=optim.SGD, top_acc=1, criterion=nn.CrossEntropyLoss(), # 损失函数默为交叉熵,多用于多分类问题 no_grad=[], scheduler_name='MultiStepLR', eta_min=0, #todo:tmp!!! data_parallel=False): ''' :param net: net to be trained :param net_name: name of the net :param exp_name: name of the experiment :param dataset_name: name of the dataset :param train_loader: data_loader for training. If not provided, a data_loader will be created based on dataset_name :param validation_loader: data_loader for validation. If not provided, a data_loader will be created based on dataset_name :param learning_rate: initial learning rate :param learning_rate_decay: boolean, if true, the learning rate will decay based on the params provided. :param learning_rate_decay_factor: float. learning_rate*=learning_rate_decay_factor, every time it decay. :param learning_rate_decay_epoch: list[int], the specific epoch that the learning rate will decay. :param num_epochs: max number of epochs for training :param batch_size: :param evaluate_step: how often will the net be tested on validation set. At least one test every epoch is guaranteed :param load_net: boolean, whether loading net from previous checkpoint. The newest checkpoint will be selected. :param test_net:boolean, if true, the net will be tested before training. :param root_path: :param checkpoint_path: :param momentum: :param num_workers: :param weight_decay: :param target_accuracy:float, the training will stop once the net reached target accuracy :param optimizer: :param top_acc: can be 1 or 5 :param criterion: loss function :param no_grad: list containing names of the modules that do not need to be trained :param scheduler_name :param eta_min: for CosineAnnealingLR :return: ''' success = True #if the trained net reaches target accuracy # gpu or not device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('using: ', end='') if torch.cuda.is_available(): print(torch.cuda.device_count(), ' * ', end='') print(torch.cuda.get_device_name(torch.cuda.current_device())) else: print(device) #prepare the data if dataset_name is 'imagenet': mean = conf.imagenet['mean'] std = conf.imagenet['std'] train_set_path = conf.imagenet['train_set_path'] train_set_size = conf.imagenet['train_set_size'] validation_set_path = conf.imagenet['validation_set_path'] default_image_size = conf.imagenet['default_image_size'] elif dataset_name is 'cifar10': train_set_size = conf.cifar10['train_set_size'] mean = conf.cifar10['mean'] std = conf.cifar10['std'] train_set_path = conf.cifar10['dataset_path'] validation_set_path = conf.cifar10['dataset_path'] default_image_size = conf.cifar10['default_image_size'] elif dataset_name is 'tiny_imagenet': train_set_size = conf.tiny_imagenet['train_set_size'] mean = conf.tiny_imagenet['mean'] std = conf.tiny_imagenet['std'] train_set_path = conf.tiny_imagenet['train_set_path'] validation_set_path = conf.tiny_imagenet['validation_set_path'] default_image_size = conf.tiny_imagenet['default_image_size'] elif dataset_name is 'cifar100': train_set_size = conf.cifar100['train_set_size'] mean = conf.cifar100['mean'] std = conf.cifar100['std'] train_set_path = conf.cifar100['dataset_path'] validation_set_path = conf.cifar100['dataset_path'] default_image_size = conf.cifar100['default_image_size'] if train_loader is None: train_loader = data_loader.create_train_loader( dataset_path=train_set_path, default_image_size=default_image_size, mean=mean, std=std, batch_size=batch_size, num_workers=num_workers, dataset_name=dataset_name) if validation_loader is None: validation_loader = data_loader.create_validation_loader( dataset_path=validation_set_path, default_image_size=default_image_size, mean=mean, std=std, batch_size=batch_size, num_workers=num_workers, dataset_name=dataset_name) if checkpoint_path is None: checkpoint_path = os.path.join(root_path, 'model_saved', exp_name, 'checkpoint') if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) #get the latest checkpoint lists = os.listdir(checkpoint_path) file_new = checkpoint_path if len(lists) > 0: lists.sort(key=lambda fn: os.path.getmtime(checkpoint_path + "/" + fn) ) # 按时间排序 file_new = os.path.join(checkpoint_path, lists[-1]) # 获取最新的文件保存到file_new sample_num = 0 if os.path.isfile(file_new): if load_net: checkpoint = torch.load(file_new) print('{} load net from previous checkpoint:{}'.format( datetime.now(), file_new)) # net.load_state_dict(checkpoint['state_dict']) net = storage.restore_net(checkpoint, pretrained=True, data_parallel=data_parallel) sample_num = checkpoint['sample_num'] if test_net: print('{} test the net'.format( datetime.now())) #no previous checkpoint net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net(net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format(datetime.now())) return success #ensure the net will be evaluated despite the inappropriate evaluate_step if evaluate_step > math.ceil(train_set_size / batch_size) - 1: evaluate_step = math.ceil(train_set_size / batch_size) - 1 optimizer = prepare_optimizer(net, optimizer, no_grad, momentum, learning_rate, weight_decay) if learning_rate_decay: if scheduler_name == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=learning_rate_decay_epoch, gamma=learning_rate_decay_factor, last_epoch=ceil(sample_num / train_set_size)) elif scheduler_name == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, num_epochs, eta_min=eta_min, last_epoch=ceil(sample_num / train_set_size)) print("{} Start training ".format(datetime.now()) + net_name + "...") for epoch in range(math.floor(sample_num / train_set_size), num_epochs): print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) net.train() # one epoch for one loop for step, data in enumerate(train_loader, 0): if sample_num / train_set_size == epoch + 1: #one epoch of training finished net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net( net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format( datetime.now())) return success break # 准备数据 images, labels = data images, labels = images.to(device), labels.to(device) sample_num += int(images.shape[0]) # if learning_rate_decay: # exponential_decay_learning_rate(optimizer=optimizer, # sample_num=sample_num, # learning_rate_decay_factor=learning_rate_decay_factor, # train_set_size=train_set_size, # learning_rate_decay_epoch=learning_rate_decay_epoch, # batch_size=batch_size) optimizer.zero_grad() # forward + backward outputs = net(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if step % 60 == 0: print('{} loss is {}'.format(datetime.now(), float(loss.data))) if step % evaluate_step == 0 and step != 0: net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net( net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format( datetime.now())) return success accuracy = float(accuracy) print('{} continue training'.format(datetime.now())) if learning_rate_decay: scheduler.step() print(optimizer.state_dict()['param_groups'][0]['lr']) print("{} Training finished. Saving net...".format(datetime.now())) net_test = copy.deepcopy(net) flop_num = measure_flops.measure_model(net=net_test, dataset_name=dataset_name, print_flop=False) accuracy = evaluate.evaluate_net(net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) accuracy = float(accuracy) checkpoint = { 'highest_accuracy': accuracy, 'state_dict': net.state_dict(), 'sample_num': sample_num, 'flop_num': flop_num } checkpoint.update(storage.get_net_information(net, dataset_name, net_name)) torch.save( checkpoint, '%s/flop=%d,accuracy=%.5f.tar' % (checkpoint_path, flop_num, accuracy)) print("{} net saved at sample num = {}".format(datetime.now(), sample_num)) return not success