示例#1
0
def create_model(model_name='naive'):
    assert model_name in ['naive', 'vgg16', 'vgg19']

    if model_name == 'naive':
        return NaiveModel()
    elif model_name == 'vgg16':
        return VGG(16)
    else:
        return VGG(19)
示例#2
0
def get_model_optimizer_scheduler(args, device, test_loader, criterion):
    if args.model == 'LeNet':
        model = LeNet().to(device)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
    elif args.model == 'vgg19':
        model = VGG(depth=19).to(device)
    else:
        raise ValueError("model not recognized")

    # In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
    if args.teacher_model_dir is None:
        raise NotImplementedError('please load pretrained teacher model first')

    else:
        model.load_state_dict(torch.load(args.teacher_model_dir))
        best_acc = test(args, model, device, criterion, test_loader)

    model_t = deepcopy(model)
    model_s = deepcopy(model)

    if args.student_model_dir is not None:
        # load the pruned student model checkpoint
        model_s.load_state_dict(torch.load(args.student_model_dir))

    dummy_input = get_dummy_input(args, device)
    m_speedup = ModelSpeedup(model_s, dummy_input, args.mask_path, device)
    m_speedup.speedup_model()

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    module_list.append(model_t)

    # setup opotimizer for fine-tuning studeng model
    optimizer = torch.optim.SGD(model_s.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(args.fine_tune_epochs * 0.5),
                                int(args.fine_tune_epochs * 0.75)
                            ],
                            gamma=0.1)

    print('Pretrained teacher model acc:', best_acc)
    return module_list, optimizer, scheduler
示例#3
0
def get_trained_model_optimizer(args, device, train_loader, val_loader, criterion):
    if args.model == 'LeNet':
        model = LeNet().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4)
        else:
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet18':
        model = ResNet18().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet50':
        model = ResNet50().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    else:
        raise ValueError("model not recognized")

    if not args.load_pretrained_model:
        best_acc = 0
        best_epoch = 0
        for epoch in range(args.pretrain_epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)
            scheduler.step()
            acc = test(model, device, criterion, val_loader)
            if acc > best_acc:
                best_acc = acc
                best_epoch = epoch
                state_dict = model.state_dict()
        model.load_state_dict(state_dict)
        print('Best acc:', best_acc)
        print('Best epoch:', best_epoch)

        if args.save_model:
            torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
            print('Model trained saved to %s' % args.experiment_data_dir)

    return model, optimizer
def flops_counter(args):
    # model speed up
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)

    if args.pruner != 'AutoCompressPruner':
        if args.model == 'LeNet':
            model = LeNet().to(device)
        elif args.model == 'vgg16':
            model = VGG(depth=16).to(device)
        elif args.model == 'resnet18':
            model = models.resnet18(pretrained=False,
                                    num_classes=10).to(device)
        elif args.model == 'mobilenet_v2':
            model = models.mobilenet_v2(pretrained=False).to(device)

        def evaluator(model):
            return test(model, device, criterion, val_loader)

        model.load_state_dict(
            torch.load(
                os.path.join(args.experiment_data_dir,
                             'model_fine_tuned.pth')))
        masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

        dummy_input = get_dummy_input(args, device)

        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        evaluation_result = evaluator(model)
        print('Evaluation result (speed up model): %s' % evaluation_result)

        with open(os.path.join(args.experiment_data_dir,
                               'performance.json')) as f:
            result = json.load(f)

        result['speedup'] = evaluation_result
        with open(os.path.join(args.experiment_data_dir, 'performance.json'),
                  'w+') as f:
            json.dump(result, f)

        torch.save(
            model.state_dict(),
            os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
        print('Speed up model saved to %s', args.experiment_data_dir)
    else:
        model = torch.load(
            os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))
        model.eval()
        flops, params = count_flops_params(model, (1, 3, 32, 32))
        with open(os.path.join(args.experiment_data_dir, 'flops.json'),
                  'w+') as f:
            json.dump({'FLOPS': int(flops), 'params': int(params)}, f)
示例#5
0
def slim_speedup(masks_file, model_checkpoint):
    device = torch.device('cuda')
    model = VGG(depth=19)
    model.to(device)
    model.eval()

    dummy_input = torch.randn(64, 3, 32, 32)
    if use_mask:
        apply_compression_results(model, masks_file)
        dummy_input = dummy_input.to(device)
        start = time.time()
        for _ in range(32):
            out = model(dummy_input)
        #print(out.size(), out)
        print('mask elapsed time: ', time.time() - start)
        return
    else:
        #print("model before: ", model)
        m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
        m_speedup.speedup_model()
        #print("model after: ", model)
        dummy_input = dummy_input.to(device)
        start = time.time()
        for _ in range(32):
            out = model(dummy_input)
        #print(out.size(), out)
        print('speedup elapsed time: ', time.time() - start)
        return
