コード例 #1
0
    def test_speedup_integration(self):
        # skip this test on windows(7GB mem available) due to memory limit
        # Note: hack trick, may be updated in the future
        if 'win' in sys.platform or 'Win' in sys.platform:
            print(
                'Skip test_speedup_integration on windows due to memory limit!'
            )
            return

        Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]

        for model_name in [
                'resnet18',
                'mobilenet_v2',
                'squeezenet1_1',
                'densenet121',
                'densenet169',
                # 'inception_v3' inception is too large and may fail the pipeline
                'resnet50'
        ]:

            for gen_cfg_func in Gen_cfg_funcs:
                kwargs = {'pretrained': True}
                if model_name == 'resnet50':
                    # testing multiple groups
                    kwargs = {'pretrained': False, 'groups': 4}
                Model = getattr(models, model_name)
                net = Model(**kwargs).to(device)
                speedup_model = Model(**kwargs).to(device)
                net.eval()  # this line is necessary
                speedup_model.eval()
                # random generate the prune config for the pruner
                cfgs = gen_cfg_func(net)
                print("Testing {} with compression config \n {}".format(
                    model_name, cfgs))
                pruner = L1FilterPruner(net, cfgs)
                pruner.compress()
                pruner.export_model(MODEL_FILE, MASK_FILE)
                pruner._unwrap_model()
                state_dict = torch.load(MODEL_FILE)
                speedup_model.load_state_dict(state_dict)
                zero_bn_bias(net)
                zero_bn_bias(speedup_model)

                data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
                ms = ModelSpeedup(speedup_model, data, MASK_FILE)
                ms.speedup_model()

                speedup_model.eval()

                ori_out = net(data)
                speeded_out = speedup_model(data)
                ori_sum = torch.sum(ori_out).item()
                speeded_sum = torch.sum(speeded_out).item()
                print('Sum of the output of %s (before speedup):' % model_name,
                      ori_sum)
                print('Sum of the output of %s (after speedup):' % model_name,
                      speeded_sum)
                assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                    (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
コード例 #2
0
ファイル: test_model_speedup.py プロジェクト: xiaowu0162/nni
def prune_model_l1(model):
    config_list = [{
        'sparsity': SPARSITY,
        'op_types': ['Conv2d']
    }]
    pruner = L1FilterPruner(model, config_list)
    pruner.compress()
    pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
コード例 #3
0
ファイル: test_model_speedup.py プロジェクト: OliverShang/nni
 def test_speedup_tupleunpack(self):
     """This test is reported in issue3645"""
     model = TupleUnpack_Model()
     cfg_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
     dummy_input = torch.rand(2, 3, 224, 224)
     pruner = L1FilterPruner(model, cfg_list)
     pruner.compress()
     model(dummy_input)
     pruner.export_model(MODEL_FILE, MASK_FILE)
     ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
     ms.speedup_model()
コード例 #4
0
    def test_mask_conflict(self):
        outdir = os.path.join(prefix, 'masks')
        os.makedirs(outdir, exist_ok=True)
        for name in model_names:
            print('Test mask conflict for %s' % name)
            model = getattr(models, name)
            net = model().to(device)
            dummy_input = torch.ones(1, 3, 224, 224).to(device)
            # random generate the prune sparsity for each layer
            cfglist = []
            for layername, layer in net.named_modules():
                if isinstance(layer, nn.Conv2d):
                    # pruner cannot allow the sparsity to be 0 or 1
                    sparsity = np.random.uniform(0.01, 0.99)
                    cfg = {
                        'op_types': ['Conv2d'],
                        'op_names': [layername],
                        'sparsity': sparsity
                    }
                    cfglist.append(cfg)
            pruner = L1FilterPruner(net, cfglist)
            pruner.compress()
            ck_file = os.path.join(outdir, '%s.pth' % name)
            mask_file = os.path.join(outdir, '%s_mask' % name)
            pruner.export_model(ck_file, mask_file)
            pruner._unwrap_model()
            # Fix the mask conflict
            fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)

            # use the channel dependency groud truth to check if
            # fix the mask conflict successfully
            for dset in channel_dependency_ground_truth[name]:
                lset = list(dset)
                for i, _ in enumerate(lset):
                    assert fixed_mask[lset[0]]['weight'].size(0) == fixed_mask[
                        lset[i]]['weight'].size(0)
                    w_index1 = self.get_pruned_index(
                        fixed_mask[lset[0]]['weight'])
                    w_index2 = self.get_pruned_index(
                        fixed_mask[lset[i]]['weight'])
                    assert w_index1 == w_index2
                    if hasattr(fixed_mask[lset[0]], 'bias'):
                        b_index1 = self.get_pruned_index(
                            fixed_mask[lset[0]]['bias'])
                        b_index2 = self.get_pruned_index(
                            fixed_mask[lset[i]]['bias'])
                        assert b_index1 == b_index2
