def read_data(path, num_images=None): sample = [] file_list = os.listdir(path) file_list.sort() print(file_list) for file_name in file_list: if '.tar' in file_name: checkpoint = torch.load(os.path.join(path, file_name)) net = storage.restore_net(checkpoint, pretrained=True) # from framework import measure_flops # measure_flops.measure_model(net, 'cifar10',print_flop=True) neural_list = checkpoint['neural_list'] try: module_list = checkpoint['module_list'] except KeyError: module_list = checkpoint['relu_list'] num_conv = 0 # num of conv layers in the network filter_weight = [] layers = [] for mod in net.modules(): if isinstance(mod, torch.nn.modules.conv.Conv2d): num_conv += 1 conv = mod elif isinstance( mod, torch.nn.ReLU): #ensure the conv are followed by relu if layers != [] and layers[ -1] == num_conv - 1: # get rid of the influence from relu in fc continue filter_weight.append(conv.weight.data.cpu().numpy()) layers.append(num_conv - 1) filter_layer = [] filter_label = [] for i in range(len(filter_weight)): for module_key in list(neural_list.keys()): if module_list[ i] is module_key: #find the neural_list_statistics in layer i+1 dead_times = neural_list[module_key] neural_num = dead_times.shape[1] * dead_times.shape[ 2] #neural num for one filter #compute sum(dead_times)/(num_images*neural_num) as label for each filter dead_times = np.sum(dead_times, axis=(1, 2)) prediction = dead_times / (neural_num * num_images) filter_label += prediction.tolist() filter_layer += [ layers[i] for j in range(filter_weight[i].shape[0]) ] sample.append({ 'net': net, 'filter_label': filter_label, 'filter_layer': filter_layer, 'net_name': checkpoint['net_name'], 'dataset_name': checkpoint['dataset_name'] }) return sample
def read_from_checkpoint(path): if '.tar' in path: file_list = [path] #single net else: file_list = os.listdir(path) filters = [] for file_name in file_list: if '.tar' in file_name: checkpoint = torch.load(os.path.join(path, file_name)) net = storage.restore_net(checkpoint) net.load_state_dict(checkpoint['state_dict']) filters += get_filters(net=net) return filters
def speed_up_regressor(): path = '/home/victorfang/PycharmProjects/model_pytorch/model_saved/vgg16bn_tinyimagenet_prune/checkpoint' predictor = predict_dead_filter.predictor(name='random_forest') predictor.load( path='/home/disk_new/model_saved/vgg16bn_tinyimagenet_prune/') file_list = os.listdir(path) file_list.sort() regressor_time = [] real_data_time = [] for file in file_list: print(file) checkpoint_path = os.path.join(path, file) checkpoint = torch.load(checkpoint_path) net = storage.restore_net(checkpoint) net.load_state_dict(checkpoint['state_dict']) #time for regressor start_time = time.time() evaluate.find_useless_filters_regressor_version( net=net, predictor=predictor, percent_of_inactive_filter=0.1, max_filters_pruned_for_one_time=0.2, ) end_time = time.time() regressor_time.append(end_time - start_time) #time for sampled data start_time = time.time() evaluate.find_useless_filters_data_version( net=net, batch_size=24, dataset_name='tiny_imagenet', percent_of_inactive_filter=0.1, max_data_to_test=50000, ) end_time = time.time() real_data_time.append(end_time - start_time) print(regressor_time) print(real_data_time)
def plot_dead_neuron_filter_number( neural_dead_times=8000, dataset_name='cifar10', ): fontsize = 17 label_fontsize = 24 tick_fontsize = 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load( '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar') vgg16 = storage.restore_net(checkpoint) vgg16.load_state_dict(checkpoint['state_dict']) checkpoint = torch.load( '../data/baseline/resnet56_cifar10,accuracy=0.94230.tar') resnet56 = resnet_cifar.resnet56().to(device) resnet56.load_state_dict(checkpoint['state_dict']) vgg16_imagenet = vgg.vgg16_bn(pretrained=True).to(device) checkpoint = torch.load( '/home/victorfang/Desktop/vgg16_bn_imagenet_deadReLU.tar') relu_list_imagenet = checkpoint['relu_list'] neural_list_imagenet = checkpoint['neural_list'] loader = data_loader.create_validation_loader(batch_size=100, num_workers=1, dataset_name=dataset_name) # loader=data_loader.create_validation_loader(batch_size=1000,num_workers=8,dataset_name='cifar10_trainset') relu_list_vgg, neural_list_vgg = evaluate.check_ReLU_alive( net=vgg16, neural_dead_times=neural_dead_times, data_loader=loader, max_data_to_test=10000) relu_list_resnet, neural_list_resnet = evaluate.check_ReLU_alive( net=resnet56, neural_dead_times=neural_dead_times, data_loader=loader, max_data_to_test=10000) def get_statistics(net, relu_list, neural_list, neural_dead_times, sample_num=10000): num_conv = 0 # num of conv layers in the net for mod in net.modules(): if isinstance(mod, torch.nn.modules.conv.Conv2d): num_conv += 1 neural_dead_list = [] #神经元死亡次数的列表 filter_dead_list = [] #卷积核死亡比率的列表 FIRE = [] for i in range(num_conv): for relu_key in list(neural_list.keys()): if relu_list[ i] is relu_key: # find the neural_list_statistics in layer i+1 dead_times = copy.deepcopy(neural_list[relu_key]) neural_dead_list += copy.deepcopy( dead_times).flatten().tolist() neural_num = dead_times.shape[1] * dead_times.shape[ 2] # neural num for one filter # compute sum(dead_times)/(batch_size*neural_num) as label for each filter dead_times = np.sum(dead_times, axis=(1, 2)) FIRE += (dead_times / (neural_num * sample_num)).tolist() # # judge dead filter by neural_dead_times and dead_filter_ratio # dead_times[dead_times < neural_dead_times] = 0 # dead_times[dead_times >= neural_dead_times] = 1 # dead_times = np.sum(dead_times, axis=(1, 2)) # count the number of dead neural for one filter # dead_times = dead_times / neural_num # filter_dead_list+=dead_times.tolist() break active_ratio = 1 - np.array(FIRE) active_filter_list = 1 - np.array(filter_dead_list) neural_activated_list = (sample_num - np.array(neural_dead_list)) / sample_num return neural_activated_list, active_ratio, #active_filter_list nal_vgg, afl_vgg = get_statistics(vgg16, relu_list_vgg, neural_list_vgg, neural_dead_times=neural_dead_times) nal_resnet, afl_resnet = get_statistics( resnet56, relu_list_resnet, neural_list_resnet, neural_dead_times=neural_dead_times) nal_imagenet, afl_imagenet = get_statistics(vgg16_imagenet, relu_list_imagenet, neural_list_imagenet, sample_num=50000, neural_dead_times=40000) # #cdf_of_dead_neurons # plt.figure() # plt.hist([nal_vgg,nal_resnet,nal_imagenet],cumulative=True,histtype='step',bins=1000,density=True,)#linewidth=5.0) #cumulative=False为pdf,true为cdf # # plt.hist(neural_activated_list,cumulative=True,histtype='bar',bins=20,density=True,rwidth=0.6) #cumulative=False为pdf,true为cdf # plt.xlabel('Activation Ratio',fontsize = fontsize) # plt.ylabel('Ratio of Neurons',fontsize = fontsize) # plt.legend(['VGG-16 on CIFAR-10','ResNet-56 on CIFAR-10','VGG-16 on ImageNet'],loc='upper left',fontsize = fontsize) # # plt.savefig('0cdf_of_dead_neurons.jpg') # plt.savefig('cdf_of_dead_neurons.eps',format='eps') # plt.show() # # #cdf_of_inactive_filter # plt.figure() # plt.hist([afl_vgg,afl_resnet,afl_imagenet],cumulative=True,histtype='step',bins=1000,density=True,)#linewidth=5.0) #cumulative=False为pdf,true为cdf # # plt.hist(neural_activated_list,cumulative=True,histtype='bar',bins=20,density=True,rwidth=0.6) #cumulative=False为pdf,true为cdf # plt.xlabel('Activation Ratio',fontsize = fontsize) # plt.ylabel('Ratio of Filters',fontsize = fontsize) # plt.legend(['VGG-16 on CIFAR-10','ResNet-56 on CIFAR-10','VGG-16 on ImageNet'],loc='upper left',fontsize = fontsize) # # plt.savefig('0cdf_of_inactive_filter.jpg') # plt.savefig('cdf_of_inactive_filter.eps',format='eps') # plt.show() #pdf_of_dead_neurons plt.figure() hist_list = [] for nal in [nal_vgg, nal_resnet, nal_imagenet]: hist, bins = np.histogram(nal, bins=[0.1 * i for i in range(11)]) hist_list.append(100 * hist / np.sum(hist)) x_tick = np.array([5, 15, 25, 35, 45, 55, 65, 75, 85, 95]) plt.figure() plt.bar(x_tick - 2, hist_list[0], color='coral', edgecolor='black', label='VGG-16 on CIFAR-10', align='center', width=2) plt.bar(x_tick, hist_list[1], color='cyan', edgecolor='black', label='ResNet-56 on CIFAR-10', align='center', width=2) plt.bar(x_tick + 2, hist_list[2], color='mediumslateblue', edgecolor='black', label='VGG-16 on ImageNet', align='center', width=2) plt.xticks(x_tick, x_tick, size=tick_fontsize) plt.yticks(size=tick_fontsize) plt.xlabel('Activation Ratio (%)', fontsize=label_fontsize) plt.ylabel('% of Neurons', fontsize=label_fontsize) plt.legend(loc='upper right', fontsize=fontsize) # plt.savefig('0pdf_of_dead_neurons.jpg') plt.savefig('pdf_of_dead_neurons.eps', format='eps', bbox_inches='tight') plt.show() #pdf_of_inactive_filter plt.figure() hist_list = [] for active_ratio in [afl_vgg, afl_resnet, afl_imagenet]: hist, bins = np.histogram(active_ratio, bins=[0.1 * i for i in range(11)]) hist_list.append(100 * hist / np.sum(hist)) x_tick = np.array([5, 15, 25, 35, 45, 55, 65, 75, 85, 95]) plt.figure() plt.bar(x_tick - 2, hist_list[0], color='coral', edgecolor='black', label='VGG-16 on CIFAR-10', align='center', width=2) plt.bar(x_tick, hist_list[1], color='cyan', edgecolor='black', label='ResNet-56 on CIFAR-10', align='center', width=2) plt.bar(x_tick + 2, hist_list[2], color='mediumslateblue', edgecolor='black', label='VGG-16 on ImageNet', align='center', width=2) plt.xticks(x_tick, x_tick, size=tick_fontsize) plt.yticks(size=tick_fontsize) plt.xlabel('Activation Ratio (%)', fontsize=label_fontsize) plt.ylabel('% of Filters', fontsize=label_fontsize) plt.legend(loc='upper right', fontsize=fontsize) # plt.savefig('0pdf_of_inactive_filter.jpg') plt.savefig('pdf_of_inactive_filter.eps', format='eps', bbox_inches='tight') plt.show() print()
def speed_up_pruned_net(): fontsize = 15 checkpoint = torch.load( '../data/baseline/vgg16_bn_cifar10,accuracy=0.941.tar') net_original = storage.restore_net(checkpoint) checkpoint = torch.load( '../data/baseline/resnet56_cifar10,accuracy=0.93280.tar') net_original = resnet_cifar.resnet56() net_original.load_state_dict(checkpoint['state_dict']) checkpoint = torch.load( '../data/model_saved/vgg16bn_cifar10_realdata_regressor6_大幅度/checkpoint/flop=39915982,accuracy=0.93200.tar' ) checkpoint = torch.load( '../data/model_saved/resnet56_cifar10_regressor_prunedBaseline2/checkpoint/flop=36145802,accuracy=0.92110.tar' ) net_pruned = storage.restore_net(checkpoint) net_pruned.load_state_dict(checkpoint['state_dict']) # batch_size=[256,512,1024] num_workers = [i for i in range(4, 5)] batch_size = [300, 600, 1000, 1600] device_list = [torch.device('cuda')] # # device_list=[torch.device('cpu')] for num_worker in num_workers: time_original = [] time_pruned = [] for d in device_list: for bs in batch_size: net_original.to(d) net_pruned.to(d) dl = data_loader.create_validation_loader( batch_size=bs, num_workers=num_worker, dataset_name='cifar10') start_time = time.time() evaluate.evaluate_net(net=net_original, data_loader=dl, save_net=False, device=d) end_time = time.time() time_original.append(end_time - start_time) del dl dl = data_loader.create_validation_loader( batch_size=bs, num_workers=num_worker, dataset_name='cifar10') start_time = time.time() evaluate.evaluate_net(net=net_pruned, data_loader=dl, save_net=False, device=d) end_time = time.time() time_pruned.append(end_time - start_time) del dl print('time before pruned:', time_original) print('time after pruned:', time_pruned) acceleration = np.array(time_original) / np.array(time_pruned) baseline = np.ones(shape=len(batch_size)) x_tick = range(len(baseline)) plt.figure() plt.bar(x_tick, acceleration, color='blue', hatch='//') #,label='GPU') # plt.bar(x_tick[len(batch_size):], acceleration[len(batch_size):], color='grey', hatch='\\', label='CPU') # plt.bar(x_tick,baseline,color='red',hatch='*',label='Baseline') plt.xticks(x_tick, batch_size, fontsize=fontsize) plt.yticks(fontsize=fontsize) for x, y in enumerate(list(acceleration)): plt.text(x, y + 0.1, '%.2f x' % y, ha='center', fontsize=fontsize) plt.ylim([0, np.max(acceleration) + 0.5]) plt.xlabel('Batch-Size', fontsize=fontsize) plt.ylabel('Speed-Up', fontsize=fontsize) # plt.legend(loc='upper left') plt.savefig('resnet_gpu_speed_up.eps', format='eps') # plt.savefig(str(num_worker)+'speed_up.jpg') plt.show() print()
def read_data(path='/home/victorfang/Desktop/dead_filter(normal_distribution)', balance=False, regression_or_classification='regression', num_images=None, sample_num=None): #note that classification function is abandoned, the code involved might be wrong if regression_or_classification is 'regression': filter = [] filter_layer = [] filter_label = [] elif regression_or_classification is 'classification': dead_filter = [] dead_filter_layer = [] living_filter = [] living_filter_layer = [] else: raise AttributeError file_list = os.listdir(path) for file_name in file_list: if '.tar' in file_name: checkpoint = torch.load(os.path.join(path, file_name)) net = storage.restore_net(checkpoint) net.load_state_dict(checkpoint['state_dict']) neural_list = checkpoint['neural_list'] try: module_list = checkpoint['module_list'] except KeyError: module_list = checkpoint['relu_list'] if regression_or_classification is 'classification': # neural_dead_times=checkpoint['neural_dead_times'] neural_dead_times = 8000 filter_FIRE = checkpoint['filter_FIRE'] num_conv = 0 # num of conv layers in the network filter_num = [] filters = [] layers = [] for mod in net.modules(): if isinstance(mod, torch.nn.modules.conv.Conv2d): num_conv += 1 conv = mod elif isinstance( mod, torch.nn.ReLU): #ensure the conv are followed by relu if layers != [] and layers[ -1] == num_conv - 1: # get rid of the influence from relu in fc continue filter_num.append(conv.out_channels) filters.append(conv) layers.append(num_conv - 1) for i in range(len(filters)): for module_key in list(neural_list.keys()): if module_list[ i] is module_key: #find the neural_list_statistics in layer i+1 dead_times = neural_list[module_key] neural_num = dead_times.shape[1] * dead_times.shape[ 2] #neural num for one filter filter_weight = filters[i].weight.data.cpu().numpy() if regression_or_classification is 'classification': # judge dead filter by neural_dead_times and dead_filter_ratio dead_times[dead_times < neural_dead_times] = 0 dead_times[dead_times >= neural_dead_times] = 1 dead_times = np.sum(dead_times, axis=( 1, 2 )) #count the number of dead neural for one filter dead_filter_index = np.where( dead_times > neural_num * filter_FIRE)[0].tolist() living_filter_index = [ i for i in range(filter_num[i]) if i not in dead_filter_index ] for ind in dead_filter_index: dead_filter.append(filter_weight[ind]) dead_filter_layer += [ i for j in range(len(dead_filter_index)) ] for ind in living_filter_index: living_filter.append(filter_weight[ind]) living_filter_layer += [ i for j in range(len(living_filter_index)) ] else: #compute sum(dead_times)/(num_images*neural_num) as label for each filter dead_times = np.sum(dead_times, axis=(1, 2)) prediction = dead_times / (neural_num * num_images) for f in filter_weight: filter.append(f) filter_label += prediction.tolist() filter_layer += [ layers[i] for j in range(filter_weight.shape[0]) ] if regression_or_classification is 'classification' and balance is True: living_filter = living_filter[:len(dead_filter)] living_filter_layer = living_filter_layer[:len(living_filter_index)] if regression_or_classification is 'classification': return dead_filter, living_filter, dead_filter_layer, living_filter_layer elif regression_or_classification is 'regression': if sample_num is not None: index = random.sample([i for i in range(len(filter))], sample_num) filter = np.array(filter)[index].tolist() filter_label = np.array(filter_label)[index].tolist() filter_layer = np.array(filter_layer)[index].tolist() return filter, filter_label, filter_layer
def train( net, net_name, exp_name='', dataset_name='imagenet', train_loader=None, validation_loader=None, learning_rate=conf.learning_rate, num_epochs=conf.num_epochs, batch_size=conf.batch_size, evaluate_step=conf.evaluate_step, load_net=True, test_net=False, root_path=conf.root_path, checkpoint_path=None, momentum=conf.momentum, num_workers=conf.num_workers, learning_rate_decay=False, learning_rate_decay_factor=conf.learning_rate_decay_factor, learning_rate_decay_epoch=conf.learning_rate_decay_epoch, weight_decay=conf.weight_decay, target_accuracy=1.0, optimizer=optim.SGD, top_acc=1, criterion=nn.CrossEntropyLoss(), # 损失函数默为交叉熵,多用于多分类问题 no_grad=[], scheduler_name='MultiStepLR', eta_min=0, #todo:tmp!!! data_parallel=False): ''' :param net: net to be trained :param net_name: name of the net :param exp_name: name of the experiment :param dataset_name: name of the dataset :param train_loader: data_loader for training. If not provided, a data_loader will be created based on dataset_name :param validation_loader: data_loader for validation. If not provided, a data_loader will be created based on dataset_name :param learning_rate: initial learning rate :param learning_rate_decay: boolean, if true, the learning rate will decay based on the params provided. :param learning_rate_decay_factor: float. learning_rate*=learning_rate_decay_factor, every time it decay. :param learning_rate_decay_epoch: list[int], the specific epoch that the learning rate will decay. :param num_epochs: max number of epochs for training :param batch_size: :param evaluate_step: how often will the net be tested on validation set. At least one test every epoch is guaranteed :param load_net: boolean, whether loading net from previous checkpoint. The newest checkpoint will be selected. :param test_net:boolean, if true, the net will be tested before training. :param root_path: :param checkpoint_path: :param momentum: :param num_workers: :param weight_decay: :param target_accuracy:float, the training will stop once the net reached target accuracy :param optimizer: :param top_acc: can be 1 or 5 :param criterion: loss function :param no_grad: list containing names of the modules that do not need to be trained :param scheduler_name :param eta_min: for CosineAnnealingLR :return: ''' success = True #if the trained net reaches target accuracy # gpu or not device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('using: ', end='') if torch.cuda.is_available(): print(torch.cuda.device_count(), ' * ', end='') print(torch.cuda.get_device_name(torch.cuda.current_device())) else: print(device) #prepare the data if dataset_name is 'imagenet': mean = conf.imagenet['mean'] std = conf.imagenet['std'] train_set_path = conf.imagenet['train_set_path'] train_set_size = conf.imagenet['train_set_size'] validation_set_path = conf.imagenet['validation_set_path'] default_image_size = conf.imagenet['default_image_size'] elif dataset_name is 'cifar10': train_set_size = conf.cifar10['train_set_size'] mean = conf.cifar10['mean'] std = conf.cifar10['std'] train_set_path = conf.cifar10['dataset_path'] validation_set_path = conf.cifar10['dataset_path'] default_image_size = conf.cifar10['default_image_size'] elif dataset_name is 'tiny_imagenet': train_set_size = conf.tiny_imagenet['train_set_size'] mean = conf.tiny_imagenet['mean'] std = conf.tiny_imagenet['std'] train_set_path = conf.tiny_imagenet['train_set_path'] validation_set_path = conf.tiny_imagenet['validation_set_path'] default_image_size = conf.tiny_imagenet['default_image_size'] elif dataset_name is 'cifar100': train_set_size = conf.cifar100['train_set_size'] mean = conf.cifar100['mean'] std = conf.cifar100['std'] train_set_path = conf.cifar100['dataset_path'] validation_set_path = conf.cifar100['dataset_path'] default_image_size = conf.cifar100['default_image_size'] if train_loader is None: train_loader = data_loader.create_train_loader( dataset_path=train_set_path, default_image_size=default_image_size, mean=mean, std=std, batch_size=batch_size, num_workers=num_workers, dataset_name=dataset_name) if validation_loader is None: validation_loader = data_loader.create_validation_loader( dataset_path=validation_set_path, default_image_size=default_image_size, mean=mean, std=std, batch_size=batch_size, num_workers=num_workers, dataset_name=dataset_name) if checkpoint_path is None: checkpoint_path = os.path.join(root_path, 'model_saved', exp_name, 'checkpoint') if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) #get the latest checkpoint lists = os.listdir(checkpoint_path) file_new = checkpoint_path if len(lists) > 0: lists.sort(key=lambda fn: os.path.getmtime(checkpoint_path + "/" + fn) ) # 按时间排序 file_new = os.path.join(checkpoint_path, lists[-1]) # 获取最新的文件保存到file_new sample_num = 0 if os.path.isfile(file_new): if load_net: checkpoint = torch.load(file_new) print('{} load net from previous checkpoint:{}'.format( datetime.now(), file_new)) # net.load_state_dict(checkpoint['state_dict']) net = storage.restore_net(checkpoint, pretrained=True, data_parallel=data_parallel) sample_num = checkpoint['sample_num'] if test_net: print('{} test the net'.format( datetime.now())) #no previous checkpoint net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net(net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format(datetime.now())) return success #ensure the net will be evaluated despite the inappropriate evaluate_step if evaluate_step > math.ceil(train_set_size / batch_size) - 1: evaluate_step = math.ceil(train_set_size / batch_size) - 1 optimizer = prepare_optimizer(net, optimizer, no_grad, momentum, learning_rate, weight_decay) if learning_rate_decay: if scheduler_name == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=learning_rate_decay_epoch, gamma=learning_rate_decay_factor, last_epoch=ceil(sample_num / train_set_size)) elif scheduler_name == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, num_epochs, eta_min=eta_min, last_epoch=ceil(sample_num / train_set_size)) print("{} Start training ".format(datetime.now()) + net_name + "...") for epoch in range(math.floor(sample_num / train_set_size), num_epochs): print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) net.train() # one epoch for one loop for step, data in enumerate(train_loader, 0): if sample_num / train_set_size == epoch + 1: #one epoch of training finished net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net( net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format( datetime.now())) return success break # 准备数据 images, labels = data images, labels = images.to(device), labels.to(device) sample_num += int(images.shape[0]) # if learning_rate_decay: # exponential_decay_learning_rate(optimizer=optimizer, # sample_num=sample_num, # learning_rate_decay_factor=learning_rate_decay_factor, # train_set_size=train_set_size, # learning_rate_decay_epoch=learning_rate_decay_epoch, # batch_size=batch_size) optimizer.zero_grad() # forward + backward outputs = net(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if step % 60 == 0: print('{} loss is {}'.format(datetime.now(), float(loss.data))) if step % evaluate_step == 0 and step != 0: net_test = copy.deepcopy(net) accuracy = evaluate.evaluate_net( net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) del net_test if accuracy >= target_accuracy: print('{} net reached target accuracy.'.format( datetime.now())) return success accuracy = float(accuracy) print('{} continue training'.format(datetime.now())) if learning_rate_decay: scheduler.step() print(optimizer.state_dict()['param_groups'][0]['lr']) print("{} Training finished. Saving net...".format(datetime.now())) net_test = copy.deepcopy(net) flop_num = measure_flops.measure_model(net=net_test, dataset_name=dataset_name, print_flop=False) accuracy = evaluate.evaluate_net(net_test, validation_loader, save_net=True, checkpoint_path=checkpoint_path, sample_num=sample_num, target_accuracy=target_accuracy, dataset_name=dataset_name, top_acc=top_acc, net_name=net_name, exp_name=exp_name) accuracy = float(accuracy) checkpoint = { 'highest_accuracy': accuracy, 'state_dict': net.state_dict(), 'sample_num': sample_num, 'flop_num': flop_num } checkpoint.update(storage.get_net_information(net, dataset_name, net_name)) torch.save( checkpoint, '%s/flop=%d,accuracy=%.5f.tar' % (checkpoint_path, flop_num, accuracy)) print("{} net saved at sample num = {}".format(datetime.now(), sample_num)) return not success
from network import create_net, net_with_mask, vgg, storage from framework import config as conf from framework.train import name_parameters_no_grad import os # os.environ["CUDA_VISIBLE_DEVICES"] = "5" os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # net=storage.restore_net(torch.load('/home/victorfang/model_pytorch/data/model_saved/resnet56_extractor_static_cifar10_only_gcn_1/checkpoint/flop=62577290,accuracy=0.93400.tar'),pretrained=True) # net=storage.restore_net(torch.load('/home/victorfang/model_pytorch/data/model_saved/resnet56_extractor_static_cifar10_only_gcn_3/checkpoint/flop=62061194,accuracy=0.93410.tar'),pretrained=True) # net=storage.restore_net(torch.load('/home/victorfang/model_pytorch/data/model_saved/resnet56_extractor_static_cifar100_2_train/checkpoint/flop=95299940,accuracy=0.70470.tar'),pretrained=True) net = storage.restore_net(torch.load( '/home/victorfang/model_pytorch/data/model_saved/resnet50_extractor_static_imagenet2/checkpoint/flop=1796559732,accuracy=0.91862.tar' ), pretrained=True) # net=storage.restore_net(torch.load('/home/victorfang/model_pytorch/data/model_saved/resnet18_tinyimagenet_extractor_static_train/checkpoint/flop=1289487816,accuracy=0.67650.tar'),pretrained=True) net = torch.nn.DataParallel(net) i = 0 success = False while not success and i < 1: # success=train.train(net=net, # # net_name='vgg16_bn', # exp_name='vgg16_extractor_static_imagenet_train', # # num_epochs=100, # learning_rate=0.001, # learning_rate_decay=True,