def get_trained_model_optimizer(args, device, train_loader, val_loader, criterion): if args.model == 'LeNet': model = LeNet().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4) else: optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.model == 'vgg16': model = VGG(depth=16).to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) elif args.model == 'resnet18': model = ResNet18().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) elif args.model == 'resnet50': model = ResNet50().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) else: raise ValueError("model not recognized") if not args.load_pretrained_model: best_acc = 0 best_epoch = 0 for epoch in range(args.pretrain_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(model, device, criterion, val_loader) if acc > best_acc: best_acc = acc best_epoch = epoch state_dict = model.state_dict() model.load_state_dict(state_dict) print('Best acc:', best_acc) print('Best epoch:', best_epoch) if args.save_model: torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth')) print('Model trained saved to %s' % args.experiment_data_dir) return model, optimizer
def get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion): if args.model == 'lenet': model = LeNet().to(device) if args.pretrained_model_dir is None: optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.model == 'vgg16': model = VGG(depth=16).to(device) if args.pretrained_model_dir is None: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) elif args.model == 'vgg19': model = VGG(depth=19).to(device) if args.pretrained_model_dir is None: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) else: raise ValueError("model not recognized") if args.pretrained_model_dir is None: print('start pre-training...') best_acc = 0 for epoch in range(args.pretrain_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(args, model, device, criterion, test_loader) if acc > best_acc: best_acc = acc state_dict = model.state_dict() model.load_state_dict(state_dict) acc = best_acc torch.save(state_dict, os.path.join(args.experiment_data_dir, f'pretrain_{args.dataset}_{args.model}.pth')) print('Model trained saved to %s' % args.experiment_data_dir) else: model.load_state_dict(torch.load(args.pretrained_model_dir)) best_acc = test(args, model, device, criterion, test_loader) # setup new opotimizer for pruning optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) print('Pretrained model acc:', best_acc) return model, optimizer, scheduler
def get_model_optimizer_scheduler(args, device, test_loader, criterion): if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'vgg19': model = VGG(depth=19).to(device) else: raise ValueError("model not recognized") # In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture. if args.teacher_model_dir is None: raise NotImplementedError('please load pretrained teacher model first') else: model.load_state_dict(torch.load(args.teacher_model_dir)) best_acc = test(args, model, device, criterion, test_loader) model_t = deepcopy(model) model_s = deepcopy(model) if args.student_model_dir is not None: # load the pruned student model checkpoint model_s.load_state_dict(torch.load(args.student_model_dir)) dummy_input = get_dummy_input(args, device) m_speedup = ModelSpeedup(model_s, dummy_input, args.mask_path, device) m_speedup.speedup_model() module_list = nn.ModuleList([]) module_list.append(model_s) module_list.append(model_t) # setup opotimizer for fine-tuning studeng model optimizer = torch.optim.SGD(model_s.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[ int(args.fine_tune_epochs * 0.5), int(args.fine_tune_epochs * 0.75) ], gamma=0.1) print('Pretrained teacher model acc:', best_acc) return module_list, optimizer, scheduler
def main(args): # prepare dataset torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size) model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion) def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) def trainer(model, optimizer, criterion, epoch): return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch) def evaluator(model): return test(model, device, criterion, val_loader) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance':{}} flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result # module types to prune, only "Conv2d" supported for channel pruning if args.base_algo in ['l1', 'l2', 'fpgm']: op_types = ['Conv2d'] elif args.base_algo == 'level': op_types = ['default'] config_list = [{ 'sparsity': args.sparsity, 'op_types': op_types }] dummy_input = get_dummy_input(args, device) if args.pruner == 'L1FilterPruner': pruner = L1FilterPruner(model, config_list) elif args.pruner == 'L2FilterPruner': pruner = L2FilterPruner(model, config_list) elif args.pruner == 'FPGMPruner': pruner = FPGMPruner(model, config_list) elif args.pruner == 'NetAdaptPruner': pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator, base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'ADMMPruner': # users are free to change the config here if args.model == 'LeNet': if args.base_algo in ['l1', 'l2', 'fpgm']: config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_types': ['Conv2d'], 'op_names': ['conv2'] }] elif args.base_algo == 'level': config_list = [{ 'sparsity': 0.8, 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_names': ['conv2'] }, { 'sparsity': 0.991, 'op_names': ['fc1'] }, { 'sparsity': 0.93, 'op_names': ['fc2'] }] else: raise ValueError('Example only implemented for LeNet.') pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, epochs_per_iteration=2) elif args.pruner == 'SimulatedAnnealingPruner': pruner = SimulatedAnnealingPruner( model, config_list, evaluator=evaluator, base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'AutoCompressPruner': pruner = AutoCompressPruner( model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_epochs_per_iteration=5, experiment_data_dir=args.experiment_data_dir) else: raise ValueError( "Pruner not supported.") # Pruner.compress() returns the masked model # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model model = pruner.compress() evaluation_result = evaluator(model) print('Evaluation result (masked model): %s' % evaluation_result) result['performance']['pruned'] = evaluation_result if args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s' % args.experiment_data_dir) # model speedup if args.speedup: if args.pruner != 'AutoCompressPruner': if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'resnet18': model = ResNet18().to(device) elif args.model == 'resnet50': model = ResNet50().to(device) model.load_state_dict(torch.load(os.path.join(args.experiment_data_dir, 'model_masked.pth'))) masks_file = os.path.join(args.experiment_data_dir, 'mask.pth') m_speedup = ModelSpeedup(model, dummy_input, masks_file, device) m_speedup.speedup_model() evaluation_result = evaluator(model) print('Evaluation result (speedup model): %s' % evaluation_result) result['performance']['speedup'] = evaluation_result torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speedup.pth')) print('Speedup model saved to %s' % args.experiment_data_dir) flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['speedup'] = flops result['params']['speedup'] = params if args.fine_tune: if args.dataset == 'mnist': optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.dataset == 'cifar10' and args.model == 'vgg16': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet18': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet50': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) best_acc = 0 for epoch in range(args.fine_tune_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = evaluator(model) if acc > best_acc: best_acc = acc torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s' % args.experiment_data_dir) result['performance']['finetuned'] = best_acc with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f)