Example #1
0
def get_trained_model_optimizer(args, device, train_loader, val_loader, criterion):
    if args.model == 'LeNet':
        model = LeNet().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4)
        else:
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet18':
        model = ResNet18().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet50':
        model = ResNet50().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    else:
        raise ValueError("model not recognized")

    if not args.load_pretrained_model:
        best_acc = 0
        best_epoch = 0
        for epoch in range(args.pretrain_epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)
            scheduler.step()
            acc = test(model, device, criterion, val_loader)
            if acc > best_acc:
                best_acc = acc
                best_epoch = epoch
                state_dict = model.state_dict()
        model.load_state_dict(state_dict)
        print('Best acc:', best_acc)
        print('Best epoch:', best_epoch)

        if args.save_model:
            torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
            print('Model trained saved to %s' % args.experiment_data_dir)

    return model, optimizer
Example #2
0
def get_model_optimizer_scheduler(args, device, train_loader, test_loader,
                                  criterion):
    if args.model == 'lenet':
        model = LeNet().to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.pretrain_epochs * 0.5),
                                        int(args.pretrain_epochs * 0.75)
                                    ],
                                    gamma=0.1)
    elif args.model == 'vgg19':
        model = VGG(depth=19).to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.pretrain_epochs * 0.5),
                                        int(args.pretrain_epochs * 0.75)
                                    ],
                                    gamma=0.1)
    else:
        raise ValueError("model not recognized")

    if args.pretrained_model_dir is None:
        print('start pre-training...')
        best_acc = 0
        for epoch in range(args.pretrain_epochs):
            train(args,
                  model,
                  device,
                  train_loader,
                  criterion,
                  optimizer,
                  epoch,
                  sparse_bn=True if args.pruner == 'slim' else False)
            scheduler.step()
            acc = test(args, model, device, criterion, test_loader)
            if acc > best_acc:
                best_acc = acc
                state_dict = model.state_dict()

        model.load_state_dict(state_dict)
        acc = best_acc

        torch.save(
            state_dict,
            os.path.join(args.experiment_data_dir,
                         f'pretrain_{args.dataset}_{args.model}.pth'))
        print('Model trained saved to %s' % args.experiment_data_dir)

    else:
        model.load_state_dict(torch.load(args.pretrained_model_dir))
        best_acc = test(args, model, device, criterion, test_loader)

    # setup new opotimizer for fine-tuning
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(args.pretrain_epochs * 0.5),
                                int(args.pretrain_epochs * 0.75)
                            ],
                            gamma=0.1)

    print('Pretrained model acc:', best_acc)
    return model, optimizer, scheduler
def main(args):
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    dataset1 = datasets.MNIST('./data',
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = LeNet().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    print('start pre-training')
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    torch.save(model.state_dict(), "pretrain_mnist_lenet.pt")

    print('start pruning')
    optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.01)

    # create pruner
    prune_config = [{
        'sparsity': args.sparsity,
        'op_types': ['default'],
    }]

    pruner = LevelPruner(model, prune_config, optimizer_finetune)
    model = pruner.compress()

    # fine-tuning
    best_top1 = 0
    for epoch in range(1, args.epochs + 1):
        pruner.update_epoch(epoch)
        train(args, model, device, train_loader, optimizer_finetune, epoch)
        top1 = test(model, device, test_loader)

        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruend_mnist_lenet.pt',
                                mask_path='mask_mnist_lenet.pt')
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader,
                                                   val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{'sparsity': args.sparsity, 'op_types': op_types}]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'ActivationMeanRankFilterPruner':
        pruner = ActivationMeanRankFilterPruner(model, config_list)
    elif args.pruner == 'ActivationAPoZRankFilterPruner':
        pruner = ActivationAPoZRankFilterPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model,
                                config_list,
                                short_term_fine_tuner=short_term_fine_tuner,
                                evaluator=evaluator,
                                base_algo=args.base_algo,
                                experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model,
                            config_list,
                            trainer=trainer,
                            num_iterations=2,
                            training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model,
            config_list,
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model,
            config_list,
            trainer=trainer,
            evaluator=evaluator,
            dummy_input=dummy_input,
            num_iterations=3,
            optimize_mode='maximize',
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            admm_num_iterations=30,
            admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)
            elif args.model == 'mobilenet_v2':
                model = models.mobilenet_v2(pretrained=False).to(device)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.experiment_data_dir,
                                 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s', args.experiment_data_dir)
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'),
              'w+') as f:
        json.dump(result, f)