示例#6
0
def model_inference(config):
    masks_file = config['masks_file']
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # device = torch.device(config['device'])
    if config['model_name'] == 'vgg16':
        model = VGG(depth=16)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'lenet':
        model = LeNet()

    model.to(device)
    model.eval()

    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
    if use_mask:
        apply_compression_results(model, masks_file, device)
        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')
示例#7
0
def model_inference(config):
    masks_file = config['masks_file']
    device = torch.device(config['device'])
    if config['model_name'] == 'unet':
        model = UNet(3, 1)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'naive':
        from model_prune_torch import NaiveModel
        model = NaiveModel()
    model.to(device)
    model.load_state_dict(torch.load(config['model_file'],
                                     map_location=device))
    model.eval()

    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
    if use_mask:
        apply_compression_results(model, masks_file, device)
        start = time.time()
        for _ in range(1):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(1):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-05):
            torch.save(model, config['save_dir_for_speedup'])
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')
示例#8
0
def main():
    parser = argparse.ArgumentParser("multiple gpu with pruning")
    parser.add_argument("--epochs", type=int, default=160)
    parser.add_argument("--retrain", default=False, action="store_true")
    parser.add_argument("--parallel", default=False, action="store_true")

    args = parser.parse_args()

    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=19)
    model.to(device)
    # Train the base VGG-19 model
    if args.retrain:
        print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
        epochs = args.epochs
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=0.1,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        for epoch in range(epochs):
            if epoch in [epochs * 0.5, epochs * 0.75]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
            print("epoch {}".format(epoch))
            train(model, device, train_loader, optimizer, True)
            test(model, device, test_loader)
        torch.save(model.state_dict(), 'vgg19_cifar10.pth')
    else:
        assert os.path.isfile(
            'vgg19_cifar10.pth'), "can not find checkpoint 'vgg19_cifar10.pth'"
        model.load_state_dict(torch.load('vgg19_cifar10.pth'))
    # Test base model accuracy
    print('=' * 10 + 'Test the original model' + '=' * 10)
    test(model, device, test_loader)
    # top1 = 93.60%

    # Pruning Configuration, in paper 'Learning efficient convolutional networks through network slimming',
    configure_list = [{
        'sparsity': 0.7,
        'op_types': ['BatchNorm2d'],
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
    pruner = SlimPruner(model, configure_list)
    model = pruner.compress()
    if args.parallel:
        if torch.cuda.device_count() > 1:
            print("use {} gpus for pruning".format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
            # model = nn.DataParallel(model, device_ids=[0, 1])
        else:
            print("only detect 1 gpu, fall back")
    model.to(device)
    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg19_cifar10.pth',
                                mask_path='mask_vgg19_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=19)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
    test(new_model, device, test_loader)
def main():
    parser = argparse.ArgumentParser("multiple gpu with pruning")
    parser.add_argument("--epochs", type=int, default=160)
    parser.add_argument("--retrain", default=False, action="store_true")
    parser.add_argument("--parallel", default=False, action="store_true")

    args = parser.parse_args()
    torch.manual_seed(0)
    device = torch.device('cuda')
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=16)
    model.to(device)

    # Train the base VGG-16 model
    if args.retrain:
        print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=0.1,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 160, 0)
        for epoch in range(args.epochs):
            train(model, device, train_loader, optimizer)
            test(model, device, test_loader)
            lr_scheduler.step(epoch)
        torch.save(model.state_dict(), 'vgg16_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test on the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.51%

    # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
    # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
    configure_list = [{
        'sparsity':
        0.5,
        'op_types': ['default'],
        'op_names': [
            'feature.0', 'feature.24', 'feature.27', 'feature.30',
            'feature.34', 'feature.37'
        ]
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    pruner = ActivationMeanRankFilterPruner(model, configure_list)
    model = pruner.compress()
    if args.parallel:
        if torch.cuda.device_count() > 1:
            print("use {} gpus for pruning".format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
        else:
            print("only detect 1 gpu, fall back")

    model.to(device)
    test(model, device, test_loader)
    # top1 = 88.19%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg16_cifar10.pth',
                                mask_path='mask_vgg16_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=16)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
    test(new_model, device, test_loader)
示例#10
0
def get_model_optimizer_scheduler(args, device, train_loader, test_loader,
                                  criterion):
    if args.model == 'lenet':
        model = LeNet().to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.pretrain_epochs * 0.5),
                                        int(args.pretrain_epochs * 0.75)
                                    ],
                                    gamma=0.1)
    elif args.model == 'vgg19':
        model = VGG(depth=19).to(device)
        if args.pretrained_model_dir is None:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.pretrain_epochs * 0.5),
                                        int(args.pretrain_epochs * 0.75)
                                    ],
                                    gamma=0.1)
    else:
        raise ValueError("model not recognized")

    if args.pretrained_model_dir is None:
        print('start pre-training...')
        best_acc = 0
        for epoch in range(args.pretrain_epochs):
            train(args,
                  model,
                  device,
                  train_loader,
                  criterion,
                  optimizer,
                  epoch,
                  sparse_bn=True if args.pruner == 'slim' else False)
            scheduler.step()
            acc = test(args, model, device, criterion, test_loader)
            if acc > best_acc:
                best_acc = acc
                state_dict = model.state_dict()

        model.load_state_dict(state_dict)
        acc = best_acc

        torch.save(
            state_dict,
            os.path.join(args.experiment_data_dir,
                         f'pretrain_{args.dataset}_{args.model}.pth'))
        print('Model trained saved to %s' % args.experiment_data_dir)

    else:
        model.load_state_dict(torch.load(args.pretrained_model_dir))
        best_acc = test(args, model, device, criterion, test_loader)

    # setup new opotimizer for fine-tuning
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(args.pretrain_epochs * 0.5),
                                int(args.pretrain_epochs * 0.75)
                            ],
                            gamma=0.1)

    print('Pretrained model acc:', best_acc)
    return model, optimizer, scheduler
