def test_speedup_integration(self): # skip this test on windows(7GB mem available) due to memory limit # Note: hack trick, may be updated in the future if 'win' in sys.platform or 'Win' in sys.platform: print( 'Skip test_speedup_integration on windows due to memory limit!' ) return Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] for model_name in [ 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121', 'densenet169', # 'inception_v3' inception is too large and may fail the pipeline 'resnet50' ]: for gen_cfg_func in Gen_cfg_funcs: kwargs = {'pretrained': True} if model_name == 'resnet50': # testing multiple groups kwargs = {'pretrained': False, 'groups': 4} Model = getattr(models, model_name) net = Model(**kwargs).to(device) speedup_model = Model(**kwargs).to(device) net.eval() # this line is necessary speedup_model.eval() # random generate the prune config for the pruner cfgs = gen_cfg_func(net) print("Testing {} with compression config \n {}".format( model_name, cfgs)) pruner = L1FilterPruner(net, cfgs) pruner.compress() pruner.export_model(MODEL_FILE, MASK_FILE) pruner._unwrap_model() state_dict = torch.load(MODEL_FILE) speedup_model.load_state_dict(state_dict) zero_bn_bias(net) zero_bn_bias(speedup_model) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device) ms = ModelSpeedup(speedup_model, data, MASK_FILE) ms.speedup_model() speedup_model.eval() ori_out = net(data) speeded_out = speedup_model(data) ori_sum = torch.sum(ori_out).item() speeded_sum = torch.sum(speeded_out).item() print('Sum of the output of %s (before speedup):' % model_name, ori_sum) print('Sum of the output of %s (after speedup):' % model_name, speeded_sum) assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def prune_model_l1(model): config_list = [{ 'sparsity': SPARSITY, 'op_types': ['Conv2d'] }] pruner = L1FilterPruner(model, config_list) pruner.compress() pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
def test_speedup_tupleunpack(self): """This test is reported in issue3645""" model = TupleUnpack_Model() cfg_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}] dummy_input = torch.rand(2, 3, 224, 224) pruner = L1FilterPruner(model, cfg_list) pruner.compress() model(dummy_input) pruner.export_model(MODEL_FILE, MASK_FILE) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8) ms.speedup_model()
def test_mask_conflict(self): outdir = os.path.join(prefix, 'masks') os.makedirs(outdir, exist_ok=True) for name in model_names: print('Test mask conflict for %s' % name) model = getattr(models, name) net = model().to(device) dummy_input = torch.ones(1, 3, 224, 224).to(device) # random generate the prune sparsity for each layer cfglist = [] for layername, layer in net.named_modules(): if isinstance(layer, nn.Conv2d): # pruner cannot allow the sparsity to be 0 or 1 sparsity = np.random.uniform(0.01, 0.99) cfg = { 'op_types': ['Conv2d'], 'op_names': [layername], 'sparsity': sparsity } cfglist.append(cfg) pruner = L1FilterPruner(net, cfglist) pruner.compress() ck_file = os.path.join(outdir, '%s.pth' % name) mask_file = os.path.join(outdir, '%s_mask' % name) pruner.export_model(ck_file, mask_file) pruner._unwrap_model() # Fix the mask conflict fixed_mask = fix_mask_conflict(mask_file, net, dummy_input) # use the channel dependency groud truth to check if # fix the mask conflict successfully for dset in channel_dependency_ground_truth[name]: lset = list(dset) for i, _ in enumerate(lset): assert fixed_mask[lset[0]]['weight'].size(0) == fixed_mask[ lset[i]]['weight'].size(0) w_index1 = self.get_pruned_index( fixed_mask[lset[0]]['weight']) w_index2 = self.get_pruned_index( fixed_mask[lset[i]]['weight']) assert w_index1 == w_index2 if hasattr(fixed_mask[lset[0]], 'bias'): b_index1 = self.get_pruned_index( fixed_mask[lset[0]]['bias']) b_index2 = self.get_pruned_index( fixed_mask[lset[i]]['bias']) assert b_index1 == b_index2
def test_speedup_integration(self): for model_name in [ 'resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', # 'inception_v3' inception is too large and may fail the pipeline 'densenet169', 'resnet50' ]: kwargs = {'pretrained': True} if model_name == 'resnet50': # testing multiple groups kwargs = {'pretrained': False, 'groups': 4} Model = getattr(models, model_name) net = Model(**kwargs).to(device) speedup_model = Model(**kwargs).to(device) net.eval() # this line is necessary speedup_model.eval() # random generate the prune config for the pruner cfgs = generate_random_sparsity(net) pruner = L1FilterPruner(net, cfgs) pruner.compress() pruner.export_model(MODEL_FILE, MASK_FILE) pruner._unwrap_model() state_dict = torch.load(MODEL_FILE) speedup_model.load_state_dict(state_dict) zero_bn_bias(net) zero_bn_bias(speedup_model) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device) ms = ModelSpeedup(speedup_model, data, MASK_FILE) ms.speedup_model() speedup_model.eval() ori_out = net(data) speeded_out = speedup_model(data) ori_sum = torch.sum(ori_out).item() speeded_sum = torch.sum(speeded_out).item() print('Sum of the output of %s (before speedup):' % model_name, ori_sum) print('Sum of the output of %s (after speedup):' % model_name, speeded_sum) assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_multiplication_speedup(self): """ Model from issue 4540. """ class Net(torch.nn.Module): def __init__(self, ): super(Net, self).__init__() self.avgpool = torch.nn.AdaptiveAvgPool2d(1) self.input = torch.nn.Conv2d(3, 8, 3) self.bn = torch.nn.BatchNorm2d(8) self.fc1 = torch.nn.Conv2d(8, 16, 1) self.fc2 = torch.nn.Conv2d(16, 8, 1) self.activation = torch.nn.ReLU() self.scale_activation = torch.nn.Hardsigmoid() self.out = torch.nn.Conv2d(8, 12, 1) def forward(self, input): input = self.activation(self.bn(self.input(input))) scale = self.avgpool(input) out1 = self.activation(self.fc1(scale)) out1 = self.scale_activation(self.fc2(out1)) return self.out(out1 * input) model = Net().to(device) model.eval() im = torch.ones(1, 3, 512, 512).to(device) model(im) cfg_list = [] for name, module in model.named_modules(): if isinstance(module, torch.nn.Conv2d): cfg_list.append({ 'op_types': ['Conv2d'], 'sparsity': 0.3, 'op_names': [name] }) pruner = L1FilterPruner(model, cfg_list) pruner.compress() pruner.export_model(MODEL_FILE, MASK_FILE) pruner._unwrap_model() ms = ModelSpeedup(model, im, MASK_FILE) ms.speedup_model()
def test_convtranspose_model(self): ori_model = TransposeModel() dummy_input = torch.rand(1, 3, 8, 8) config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] pruner = L1FilterPruner(ori_model, config_list) pruner.compress() ori_model(dummy_input) pruner.export_model(MODEL_FILE, MASK_FILE) pruner._unwrap_model() new_model = TransposeModel() state_dict = torch.load(MODEL_FILE) new_model.load_state_dict(state_dict) ms = ModelSpeedup(new_model, dummy_input, MASK_FILE) ms.speedup_model() zero_bn_bias(ori_model) zero_bn_bias(new_model) ori_out = ori_model(dummy_input) new_out = new_model(dummy_input) ori_sum = torch.sum(ori_out) speeded_sum = torch.sum(new_out) print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum)) assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs(args.experiment_data_dir, exist_ok=True) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=64, ) test_loader = torch.utils.data.DataLoader(datasets.MNIST( 'data', train=False, transform=transform), batch_size=1000) # Step1. Model Pretraining model = NaiveModel().to(device) criterion = torch.nn.NLLLoss() optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False) if args.pretrained_model_dir is None: args.pretrained_model_dir = os.path.join(args.experiment_data_dir, f'pretrained.pth') best_acc = 0 for epoch in range(args.pretrain_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(args, model, device, criterion, test_loader) if acc > best_acc: best_acc = acc state_dict = model.state_dict() model.load_state_dict(state_dict) torch.save(state_dict, args.pretrained_model_dir) print(f'Model saved to {args.pretrained_model_dir}') else: state_dict = torch.load(args.pretrained_model_dir) model.load_state_dict(state_dict) best_acc = test(args, model, device, criterion, test_loader) dummy_input = torch.randn([1000, 1, 28, 28]).to(device) time_cost = get_model_time_cost(model, dummy_input) # 125.49 M, 0.85M, 93.29, 1.1012 print( f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}' ) # Step2. Model Pruning config_list = [{'sparsity': args.sparsity, 'op_types': ['Conv2d']}] kw_args = {} if args.dependency_aware: dummy_input = torch.randn([1000, 1, 28, 28]).to(device) print('Enable the dependency_aware mode') # note that, not all pruners support the dependency_aware mode kw_args['dependency_aware'] = True kw_args['dummy_input'] = dummy_input pruner = L1FilterPruner(model, config_list, **kw_args) model = pruner.compress() pruner.get_pruned_weights() mask_path = os.path.join(args.experiment_data_dir, 'mask.pth') model_path = os.path.join(args.experiment_data_dir, 'pruned.pth') pruner.export_model(model_path=model_path, mask_path=mask_path) pruner._unwrap_model() # unwrap all modules to normal state # Step3. Model Speedup m_speedup = ModelSpeedup(model, dummy_input, mask_path, device) m_speedup.speedup_model() print('model after speedup', model) flops, params, _ = count_flops_params(model, dummy_input, verbose=False) acc = test(args, model, device, criterion, test_loader) time_cost = get_model_time_cost(model, dummy_input) print( f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}' ) # Step4. Model Finetuning optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) best_acc = 0 for epoch in range(args.finetune_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(args, model, device, criterion, test_loader) if acc > best_acc: best_acc = acc state_dict = model.state_dict() model.load_state_dict(state_dict) save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth') torch.save(state_dict, save_path) flops, params, _ = count_flops_params(model, dummy_input, verbose=True) time_cost = get_model_time_cost(model, dummy_input) # FLOPs 28.48 M, #Params: 0.18M, Accuracy: 89.03, Time Cost: 1.03 print( f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}' ) print(f'Model saved to {save_path}') # Step5. Model Quantization via QAT config_list = [{ 'quant_types': ['weight', 'output'], 'quant_bits': { 'weight': 8, 'output': 8 }, 'op_names': ['conv1'] }, { 'quant_types': ['output'], 'quant_bits': { 'output': 8 }, 'op_names': ['relu1'] }, { 'quant_types': ['weight', 'output'], 'quant_bits': { 'weight': 8, 'output': 8 }, 'op_names': ['conv2'] }, { 'quant_types': ['output'], 'quant_bits': { 'output': 8 }, 'op_names': ['relu2'] }] optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) quantizer = QAT_Quantizer(model, config_list, optimizer) quantizer.compress() # Step6. Quantization Aware Training best_acc = 0 for epoch in range(1): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(args, model, device, criterion, test_loader) if acc > best_acc: best_acc = acc state_dict = model.state_dict() calibration_path = os.path.join(args.experiment_data_dir, 'calibration.pth') calibration_config = quantizer.export_model(model_path, calibration_path) print("calibration_config: ", calibration_config) # Step7. Model Speedup batch_size = 32 input_shape = (batch_size, 1, 28, 28) engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32) engine.compress() test_trt(engine, test_loader)
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours # Load the YOLO model model = models.load_model( "%s/config/yolov3.cfg" % prefix, "%s/yolov3.weights" % prefix).cpu() model.eval() dummy_input = torch.rand(8, 3, 320, 320) model(dummy_input) # Generate the config list for pruner # Filter the layers that may not be able to prune not_safe = not_safe_to_prune(model, dummy_input) cfg_list = [] for name, module in model.named_modules(): if name in not_safe: continue if isinstance(module, torch.nn.Conv2d): cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]}) # Prune the model pruner = L1FilterPruner(model, cfg_list) pruner.compress() pruner.export_model('./model', './mask') pruner._unwrap_model() # Speedup the model ms = ModelSpeedup(model, dummy_input, './mask') ms.speedup_model() model(dummy_input)
def main(): torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './data.cifar10', train=True, download=True, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=200, shuffle=False) model = VGG(depth=16) model.to(device) # Train the base VGG-16 model print('=' * 10 + 'Train the unpruned base model' + '=' * 10) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, 160, 0) for epoch in range(160): print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer) test(model, device, test_loader) lr_scheduler.step(epoch) torch.save(model.state_dict(), 'vgg16_cifar10.pth') # Test base model accuracy print('=' * 10 + 'Test on the original model' + '=' * 10) model.load_state_dict(torch.load('vgg16_cifar10.pth')) test(model, device, test_loader) # top1 = 93.51% # Pruning Configuration, all convolution layers are pruned out 80% filters according to the L1 norm configure_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], }] # Prune model and test accuracy without fine tuning. print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) pruner = L1FilterPruner(model, configure_list) model = pruner.compress() test(model, device, test_loader) # top1 = 10.00% # Fine tune the pruned model for 40 epochs and test accuracy print('=' * 10 + 'Fine tuning' + '=' * 10) optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) best_top1 = 0 kd_teacher_model = VGG(depth=16) kd_teacher_model.to(device) kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth')) kd = KnowledgeDistill(kd_teacher_model, kd_T=5) for epoch in range(40): pruner.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer_finetune, kd) top1 = test(model, device, test_loader) if top1 > best_top1: best_top1 = top1 # Export the best model, 'model_path' stores state_dict of the pruned model, # mask_path stores mask_dict of the pruned model pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth') # Test the exported model print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) new_model = VGG(depth=16) new_model.to(device) new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) test(new_model, device, test_loader)
def main(args): # prepare dataset torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size) model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion) def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) def trainer(model, optimizer, criterion, epoch, callback): return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback) def evaluator(model): return test(model, device, criterion, val_loader) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance':{}} flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result # module types to prune, only "Conv2d" supported for channel pruning if args.base_algo in ['l1', 'l2', 'fpgm']: op_types = ['Conv2d'] elif args.base_algo == 'level': op_types = ['default'] config_list = [{ 'sparsity': args.sparsity, 'op_types': op_types }] dummy_input = get_dummy_input(args, device) if args.pruner == 'L1FilterPruner': pruner = L1FilterPruner(model, config_list) elif args.pruner == 'L2FilterPruner': pruner = L2FilterPruner(model, config_list) elif args.pruner == 'FPGMPruner': pruner = FPGMPruner(model, config_list) elif args.pruner == 'NetAdaptPruner': pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator, base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'ADMMPruner': # users are free to change the config here if args.model == 'LeNet': if args.base_algo in ['l1', 'l2', 'fpgm']: config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_types': ['Conv2d'], 'op_names': ['conv2'] }] elif args.base_algo == 'level': config_list = [{ 'sparsity': 0.8, 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_names': ['conv2'] }, { 'sparsity': 0.991, 'op_names': ['fc1'] }, { 'sparsity': 0.93, 'op_names': ['fc2'] }] else: raise ValueError('Example only implemented for LeNet.') pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2) elif args.pruner == 'SimulatedAnnealingPruner': pruner = SimulatedAnnealingPruner( model, config_list, evaluator=evaluator, base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'AutoCompressPruner': pruner = AutoCompressPruner( model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5, experiment_data_dir=args.experiment_data_dir) else: raise ValueError( "Pruner not supported.") # Pruner.compress() returns the masked model # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model model = pruner.compress() evaluation_result = evaluator(model) print('Evaluation result (masked model): %s' % evaluation_result) result['performance']['pruned'] = evaluation_result if args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s' % args.experiment_data_dir) # model speed up if args.speed_up: if args.pruner != 'AutoCompressPruner': if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'resnet18': model = ResNet18().to(device) elif args.model == 'resnet50': model = ResNet50().to(device) model.load_state_dict(torch.load(os.path.join(args.experiment_data_dir, 'model_masked.pth'))) masks_file = os.path.join(args.experiment_data_dir, 'mask.pth') m_speedup = ModelSpeedup(model, dummy_input, masks_file, device) m_speedup.speedup_model() evaluation_result = evaluator(model) print('Evaluation result (speed up model): %s' % evaluation_result) result['performance']['speedup'] = evaluation_result torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth')) print('Speed up model saved to %s' % args.experiment_data_dir) flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['speedup'] = flops result['params']['speedup'] = params if args.fine_tune: if args.dataset == 'mnist': optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.dataset == 'cifar10' and args.model == 'vgg16': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet18': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet50': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) best_acc = 0 for epoch in range(args.fine_tune_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = evaluator(model) if acc > best_acc: best_acc = acc torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s' % args.experiment_data_dir) result['performance']['finetuned'] = best_acc with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f)
def main(): torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './data.cifar10', train=True, download=True, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=200, shuffle=False) model = VGG(depth=16) model.to(device) # Train the base VGG-16 model print('=' * 10 + 'Train the unpruned base model' + '=' * 10) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, 160, 0) for epoch in range(160): train(model, device, train_loader, optimizer) test(model, device, test_loader) lr_scheduler.step(epoch) torch.save(model.state_dict(), 'vgg16_cifar10.pth') # Test base model accuracy print('=' * 10 + 'Test on the original model' + '=' * 10) model.load_state_dict(torch.load('vgg16_cifar10.pth')) test(model, device, test_loader) # top1 = 93.51% # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' configure_list = [{ 'sparsity': 0.5, 'op_types': ['default'], 'op_names': [ 'feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37' ] }] # Prune model and test accuracy without fine tuning. print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) pruner = L1FilterPruner(model, configure_list, optimizer_finetune) model = pruner.compress() test(model, device, test_loader) # top1 = 88.19% # Fine tune the pruned model for 40 epochs and test accuracy print('=' * 10 + 'Fine tuning' + '=' * 10) best_top1 = 0 for epoch in range(40): pruner.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer_finetune) top1 = test(model, device, test_loader) if top1 > best_top1: best_top1 = top1 # Export the best model, 'model_path' stores state_dict of the pruned model, # mask_path stores mask_dict of the pruned model pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth') # Test the exported model print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) new_model = VGG(depth=16) new_model.to(device) new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) test(new_model, device, test_loader)