コード例 #5
0
    def test_speedup_integration(self):
        for model_name in [
                'resnet18',
                'squeezenet1_1',
                'mobilenet_v2',
                'densenet121',
                # 'inception_v3' inception is too large and may fail the pipeline
                'densenet169',
                'resnet50'
        ]:
            kwargs = {'pretrained': True}
            if model_name == 'resnet50':
                # testing multiple groups
                kwargs = {'pretrained': False, 'groups': 4}

            Model = getattr(models, model_name)
            net = Model(**kwargs).to(device)
            speedup_model = Model(**kwargs).to(device)
            net.eval()  # this line is necessary
            speedup_model.eval()
            # random generate the prune config for the pruner
            cfgs = generate_random_sparsity(net)
            pruner = L1FilterPruner(net, cfgs)
            pruner.compress()
            pruner.export_model(MODEL_FILE, MASK_FILE)
            pruner._unwrap_model()
            state_dict = torch.load(MODEL_FILE)
            speedup_model.load_state_dict(state_dict)
            zero_bn_bias(net)
            zero_bn_bias(speedup_model)

            data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
            ms = ModelSpeedup(speedup_model, data, MASK_FILE)
            ms.speedup_model()

            speedup_model.eval()

            ori_out = net(data)
            speeded_out = speedup_model(data)
            ori_sum = torch.sum(ori_out).item()
            speeded_sum = torch.sum(speeded_out).item()
            print('Sum of the output of %s (before speedup):' % model_name,
                  ori_sum)
            print('Sum of the output of %s (after speedup):' % model_name,
                  speeded_sum)
            assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                   (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
コード例 #6
0
    def test_multiplication_speedup(self):
        """
        Model from issue 4540.
        """
        class Net(torch.nn.Module):
            def __init__(self, ):
                super(Net, self).__init__()
                self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
                self.input = torch.nn.Conv2d(3, 8, 3)
                self.bn = torch.nn.BatchNorm2d(8)
                self.fc1 = torch.nn.Conv2d(8, 16, 1)
                self.fc2 = torch.nn.Conv2d(16, 8, 1)
                self.activation = torch.nn.ReLU()
                self.scale_activation = torch.nn.Hardsigmoid()
                self.out = torch.nn.Conv2d(8, 12, 1)

            def forward(self, input):
                input = self.activation(self.bn(self.input(input)))
                scale = self.avgpool(input)
                out1 = self.activation(self.fc1(scale))
                out1 = self.scale_activation(self.fc2(out1))
                return self.out(out1 * input)

        model = Net().to(device)
        model.eval()
        im = torch.ones(1, 3, 512, 512).to(device)
        model(im)
        cfg_list = []

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                cfg_list.append({
                    'op_types': ['Conv2d'],
                    'sparsity': 0.3,
                    'op_names': [name]
                })

        pruner = L1FilterPruner(model, cfg_list)
        pruner.compress()
        pruner.export_model(MODEL_FILE, MASK_FILE)
        pruner._unwrap_model()
        ms = ModelSpeedup(model, im, MASK_FILE)
        ms.speedup_model()
コード例 #7
0
ファイル: test_model_speedup.py プロジェクト: xiaowu0162/nni
 def test_convtranspose_model(self):
     ori_model = TransposeModel()
     dummy_input = torch.rand(1, 3, 8, 8)
     config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
     pruner = L1FilterPruner(ori_model, config_list)
     pruner.compress()
     ori_model(dummy_input)
     pruner.export_model(MODEL_FILE, MASK_FILE)
     pruner._unwrap_model()
     new_model = TransposeModel()
     state_dict = torch.load(MODEL_FILE)
     new_model.load_state_dict(state_dict)
     ms = ModelSpeedup(new_model, dummy_input, MASK_FILE)
     ms.speedup_model()
     zero_bn_bias(ori_model)
     zero_bn_bias(new_model)
     ori_out = ori_model(dummy_input)
     new_out = new_model(dummy_input)
     ori_sum = torch.sum(ori_out)
     speeded_sum = torch.sum(new_out)
     print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum))
     assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
             (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
コード例 #8
0
ファイル: end2end_compression.py プロジェクト: xiaowu0162/nni
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.experiment_data_dir, exist_ok=True)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, transform=transform),
        batch_size=64,
    )
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        'data', train=False, transform=transform),
                                              batch_size=1000)

    # Step1. Model Pretraining
    model = NaiveModel().to(device)
    criterion = torch.nn.NLLLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False)

    if args.pretrained_model_dir is None:
        args.pretrained_model_dir = os.path.join(args.experiment_data_dir,
                                                 f'pretrained.pth')

        best_acc = 0
        for epoch in range(args.pretrain_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = test(args, model, device, criterion, test_loader)
            if acc > best_acc:
                best_acc = acc
                state_dict = model.state_dict()

        model.load_state_dict(state_dict)
        torch.save(state_dict, args.pretrained_model_dir)
        print(f'Model saved to {args.pretrained_model_dir}')
    else:
        state_dict = torch.load(args.pretrained_model_dir)
        model.load_state_dict(state_dict)
        best_acc = test(args, model, device, criterion, test_loader)

    dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
    time_cost = get_model_time_cost(model, dummy_input)

    # 125.49 M, 0.85M, 93.29, 1.1012
    print(
        f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}'
    )

    # Step2. Model Pruning
    config_list = [{'sparsity': args.sparsity, 'op_types': ['Conv2d']}]

    kw_args = {}
    if args.dependency_aware:
        dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
        print('Enable the dependency_aware mode')
        # note that, not all pruners support the dependency_aware mode
        kw_args['dependency_aware'] = True
        kw_args['dummy_input'] = dummy_input

    pruner = L1FilterPruner(model, config_list, **kw_args)
    model = pruner.compress()
    pruner.get_pruned_weights()

    mask_path = os.path.join(args.experiment_data_dir, 'mask.pth')
    model_path = os.path.join(args.experiment_data_dir, 'pruned.pth')
    pruner.export_model(model_path=model_path, mask_path=mask_path)
    pruner._unwrap_model()  # unwrap all modules to normal state

    # Step3. Model Speedup
    m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
    m_speedup.speedup_model()
    print('model after speedup', model)

    flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
    acc = test(args, model, device, criterion, test_loader)
    time_cost = get_model_time_cost(model, dummy_input)
    print(
        f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}'
    )

    # Step4. Model Finetuning
    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

    best_acc = 0
    for epoch in range(args.finetune_epochs):
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        acc = test(args, model, device, criterion, test_loader)
        if acc > best_acc:
            best_acc = acc
            state_dict = model.state_dict()

    model.load_state_dict(state_dict)
    save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
    torch.save(state_dict, save_path)

    flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
    time_cost = get_model_time_cost(model, dummy_input)

    # FLOPs 28.48 M, #Params: 0.18M, Accuracy:  89.03, Time Cost: 1.03
    print(
        f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}'
    )
    print(f'Model saved to {save_path}')

    # Step5. Model Quantization via QAT
    config_list = [{
        'quant_types': ['weight', 'output'],
        'quant_bits': {
            'weight': 8,
            'output': 8
        },
        'op_names': ['conv1']
    }, {
        'quant_types': ['output'],
        'quant_bits': {
            'output': 8
        },
        'op_names': ['relu1']
    }, {
        'quant_types': ['weight', 'output'],
        'quant_bits': {
            'weight': 8,
            'output': 8
        },
        'op_names': ['conv2']
    }, {
        'quant_types': ['output'],
        'quant_bits': {
            'output': 8
        },
        'op_names': ['relu2']
    }]

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    quantizer = QAT_Quantizer(model, config_list, optimizer)
    quantizer.compress()

    # Step6. Quantization Aware Training
    best_acc = 0
    for epoch in range(1):
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        acc = test(args, model, device, criterion, test_loader)
        if acc > best_acc:
            best_acc = acc
            state_dict = model.state_dict()

    calibration_path = os.path.join(args.experiment_data_dir,
                                    'calibration.pth')
    calibration_config = quantizer.export_model(model_path, calibration_path)
    print("calibration_config: ", calibration_config)

    # Step7. Model Speedup
    batch_size = 32
    input_shape = (batch_size, 1, 28, 28)
    engine = ModelSpeedupTensorRT(model,
                                  input_shape,
                                  config=calibration_config,
                                  batchsize=32)
    engine.compress()

    test_trt(engine, test_loader)