示例#11
0
def main():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=16)
    model.to(device)

    # Train the base VGG-16 model
    print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 160, 0)
    for epoch in range(160):
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)
        lr_scheduler.step(epoch)
    torch.save(model.state_dict(), 'vgg16_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test on the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.51%

    # Pruning Configuration, all convolution layers are pruned out 80% filters according to the L1 norm
    configure_list = [{
        'sparsity': 0.8,
        'op_types': ['Conv2d'],
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    pruner = L1FilterPruner(model, configure_list)
    model = pruner.compress()
    test(model, device, test_loader)
    # top1 = 10.00%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0
    kd_teacher_model = VGG(depth=16)
    kd_teacher_model.to(device)
    kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune, kd)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg16_cifar10.pth',
                                mask_path='mask_vgg16_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=16)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
    test(new_model, device, test_loader)
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader,
                                                   val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{'sparsity': args.sparsity, 'op_types': op_types}]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'ActivationMeanRankFilterPruner':
        pruner = ActivationMeanRankFilterPruner(model, config_list)
    elif args.pruner == 'ActivationAPoZRankFilterPruner':
        pruner = ActivationAPoZRankFilterPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model,
                                config_list,
                                short_term_fine_tuner=short_term_fine_tuner,
                                evaluator=evaluator,
                                base_algo=args.base_algo,
                                experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model,
                            config_list,
                            trainer=trainer,
                            num_iterations=2,
                            training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model,
            config_list,
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model,
            config_list,
            trainer=trainer,
            evaluator=evaluator,
            dummy_input=dummy_input,
            num_iterations=3,
            optimize_mode='maximize',
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            admm_num_iterations=30,
            admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)
            elif args.model == 'mobilenet_v2':
                model = models.mobilenet_v2(pretrained=False).to(device)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.experiment_data_dir,
                                 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s', args.experiment_data_dir)
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'),
              'w+') as f:
        json.dump(result, f)
示例#13
0
def main():
    torch.manual_seed(0)
    device = torch.device('cuda')
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.Pad(4),
                             transforms.RandomCrop(32),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
        batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=200, shuffle=False)

    model = VGG(depth=19)
    model.to(device)

    # Train the base VGG-19 model
    print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
    epochs = 160
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    for epoch in range(epochs):
        if epoch in [epochs * 0.5, epochs * 0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        train(model, device, train_loader, optimizer, True)
        test(model, device, test_loader)
    torch.save(model.state_dict(), 'vgg19_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg19_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.60%

    # Pruning Configuration, in paper 'Learning efficient convolutional networks through network slimming',
    configure_list = [{
        'sparsity': 0.7,
        'op_types': ['BatchNorm2d'],
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
    pruner = SlimPruner(model, configure_list)
    model = pruner.compress()
    test(model, device, test_loader)
    # top1 = 93.55%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=19)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
    test(new_model, device, test_loader)