def flops_counter(args): # model speed up torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, criterion = get_data(args) if args.pruner != 'AutoCompressPruner': if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'resnet18': model = models.resnet18(pretrained=False, num_classes=10).to(device) elif args.model == 'mobilenet_v2': model = models.mobilenet_v2(pretrained=False).to(device) def evaluator(model): return test(model, device, criterion, val_loader) model.load_state_dict( torch.load( os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))) masks_file = os.path.join(args.experiment_data_dir, 'mask.pth') dummy_input = get_dummy_input(args, device) m_speedup = ModelSpeedup(model, dummy_input, masks_file, device) m_speedup.speedup_model() evaluation_result = evaluator(model) print('Evaluation result (speed up model): %s' % evaluation_result) with open(os.path.join(args.experiment_data_dir, 'performance.json')) as f: result = json.load(f) result['speedup'] = evaluation_result with open(os.path.join(args.experiment_data_dir, 'performance.json'), 'w+') as f: json.dump(result, f) torch.save( model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth')) print('Speed up model saved to %s', args.experiment_data_dir) else: model = torch.load( os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) model.eval() flops, params = count_flops_params(model, (1, 3, 32, 32)) with open(os.path.join(args.experiment_data_dir, 'flops.json'), 'w+') as f: json.dump({'FLOPS': int(flops), 'params': int(params)}, f)
def model_inference(config): model_trained = './experiment_data/resnet_bn/model_fine_tuned_first.pth' rn50 = resnet50() m_paras = torch.load(model_trained) ##delete mask in pth m_new = collections.OrderedDict() mask = dict() for key in m_paras: if 'weight_mask_b' in key: continue if 'weight_mask' in key: if 'module_added' not in key: mask[key.replace('.weight_mask', '')] = dict() mask[key.replace('.weight_mask', '')]['weight'] = m_paras[key] mask[key.replace('.weight_mask', '')]['bias'] = m_paras[key] else: mask[key.replace('.relu1.module_added.weight_mask', '.bn3')] = {} mask[key.replace('.relu1.module_added.weight_mask', '.bn3')]['weight'] = m_paras[key] mask[key.replace('.relu1.module_added.weight_mask', '.bn3')]['bias'] = m_paras[key] if '0.relu1' in key: mask[key.replace('relu1.module_added.weight_mask', 'downsample.1')] = {} mask[key.replace('relu1.module_added.weight_mask', 'downsample.1')]['weight'] = m_paras[key] mask[key.replace('relu1.module_added.weight_mask', 'downsample.1')]['bias'] = m_paras[key] continue if 'module_added' in key: continue elif 'module' in key: m_new[key.replace('module.', '')] = m_paras[key] else: m_new[key] = m_paras[key] for key in mask: #modify the weight and bias of model with pruning m_new[key + '.weight'] = m_new[key + '.weight'].data.mul( mask[key]['weight']) m_new[key + '.bias'] = m_new[key + '.bias'].data.mul(mask[key]['bias']) rn50.load_state_dict(m_new) rn50.cuda() rn50.eval() torch.save(mask, 'taylor_mask.pth') mask_file = './taylor_mask.pth' dummy_input = torch.randn(64, 3, 224, 224).cuda() use_mask_out = use_speedup_out = None rn = rn50 rn_mask_out = rn(dummy_input) model = rn50 if use_mask: torch.onnx.export(model, dummy_input, 'resnet_masked_taylor_1700.onnx', export_params=True, opset_version=12, do_constant_folding=True, input_names=['inputs'], output_names=['proba'], dynamic_axes={ 'inputs': [0], 'mask': [0] }, keep_initializers_as_inputs=True) start = time.time() for _ in range(32): use_mask_out = model(dummy_input) elapsed_t = time.time() - start print('elapsed time when use mask: ', elapsed_t) _logger.info( 'for batch size 64 and with 32 runs, the elapsed time is {}'. format(elapsed_t)) print('before speed up===================') flops, paras = count_flops_params(model, (1, 3, 224, 224)) _logger.info( 'flops and parameters before speedup is {} FLOPS and {} params'.format( flops, paras)) if use_speedup: dummy_input.cuda() m_speedup = ModelSpeedup(model, dummy_input, mask_file, 'cuda') m_speedup.speedup_model() print('==' * 20) print('Start inference') torch.onnx.export(model, dummy_input, 'resnet_taylor_1700.onnx', export_params=True, opset_version=12, do_constant_folding=True, input_names=['inputs'], output_names=['proba'], dynamic_axes={ 'inputs': [0], 'mask': [0] }, keep_initializers_as_inputs=True) start = time.time() for _ in range(32): use_speedup_out = model(dummy_input) elasped_t1 = time.time() - start print('elapsed time when use speedup: ', elasped_t1) _logger.info( 'elasped time with batch_size 64 and in 32 runs is {}'.format( elasped_t1)) #print('After speedup model is ',model) _logger.info('model structure after speedup is ====') _logger.info(model) print('=================') print('After speedup') flops, paras = count_flops_params(model, (1, 3, 224, 224)) _logger.info( 'After speedup flops are {} and number of parameters are {}'.format( flops, paras)) if compare_results: print(rn_mask_out) print('another is', use_speedup_out) if torch.allclose(rn_mask_out, use_speedup_out, atol=1e-6): #-07): print('the outputs from use_mask and use_speedup are the same') else: raise RuntimeError( 'the outputs from use_mask and use_speedup are different') # start the accuracy check criterion = nn.CrossEntropyLoss() with torch.no_grad(): start = time.time() evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20) print('elapsed time is ', time.time() - start)
def main(args): # prepare dataset torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, criterion = get_data(args) model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion) def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) def trainer(model, optimizer, criterion, epoch, callback): return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback) def evaluator(model): return test(model, device, criterion, val_loader) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance': {}} flops, params = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result # module types to prune, only "Conv2d" supported for channel pruning if args.base_algo in ['l1', 'l2']: op_types = ['Conv2d'] elif args.base_algo == 'level': op_types = ['default'] config_list = [{'sparsity': args.sparsity, 'op_types': op_types}] dummy_input = get_dummy_input(args, device) if args.pruner == 'L1FilterPruner': pruner = L1FilterPruner(model, config_list) elif args.pruner == 'L2FilterPruner': pruner = L2FilterPruner(model, config_list) elif args.pruner == 'ActivationMeanRankFilterPruner': pruner = ActivationMeanRankFilterPruner(model, config_list) elif args.pruner == 'ActivationAPoZRankFilterPruner': pruner = ActivationAPoZRankFilterPruner(model, config_list) elif args.pruner == 'NetAdaptPruner': pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator, base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'ADMMPruner': # users are free to change the config here if args.model == 'LeNet': if args.base_algo in ['l1', 'l2']: config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_types': ['Conv2d'], 'op_names': ['conv2'] }] elif args.base_algo == 'level': config_list = [{ 'sparsity': 0.8, 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_names': ['conv2'] }, { 'sparsity': 0.991, 'op_names': ['fc1'] }, { 'sparsity': 0.93, 'op_names': ['fc2'] }] else: raise ValueError('Example only implemented for LeNet.') pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2) elif args.pruner == 'SimulatedAnnealingPruner': pruner = SimulatedAnnealingPruner( model, config_list, evaluator=evaluator, base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'AutoCompressPruner': pruner = AutoCompressPruner( model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5, experiment_data_dir=args.experiment_data_dir) else: raise ValueError("Pruner not supported.") # Pruner.compress() returns the masked model # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model model = pruner.compress() evaluation_result = evaluator(model) print('Evaluation result (masked model): %s' % evaluation_result) result['performance']['pruned'] = evaluation_result if args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s', args.experiment_data_dir) # model speed up if args.speed_up: if args.pruner != 'AutoCompressPruner': if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'resnet18': model = ResNet18().to(device) elif args.model == 'resnet50': model = ResNet50().to(device) elif args.model == 'mobilenet_v2': model = models.mobilenet_v2(pretrained=False).to(device) model.load_state_dict( torch.load( os.path.join(args.experiment_data_dir, 'model_masked.pth'))) masks_file = os.path.join(args.experiment_data_dir, 'mask.pth') m_speedup = ModelSpeedup(model, dummy_input, masks_file, device) m_speedup.speedup_model() evaluation_result = evaluator(model) print('Evaluation result (speed up model): %s' % evaluation_result) result['performance']['speedup'] = evaluation_result torch.save( model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth')) print('Speed up model saved to %s', args.experiment_data_dir) flops, params = count_flops_params(model, get_input_size(args.dataset)) result['flops']['speedup'] = flops result['params']['speedup'] = params if args.fine_tune: if args.dataset == 'mnist': optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.dataset == 'cifar10' and args.model == 'vgg16': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.5), int(args.fine_tune_epochs * 0.75) ], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet18': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.5), int(args.fine_tune_epochs * 0.75) ], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet50': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.5), int(args.fine_tune_epochs * 0.75) ], gamma=0.1) best_acc = 0 for epoch in range(args.fine_tune_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = evaluator(model) if acc > best_acc: best_acc = acc torch.save( model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s', args.experiment_data_dir) result['performance']['finetuned'] = best_acc with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f)
def model_inference(config): masks_file = './speedup_test/mask_new.pth' shape_mask = './speedup_test/mask_new.pth' org_mask = './speedup_test/mask.pth' rn50 = models.resnet50() m_paras = torch.load('./speedup_test/model_fine_tuned.pth') ##delete mask in pth m_new = collections.OrderedDict() for key in m_paras: if 'mask' in key: continue if 'module' in key: m_new[key.replace('module.', '')] = m_paras[key] else: m_new[key] = m_paras[key] rn50.load_state_dict(m_new) rn50.cuda() rn50.eval() dummy_input = torch.randn(64, 3, 224, 224).cuda() use_mask_out = use_speedup_out = None rn = rn50 apply_compression_results(rn, org_mask, 'cuda') rn_mask_out = rn(dummy_input) model = rn50 # must run use_mask before use_speedup because use_speedup modify the model if use_mask: apply_compression_results(model, masks_file, 'cuda') torch.onnx.export(model, dummy_input, 'resnet_masked.onnx', export_params=True, opset_version=12, do_constant_folding=True, input_names=['inputs'], output_names=['proba'], dynamic_axes={ 'inputs': [0], 'mask': [0] }, keep_initializers_as_inputs=True) start = time.time() for _ in range(32): use_mask_out = model(dummy_input) print('elapsed time when use mask: ', time.time() - start) print('Model is ', model) print('before speed up===================') # print(para) # print(model.state_dict()[para]) # print(model.state_dict()[para].shape) flops, paras = count_flops_params(model, (1, 3, 224, 224)) print( 'flops and parameters before speedup is {} FLOPS and {} params'.format( flops, paras)) if use_speedup: dummy_input.cuda() m_speedup = ModelSpeedup(model, dummy_input, shape_mask, 'cuda') m_speedup.speedup_model() print('==' * 20) print('Start inference') torch.onnx.export(model, dummy_input, 'resnet_fpgm.onnx', export_params=True, opset_version=12, do_constant_folding=True, input_names=['inputs'], output_names=['proba'], dynamic_axes={ 'inputs': [0], 'mask': [0] }, keep_initializers_as_inputs=True) start = time.time() for _ in range(32): use_speedup_out = model(dummy_input) print('elapsed time when use speedup: ', time.time() - start) print('After speedup model is ', model) print('=================') print('After speedup') flops, paras = count_flops_params(model, (1, 3, 224, 224)) print( 'flops and parameters before speedup is {} FLOPS and {} params'.format( flops, paras)) #for para in model.state_dict(): # print(para) # print(model.state_dict()[para]) # print(model.state_dict()[para].shape) if compare_results: print(rn_mask_out) print('another is', use_speedup_out) if torch.allclose(rn_mask_out, use_speedup_out, atol=1e-6): #-07): print('the outputs from use_mask and use_speedup are the same') else: raise RuntimeError( 'the outputs from use_mask and use_speedup are different') # start the accuracy check criterion = nn.CrossEntropyLoss() with torch.no_grad(): start = time.time() evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20) print('elapsed time is ', time.time() - start)
def main(args): # prepare dataset torch.manual_seed(0) model = resnet50() model.load_state_dict(torch.load(args.pretrained_model_dir)) inited = init_distributed(True) #use nccl fro communication print('all cudas numbers are ', get_world_size()) distributed = (get_world_size() > 1) and inited paral = get_world_size() #device = torch.device('cuda',args.local_rank) if distributed else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #args.rank = get_rank() device = set_device(args.cuda, args.local_rank) #write to tensorboard logger = Logger("logs/" + str(args.local_rank)) print(distributed) print('local rank is {}'.format(args.local_rank)) train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size) print('to distribute ', distributed) model = model.to(args.local_rank) if distributed: model = DDP(model, device_ids=[args.local_rank ]) #, output_device=args.local_rank) #elif args.mgpu: # model=torch.nn.DataParallel(model).cuda() else: model = model.cuda() #model = torch.nn.DataParallel(model).cuda() criterion = criterion.to(args.local_rank) #for module in model.named_modules(): # print('%'*20) # print(module) # module types to prune, only "BatchNorm2d" supported for channel pruning with Taylor config_step = { 'bn_statics': args.bn_statistics, 'fre': args.freq, 'tot_pru': args.tot_pru, 'save_path': args.experiment_data_dir } config_list = [{ 'op_types': ['BatchNorm2d', 'ReLU'], 'must_names': 'layer', 'include_names': ['bn1', 'bn2', 'relu1'] }] dummy_input = get_dummy_input(args, args.local_rank) if args.pruner == 'TaylorPruner': pruner = TaylorPruner(device, model, config_list, config_step, dependency_aware=False, optimizer=None) else: raise ValueError("Pruner not supported.") # Pruner.compress() returns the masked model model = pruner.compress() if args.pruner == 'TaylorPruner': params_update = [ param[1] for param in model.named_parameters() if 'mask' not in param[0] ] not_updates = [ param[1] for param in model.named_parameters() if 'mask' in param[0] ] #print(params_update) params_all = [{ 'params': params_update }, { 'params': not_updates, 'lr': 0.0 }] if args.fine_tune: if args.dataset in ['imagenet'] and args.model == 'resnet50': optimizer = torch.optim.SGD(params_all, lr=0.01, momentum=0.5, weight_decay=1e-8) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.3), int(args.fine_tune_epochs * 0.6), int(args.fine_tune_epochs * 0.8) ], gamma=0.1) pruner.optimizer = optimizer pruner.patch_optimizer(pruner.masker.calc_contributions) pruner.keep_org_step() else: raise ValueError def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(pruner, args, model, device, train_loader, criterion, optimizer, epoch, logger) def trainer(pruner, model, optimizer, criterion, epoch, callback): return train(pruner, args, model, device, train_loader, criterion, optimizer, epoch=epoch, logger=logger, callback=callback) def evaluator(model, step): return test(model, device, criterion, val_loader, step, logger) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance': {}} #print(model) for module in model.named_modules(): print('DEBUG===name of module is ', module[0]) print('buffers are: ', [buff for buff in module[1].buffers()]) print('para are: ', [para for para in module[1].named_parameters()]) flops, params = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model, 0) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result if args.local_rank == 0 and args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s', args.experiment_data_dir) def wrapped(module): return isinstance(module, BNTaylorPrunerMasker) wrap_mask = [ module for module in model.named_modules() if wrapped(module[1]) ] for idx, mm in enumerate(wrap_mask): print('====****' * 10, idx) print(mm[0]) print(mm[1].state_dict().keys()) print('weight mask is ', mm[1].state_dict()['weight_mask']) if 'bias_mask' in mm[1].state_dict(): print('bias mask is ', mm[1].state_dict()['bias_mask']) if args.mgpu: model = torch.nn.DataParallel(model).cuda() print('local rank is', args.local_rank) if args.fine_tune: best_acc = 0 for epoch in range(args.fine_tune_epochs): print('start fine tune for epoch {}/{}'.format( epoch, args.fine_tune_epochs)) stime = time.time() train(pruner, args, model, args.local_rank, train_loader, criterion, optimizer, epoch, logger) scheduler.step() acc = evaluator(model, epoch) print('end fine tune for epoch {}/{} for {} seconds'.format( epoch, args.fine_tune_epochs, time.time() - stime)) if acc > best_acc and args.local_rank == 0: best_acc = acc torch.save( model, os.path.join(args.experiment_data_dir, args.model, 'finetune_model.pt')) torch.save( model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s', args.experiment_data_dir) result['performance']['finetuned'] = best_acc if args.local_rank == 0: with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f)
def main(args): # prepare dataset torch.manual_seed(0) #device = torch.device('cuda',args.local_rank) if distributed else torch.device("cuda" if torch.cuda.is_available() else "cpu") device = set_device(args.cuda, args.local_rank) inited = init_distributed(True) #use nccl fro communication print('all cudas numbers are ', get_world_size()) distributed = (get_world_size() > 1) and inited paral = get_world_size() args.rank = get_rank() #write to tensorboard logger = Logger("logs/" + str(args.rank)) print(distributed) #device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('device is', device) print('rank is {} local rank is {}'.format(args.rank, args.local_rank)) train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size) model = torchvision.models.resnet50(pretrained=True) model = model.cuda() print('to distribute ', distributed) if distributed: model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) #model = torch.nn.DataParallel(model).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75) ], gamma=0.1) criterion = criterion.cuda() #model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion) def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch, logger) def trainer(model, optimizer, criterion, epoch, callback): return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, logger=logger, callback=callback) def evaluator(model, step): return test(model, device, criterion, val_loader, step, logger) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance': {}} flops, params = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model, 0) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result # module types to prune, only "Conv2d" supported for channel pruning if args.base_algo in ['l1', 'l2']: op_types = ['Conv2d'] elif args.base_algo == 'level': op_types = ['default'] config_list = [{ 'sparsity': args.sparsity, 'op_types': op_types, 'exclude_names': 'downsample' }] dummy_input = get_dummy_input(args, device) if args.pruner == 'FPGMPruner': pruner = MyPruner(model, config_list) else: raise ValueError("Pruner not supported.") # Pruner.compress() returns the masked model model = pruner.compress() evaluation_result = evaluator(model, 0) print('Evaluation result (masked model): %s' % evaluation_result) result['performance']['pruned'] = evaluation_result if args.rank == 0 and args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s', args.experiment_data_dir) def wrapped(module): return isinstance(module, BNWrapper) or isinstance( module, PrunerModuleWrapper) wrap_mask = [ module for module in model.named_modules() if wrapped(module[1]) ] for mm in wrap_mask: print('====****' * 10) print(mm[0]) print(mm[1].state_dict().keys()) print('weight mask is ', mm[1].state_dict()['weight_mask']) if 'bias_mask' in mm[1].state_dict(): print('bias mask is ', mm[1].state_dict()['bias_mask']) if args.fine_tune: if args.dataset in ['imagenet'] and args.model == 'resnet50': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.3), int(args.fine_tune_epochs * 0.6), int(args.fine_tune_epochs * 0.8) ], gamma=0.1) else: raise ValueError("Pruner not supported.") best_acc = 0 for epoch in range(args.fine_tune_epochs): print('start fine tune for epoch {}/{}'.format( epoch, args.fine_tune_epochs)) stime = time.time() train(args, model, device, train_loader, criterion, optimizer, epoch, logger) scheduler.step() acc = evaluator(model, epoch) print('end fine tune for epoch {}/{} for {} seconds'.format( epoch, args.fine_tune_epochs, time.time() - stime)) if acc > best_acc and args.rank == 0: best_acc = acc torch.save( model, os.path.join(args.experiment_data_dir, args.model, 'finetune_model.pt')) torch.save( model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s', args.experiment_data_dir) result['performance']['finetuned'] = best_acc if args.rank == 0: with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f)