コード例 #9
0
ファイル: speedup_yolov3.py プロジェクト: yinfupai/nni
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours
# Load the YOLO model
model = models.load_model(
  "%s/config/yolov3.cfg" % prefix, 
  "%s/yolov3.weights" % prefix).cpu()
model.eval()
dummy_input = torch.rand(8, 3, 320, 320)
model(dummy_input)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe = not_safe_to_prune(model, dummy_input)
cfg_list = []
for name, module in model.named_modules():
  if name in not_safe:
    continue
  if isinstance(module, torch.nn.Conv2d):
    cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]})
# Prune the model
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
pruner._unwrap_model()
# Speedup the model
ms = ModelSpeedup(model, dummy_input, './mask')

ms.speedup_model()
model(dummy_input)

コード例 #10
0
def main():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=16)
    model.to(device)

    # Train the base VGG-16 model
    print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 160, 0)
    for epoch in range(160):
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)
        lr_scheduler.step(epoch)
    torch.save(model.state_dict(), 'vgg16_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test on the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.51%

    # Pruning Configuration, all convolution layers are pruned out 80% filters according to the L1 norm
    configure_list = [{
        'sparsity': 0.8,
        'op_types': ['Conv2d'],
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    pruner = L1FilterPruner(model, configure_list)
    model = pruner.compress()
    test(model, device, test_loader)
    # top1 = 10.00%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0
    kd_teacher_model = VGG(depth=16)
    kd_teacher_model.to(device)
    kd_teacher_model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    kd = KnowledgeDistill(kd_teacher_model, kd_T=5)
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune, kd)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg16_cifar10.pth',
                                mask_path='mask_vgg16_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=16)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
    test(new_model, device, test_loader)
コード例 #11
0
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance':{}}

    flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2', 'fpgm']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{
        'sparsity': args.sparsity,
        'op_types': op_types
    }]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'FPGMPruner':
        pruner = FPGMPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator,
                                base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2', 'fpgm']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model, config_list, evaluator=evaluator, base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input,
            num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError(
            "Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s' % args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)

            model.load_state_dict(torch.load(os.path.join(args.experiment_data_dir, 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s' % args.experiment_data_dir)
        flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s' % args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f:
        json.dump(result, f)
コード例 #12
0
ファイル: L1_torch_cifar10.py プロジェクト: zaraSiddiqui/nni
def main():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=16)
    model.to(device)

    # Train the base VGG-16 model
    print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 160, 0)
    for epoch in range(160):
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)
        lr_scheduler.step(epoch)
    torch.save(model.state_dict(), 'vgg16_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test on the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.51%

    # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
    # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
    configure_list = [{
        'sparsity':
        0.5,
        'op_types': ['default'],
        'op_names': [
            'feature.0', 'feature.24', 'feature.27', 'feature.30',
            'feature.34', 'feature.37'
        ]
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    pruner = L1FilterPruner(model, configure_list, optimizer_finetune)
    model = pruner.compress()
    test(model, device, test_loader)
    # top1 = 88.19%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg16_cifar10.pth',
                                mask_path='mask_vgg16_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=16)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
    test(new_model, device, test_loader)