Example #1
0
    def test_finegrained_speedup(self):
        """ Test the speedup on the fine-grained sparsity"""
        class MLP(nn.Module):
            def __init__(self):
                super(MLP, self).__init__()
                self.fc1 = nn.Linear(1024, 1024)
                self.fc2 = nn.Linear(1024, 1024)
                self.fc3 = nn.Linear(1024, 512)
                self.fc4 = nn.Linear(512, 10)

            def forward(self, x):
                x = x.view(-1, 1024)
                x = self.fc1(x)
                x = self.fc2(x)
                x = self.fc3(x)
                x = self.fc4(x)
                return x

        model = MLP().to(device)
        dummy_input = torch.rand(16, 1, 32, 32).to(device)
        cfg_list = [{'op_types': ['Linear'], 'sparsity': 0.99}]
        pruner = LevelPruner(model, cfg_list)
        pruner.compress()
        print('Original Arch')
        print(model)
        pruner.export_model(MODEL_FILE, MASK_FILE)
        pruner._unwrap_model()
        ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
        ms.speedup_model()
        print("Fine-grained speeduped model")
        print(model)
    def test_speedup_integration(self):
        # skip this test on windows(7GB mem available) due to memory limit
        # Note: hack trick, may be updated in the future
        if 'win' in sys.platform or 'Win' in sys.platform:
            print(
                'Skip test_speedup_integration on windows due to memory limit!'
            )
            return

        Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]

        for model_name in [
                'resnet18',
                'mobilenet_v2',
                'squeezenet1_1',
                'densenet121',
                'densenet169',
                # 'inception_v3' inception is too large and may fail the pipeline
                'resnet50'
        ]:

            for gen_cfg_func in Gen_cfg_funcs:
                kwargs = {'pretrained': True}
                if model_name == 'resnet50':
                    # testing multiple groups
                    kwargs = {'pretrained': False, 'groups': 4}
                Model = getattr(models, model_name)
                net = Model(**kwargs).to(device)
                speedup_model = Model(**kwargs).to(device)
                net.eval()  # this line is necessary
                speedup_model.eval()
                # random generate the prune config for the pruner
                cfgs = gen_cfg_func(net)
                print("Testing {} with compression config \n {}".format(
                    model_name, cfgs))
                pruner = L1FilterPruner(net, cfgs)
                pruner.compress()
                pruner.export_model(MODEL_FILE, MASK_FILE)
                pruner._unwrap_model()
                state_dict = torch.load(MODEL_FILE)
                speedup_model.load_state_dict(state_dict)
                zero_bn_bias(net)
                zero_bn_bias(speedup_model)

                data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
                ms = ModelSpeedup(speedup_model, data, MASK_FILE)
                ms.speedup_model()

                speedup_model.eval()

                ori_out = net(data)
                speeded_out = speedup_model(data)
                ori_sum = torch.sum(ori_out).item()
                speeded_sum = torch.sum(speeded_out).item()
                print('Sum of the output of %s (before speedup):' % model_name,
                      ori_sum)
                print('Sum of the output of %s (after speedup):' % model_name,
                      speeded_sum)
                assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                    (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
Example #3
0
    def test_speedup_bigmodel(self):
        prune_model_l1(BigModel())
        model = BigModel()
        apply_compression_results(model, MASK_FILE, 'cpu')
        model.eval()
        mask_out = model(dummy_input)

        model.train()
        ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
        ms.speedup_model()
        assert model.training

        model.eval()
        speedup_out = model(dummy_input)
        if not torch.allclose(mask_out, speedup_out, atol=1e-07):
            print('input:', dummy_input.size(),
                  torch.abs(dummy_input).sum((2, 3)))
            print('mask_out:', mask_out)
            print('speedup_out:', speedup_out)
            raise RuntimeError('model speedup inference result is incorrect!')

        orig_model = BigModel()

        assert model.backbone2.conv1.out_channels == int(
            orig_model.backbone2.conv1.out_channels * SPARSITY)
        assert model.backbone2.conv2.in_channels == int(
            orig_model.backbone2.conv2.in_channels * SPARSITY)
        assert model.backbone2.conv2.out_channels == int(
            orig_model.backbone2.conv2.out_channels * SPARSITY)
        assert model.backbone2.fc1.in_features == int(
            orig_model.backbone2.fc1.in_features * SPARSITY)
Example #4
0
    def test_dependency_aware_pruning(self):
        model_zoo = ['resnet18']
        pruners = [
            L1FilterPruner, L2FilterPruner, FPGMPruner,
            TaylorFOWeightFilterPruner
        ]
        sparsity = 0.7
        cfg_list = [{'op_types': ['Conv2d'], 'sparsity': sparsity}]
        dummy_input = torch.ones(1, 3, 224, 224)
        for model_name in model_zoo:
            for pruner in pruners:
                print('Testing on ', pruner)
                ori_filters = {}
                Model = getattr(models, model_name)
                net = Model(pretrained=True, progress=False)
                # record the number of the filter of each conv layer
                for name, module in net.named_modules():
                    if isinstance(module, nn.Conv2d):
                        ori_filters[name] = module.out_channels

                # for the pruners that based on the activations, we need feed
                # enough data before we call the compress function.
                optimizer = torch.optim.SGD(net.parameters(),
                                            lr=0.0001,
                                            momentum=0.9,
                                            weight_decay=4e-5)
                criterion = torch.nn.CrossEntropyLoss()
                tmp_pruner = pruner(net,
                                    cfg_list,
                                    optimizer,
                                    dependency_aware=True,
                                    dummy_input=dummy_input)
                # train one single batch so that the the pruner can collect the
                # statistic
                optimizer.zero_grad()
                out = net(dummy_input)
                batchsize = dummy_input.size(0)
                loss = criterion(out, torch.zeros(batchsize,
                                                  dtype=torch.int64))
                loss.backward()
                optimizer.step()

                tmp_pruner.compress()
                tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
                # if we want to use the same model, we should unwrap the pruner before the speedup
                tmp_pruner._unwrap_model()
                ms = ModelSpeedup(net, dummy_input, MASK_FILE)
                ms.speedup_model()
                for name, module in net.named_modules():
                    if isinstance(module, nn.Conv2d):
                        expected = int(ori_filters[name] * (1 - sparsity))
                        filter_diff = abs(expected - module.out_channels)
                        errmsg = '%s Ori: %d, Expected: %d, Real: %d' % (
                            name, ori_filters[name], expected,
                            module.out_channels)

                        # because we are using the dependency-aware mode, so the number of the
                        # filters after speedup should be ori_filters[name] * ( 1 - sparsity )
                        print(errmsg)
                        assert filter_diff <= 1, errmsg
Example #5
0
def get_model(args):
    print('=> Building model..')

    if args.dataset == 'imagenet':
        n_class = 1000
    elif args.dataset == 'cifar10':
        n_class = 10
    else:
        raise NotImplementedError

    if args.model_type == 'mobilenet':
        net = MobileNet(n_class=n_class)
    elif args.model_type == 'mobilenetv2':
        net = MobileNetV2(n_class=n_class)
    elif args.model_type.startswith('resnet'):
        net = resnet.__dict__[args.model_type](pretrained=True)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, n_class)
    else:
        raise NotImplementedError

    if args.ckpt_path is not None:
        # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
        print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
        net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
        if args.mask_path is not None:
            SZ = 224 if args.dataset == 'imagenet' else 32
            data = torch.randn(2, 3, SZ, SZ)
            ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
            ms.speedup_model()

    net.to(args.device)
    if torch.cuda.is_available() and args.n_gpu > 1:
        net = torch.nn.DataParallel(net, list(range(args.n_gpu)))
    return net
Example #6
0
    def test_speedup_vgg16(self):
        prune_model_l1(vgg16())
        model = vgg16()
        model.train()
        ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
        ms.speedup_model()

        orig_model = vgg16()
        assert model.training
        assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
        assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)
Example #7
0
 def test_speedup_tupleunpack(self):
     """This test is reported in issue3645"""
     model = TupleUnpack_Model()
     cfg_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
     dummy_input = torch.rand(2, 3, 224, 224)
     pruner = L1FilterPruner(model, cfg_list)
     pruner.compress()
     model(dummy_input)
     pruner.export_model(MODEL_FILE, MASK_FILE)
     ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
     ms.speedup_model()
Example #8
0
    def test_dependency_aware_random_config(self):
        model_zoo = ['resnet18']
        pruners = [
            L1FilterPruner, L2FilterPruner, FPGMPruner,
            TaylorFOWeightFilterPruner, ActivationMeanRankFilterPruner,
            ActivationAPoZRankFilterPruner
        ]
        dummy_input = torch.ones(1, 3, 224, 224)
        for model_name in model_zoo:
            for pruner in pruners:
                Model = getattr(models, model_name)
                cfg_generator = [
                    generate_random_sparsity, generate_random_sparsity_v2
                ]
                for _generator in cfg_generator:
                    net = Model(pretrained=True, progress=False)
                    cfg_list = _generator(net)

                    print('\n\nModel:', model_name)
                    print('Pruner', pruner)
                    print('Config_list:', cfg_list)
                    # for the pruners that based on the activations, we need feed
                    # enough data before we call the compress function.
                    optimizer = torch.optim.SGD(net.parameters(),
                                                lr=0.0001,
                                                momentum=0.9,
                                                weight_decay=4e-5)
                    criterion = torch.nn.CrossEntropyLoss()

                    if pruner in (TaylorFOWeightFilterPruner,
                                  ActivationMeanRankFilterPruner,
                                  ActivationAPoZRankFilterPruner):
                        tmp_pruner = pruner(net,
                                            cfg_list,
                                            optimizer,
                                            trainer=trainer,
                                            criterion=criterion,
                                            dependency_aware=True,
                                            dummy_input=dummy_input)
                    else:
                        tmp_pruner = pruner(net,
                                            cfg_list,
                                            dependency_aware=True,
                                            dummy_input=dummy_input)

                    tmp_pruner.compress()
                    tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
                    # if we want to use the same model, we should unwrap the pruner before the speedup
                    tmp_pruner._unwrap_model()
                    ms = ModelSpeedup(net, dummy_input, MASK_FILE)
                    ms.speedup_model()
Example #9
0
    def test_speedup_integration(self):
        for model_name in [
                'resnet18',
                'squeezenet1_1',
                'mobilenet_v2',
                'densenet121',
                # 'inception_v3' inception is too large and may fail the pipeline
                'densenet169',
                'resnet50'
        ]:
            kwargs = {'pretrained': True}
            if model_name == 'resnet50':
                # testing multiple groups
                kwargs = {'pretrained': False, 'groups': 4}

            Model = getattr(models, model_name)
            net = Model(**kwargs).to(device)
            speedup_model = Model(**kwargs).to(device)
            net.eval()  # this line is necessary
            speedup_model.eval()
            # random generate the prune config for the pruner
            cfgs = generate_random_sparsity(net)
            pruner = L1FilterPruner(net, cfgs)
            pruner.compress()
            pruner.export_model(MODEL_FILE, MASK_FILE)
            pruner._unwrap_model()
            state_dict = torch.load(MODEL_FILE)
            speedup_model.load_state_dict(state_dict)
            zero_bn_bias(net)
            zero_bn_bias(speedup_model)

            data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
            ms = ModelSpeedup(speedup_model, data, MASK_FILE)
            ms.speedup_model()

            speedup_model.eval()

            ori_out = net(data)
            speeded_out = speedup_model(data)
            ori_sum = torch.sum(ori_out).item()
            speeded_sum = torch.sum(speeded_out).item()
            print('Sum of the output of %s (before speedup):' % model_name,
                  ori_sum)
            print('Sum of the output of %s (after speedup):' % model_name,
                  speeded_sum)
            assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                   (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
Example #10
0
def get_model_optimizer_scheduler(args, device, test_loader, criterion):
    if args.model == 'LeNet':
        model = LeNet().to(device)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
    elif args.model == 'vgg19':
        model = VGG(depth=19).to(device)
    else:
        raise ValueError("model not recognized")

    # In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
    if args.teacher_model_dir is None:
        raise NotImplementedError('please load pretrained teacher model first')

    else:
        model.load_state_dict(torch.load(args.teacher_model_dir))
        best_acc = test(args, model, device, criterion, test_loader)

    model_t = deepcopy(model)
    model_s = deepcopy(model)

    if args.student_model_dir is not None:
        # load the pruned student model checkpoint
        model_s.load_state_dict(torch.load(args.student_model_dir))

    dummy_input = get_dummy_input(args, device)
    m_speedup = ModelSpeedup(model_s, dummy_input, args.mask_path, device)
    m_speedup.speedup_model()

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    module_list.append(model_t)

    # setup opotimizer for fine-tuning studeng model
    optimizer = torch.optim.SGD(model_s.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(args.fine_tune_epochs * 0.5),
                                int(args.fine_tune_epochs * 0.75)
                            ],
                            gamma=0.1)

    print('Pretrained teacher model acc:', best_acc)
    return module_list, optimizer, scheduler
Example #11
0
    def test_multiplication_speedup(self):
        """
        Model from issue 4540.
        """
        class Net(torch.nn.Module):
            def __init__(self, ):
                super(Net, self).__init__()
                self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
                self.input = torch.nn.Conv2d(3, 8, 3)
                self.bn = torch.nn.BatchNorm2d(8)
                self.fc1 = torch.nn.Conv2d(8, 16, 1)
                self.fc2 = torch.nn.Conv2d(16, 8, 1)
                self.activation = torch.nn.ReLU()
                self.scale_activation = torch.nn.Hardsigmoid()
                self.out = torch.nn.Conv2d(8, 12, 1)

            def forward(self, input):
                input = self.activation(self.bn(self.input(input)))
                scale = self.avgpool(input)
                out1 = self.activation(self.fc1(scale))
                out1 = self.scale_activation(self.fc2(out1))
                return self.out(out1 * input)

        model = Net().to(device)
        model.eval()
        im = torch.ones(1, 3, 512, 512).to(device)
        model(im)
        cfg_list = []

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                cfg_list.append({
                    'op_types': ['Conv2d'],
                    'sparsity': 0.3,
                    'op_names': [name]
                })

        pruner = L1FilterPruner(model, cfg_list)
        pruner.compress()
        pruner.export_model(MODEL_FILE, MASK_FILE)
        pruner._unwrap_model()
        ms = ModelSpeedup(model, im, MASK_FILE)
        ms.speedup_model()
Example #12
0
    def test_channel_prune(self):
        orig_net = resnet18(num_classes=10).to(device)
        channel_prune(orig_net)
        state_dict = torch.load(MODEL_FILE)

        orig_net = resnet18(num_classes=10).to(device)
        orig_net.load_state_dict(state_dict)
        apply_compression_results(orig_net, MASK_FILE)
        orig_net.eval()

        net = resnet18(num_classes=10).to(device)

        net.load_state_dict(state_dict)
        net.eval()

        data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
        ms = ModelSpeedup(net, data, MASK_FILE, confidence=8)
        ms.speedup_model()
        ms.bound_model(data)

        net.eval()

        ori_sum = orig_net(data).abs().sum().item()
        speeded_sum = net(data).abs().sum().item()

        print(ori_sum, speeded_sum)
        assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
            (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
Example #13
0
def model_inference(config):
    masks_file = config['masks_file']
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # device = torch.device(config['device'])
    if config['model_name'] == 'vgg16':
        model = VGG(depth=16)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'lenet':
        model = LeNet()

    model.to(device)
    model.eval()

    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
    if use_mask:
        apply_compression_results(model, masks_file, device)
        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')
Example #14
0
 def test_convtranspose_model(self):
     ori_model = TransposeModel()
     dummy_input = torch.rand(1, 3, 8, 8)
     config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
     pruner = L1FilterPruner(ori_model, config_list)
     pruner.compress()
     ori_model(dummy_input)
     pruner.export_model(MODEL_FILE, MASK_FILE)
     pruner._unwrap_model()
     new_model = TransposeModel()
     state_dict = torch.load(MODEL_FILE)
     new_model.load_state_dict(state_dict)
     ms = ModelSpeedup(new_model, dummy_input, MASK_FILE)
     ms.speedup_model()
     zero_bn_bias(ori_model)
     zero_bn_bias(new_model)
     ori_out = ori_model(dummy_input)
     new_out = new_model(dummy_input)
     ori_sum = torch.sum(ori_out)
     speeded_sum = torch.sum(new_out)
     print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum))
     assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
             (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
Example #15
0
# %%
# Show the original model structure.
print(model)

# %%
# Roughly test the original model inference speed.
import time
start = time.time()
model(torch.rand(128, 1, 28, 28).to(device))
print('Original Model - Elapsed Time : ', time.time() - start)

# %%
# Speedup the model and show the model structure after speedup.
from nni.compression.pytorch import ModelSpeedup
ModelSpeedup(model,
             torch.rand(10, 1, 28, 28).to(device), masks).speedup_model()
print(model)

# %%
# Roughly test the model after speedup inference speed.
start = time.time()
model(torch.rand(128, 1, 28, 28).to(device))
print('Speedup Model - Elapsed Time : ', time.time() - start)

# %%
# For combining usage of ``Pruner`` masks generation with ``ModelSpeedup``,
# please refer to :doc:`Pruning Quick Start <pruning_quick_start_mnist>`.
#
# NOTE: The current implementation supports PyTorch 1.3.1 or newer.
#
# Limitations
Example #16
0
                                          trainer,
                                          traced_optimizer,
                                          criterion,
                                          training_batches=20)
    else:
        pruner = ActivationMeanRankPruner(model,
                                          config_list,
                                          trainer,
                                          traced_optimizer,
                                          criterion,
                                          training_batches=20)
    _, masks = pruner.compress()
    pruner.show_pruned_weights()
    pruner._unwrap_model()
    ModelSpeedup(model,
                 dummy_input=torch.rand([10, 3, 32, 32]).to(device),
                 masks_file=masks).speedup_model()
    print('\n' + '=' * 50 + ' EVALUATE THE MODEL AFTER SPEEDUP ' + '=' * 50)
    evaluator(model)

    # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
    print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
    optimizer, scheduler = optimizer_scheduler_generator(
        model, _lr=0.01, total_epoch=args.fine_tune_epochs)

    best_acc = 0.0
    g_epoch = 0
    for i in range(args.fine_tune_epochs):
        trainer(model, optimizer, criterion)
        scheduler.step()
        best_acc = max(evaluator(model), best_acc)
Example #17
0
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours
# Load the YOLO model
model = models.load_model(
  "%s/config/yolov3.cfg" % prefix, 
  "%s/yolov3.weights" % prefix).cpu()
model.eval()
dummy_input = torch.rand(8, 3, 320, 320)
model(dummy_input)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe = not_safe_to_prune(model, dummy_input)
cfg_list = []
for name, module in model.named_modules():
  if name in not_safe:
    continue
  if isinstance(module, torch.nn.Conv2d):
    cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]})
# Prune the model
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
pruner._unwrap_model()
# Speedup the model
ms = ModelSpeedup(model, dummy_input, './mask')

ms.speedup_model()
model(dummy_input)

Example #18
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.experiment_data_dir, exist_ok=True)

    # prepare model and data
    train_loader, test_loader, criterion = get_data(args.dataset,
                                                    args.data_dir,
                                                    args.batch_size,
                                                    args.test_batch_size)

    model, optimizer, scheduler = get_model_optimizer_scheduler(
        args, device, train_loader, test_loader, criterion)

    dummy_input = get_dummy_input(args, device)
    flops, params, results = count_flops_params(model, dummy_input)
    print(f"FLOPs: {flops}, params: {params}")

    print('start pruning...')
    model_path = os.path.join(
        args.experiment_data_dir,
        'pruned_{}_{}_{}.pth'.format(args.model, args.dataset, args.pruner))
    mask_path = os.path.join(
        args.experiment_data_dir,
        'mask_{}_{}_{}.pth'.format(args.model, args.dataset, args.pruner))

    pruner = get_pruner(model, args.pruner, device, optimizer,
                        args.dependency_aware)
    model = pruner.compress()

    if args.multi_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    if args.test_only:
        test(args, model, device, criterion, test_loader)

    best_top1 = 0
    for epoch in range(args.fine_tune_epochs):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        top1 = test(args, model, device, criterion, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path=model_path, mask_path=mask_path)

    if args.nni:
        nni.report_final_result(best_top1)

    if args.speed_up:
        # reload the best checkpoint for speed-up
        args.pretrained_model_dir = model_path
        model, _, _ = get_model_optimizer_scheduler(args, device, train_loader,
                                                    test_loader, criterion)
        model.eval()

        apply_compression_results(model, mask_path, device)

        # test model speed
        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)

        m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
        m_speedup.speedup_model()

        flops, params, results = count_flops_params(model, dummy_input)
        print(f"FLOPs: {flops}, params: {params}")

        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)

        top1 = test(args, model, device, criterion, test_loader)
Example #19
0
    def compress(self):
        """
        Compress the model with AutoCompress.

        Returns
        -------
        torch.nn.Module
            model with specified modules compressed.
        """
        _logger.info('Starting AutoCompress pruning...')

        sparsity_each_round = 1 - pow(1 - self._sparsity, 1 / self._num_iterations)

        for i in range(self._num_iterations):
            _logger.info('Pruning iteration: %d', i)
            _logger.info('Target sparsity this round: %s',
                         1 - pow(1 - sparsity_each_round, i + 1))

            # SimulatedAnnealingPruner
            _logger.info(
                'Generating sparsities with SimulatedAnnealingPruner...')
            SApruner = SimulatedAnnealingPruner(
                model=copy.deepcopy(self._model_to_prune),
                config_list=[
                    {"sparsity": sparsity_each_round, "op_types": ['Conv2d']}],
                evaluator=self._evaluator,
                optimize_mode=self._optimize_mode,
                base_algo=self._base_algo,
                start_temperature=self._start_temperature,
                stop_temperature=self._stop_temperature,
                cool_down_rate=self._cool_down_rate,
                perturbation_magnitude=self._perturbation_magnitude,
                experiment_data_dir=self._experiment_data_dir)
            config_list = SApruner.compress(return_config_list=True)
            _logger.info("Generated config_list : %s", config_list)

            # ADMMPruner
            _logger.info('Performing structured pruning with ADMMPruner...')
            ADMMpruner = ADMMPruner(
                model=copy.deepcopy(self._model_to_prune),
                config_list=config_list,
                criterion=self._criterion,
                trainer=self._trainer,
                num_iterations=self._admm_num_iterations,
                epochs_per_iteration=self._admm_epochs_per_iteration,
                row=self._row,
                base_algo=self._base_algo)
            ADMMpruner.compress()

            ADMMpruner.export_model(os.path.join(self._experiment_data_dir, 'model_admm_masked.pth'), os.path.join(
                self._experiment_data_dir, 'mask.pth'))

            # use speed up to prune the model before next iteration,
            # because SimulatedAnnealingPruner & ADMMPruner don't take masked models
            self._model_to_prune.load_state_dict(torch.load(os.path.join(
                self._experiment_data_dir, 'model_admm_masked.pth')))

            masks_file = os.path.join(self._experiment_data_dir, 'mask.pth')
            device = next(self._model_to_prune.parameters()).device

            _logger.info('Speeding up models...')
            m_speedup = ModelSpeedup(self._model_to_prune, self._dummy_input, masks_file, device)
            m_speedup.speedup_model()

            evaluation_result = self._evaluator(self._model_to_prune)
            _logger.info('Evaluation result of the pruned model in iteration %d: %s', i, evaluation_result)

        _logger.info('----------Compression finished--------------')

        os.remove(os.path.join(self._experiment_data_dir, 'model_admm_masked.pth'))
        os.remove(os.path.join(self._experiment_data_dir, 'mask.pth'))

        return self._model_to_prune
model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
model.cuda()
model.eval()
model(dummy_input)  #first time infer will cost much time

# mask
use_mask_out = use_speedup_out = None
apply_compression_results(model, 'mask_vgg19_cifar10.pth')
start = time.time()
for _ in range(320):
    use_mask_out = model(dummy_input)
print('elapsed time when use mask: ', time.time() - start)
print(test(model))

# speedup
m_speedup = ModelSpeedup(model, dummy_input, 'mask_vgg19_cifar10.pth')
m_speedup.speedup_model()
start = time.time()
for _ in range(320):
    use_speedup_out = model(dummy_input)
print('elapsed time when use speedup: ', time.time() - start)
print(test(model))

# 确定两输出是一致的
if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
    print('the outputs from use_mask and use_speedup are the same')

# 保存模型
trace_model = torch.jit.trace(model, torch.ones(1, 3, 32, 32).cuda())
trace_model.save('trace_vgg_cifar10.pt')
print(test(trace_model))
Example #21
0
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance':{}}

    flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2', 'fpgm']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{
        'sparsity': args.sparsity,
        'op_types': op_types
    }]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'FPGMPruner':
        pruner = FPGMPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator,
                                base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2', 'fpgm']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model, config_list, evaluator=evaluator, base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input,
            num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError(
            "Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s' % args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)

            model.load_state_dict(torch.load(os.path.join(args.experiment_data_dir, 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s' % args.experiment_data_dir)
        flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s' % args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f:
        json.dump(result, f)
Example #22
0
def run_pruning(args):
    print(args)
    torch.set_num_threads(args.n_workers)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    log = open(
        args.experiment_dir + '/pruning_{}_{}_sparsity{}_{}.log'.format(
            args.pruner_name, args.pruning_mode, args.sparsity,
            strftime("%Y%m%d%H%M", gmtime())), 'w')

    train_dataset = TrainDataset('./data/stanford-dogs/Processed/train')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    train_dataset_for_pruner = EvalDataset(
        './data/stanford-dogs/Processed/train')
    train_dataloader_for_pruner = DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False)
    valid_dataset = EvalDataset('./data/stanford-dogs/Processed/valid')
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False)
    test_dataset = EvalDataset('./data/stanford-dogs/Processed/test')
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)

    model = create_model(model_type=model_type,
                         pretrained=False,
                         n_classes=n_classes,
                         input_size=input_size,
                         checkpoint=args.experiment_dir + '/' +
                         args.checkpoint_name)
    model = model.to(device)

    teacher_model = None
    if args.kd:
        teacher_model = copy.deepcopy(model)

    # evaluation before pruning
    # count_flops(model, log, device)
    initial_loss, initial_acc = run_eval(model, test_dataloader, device)
    print('Before Pruning:\nLoss: {}\nAccuracy: {}'.format(
        initial_loss, initial_acc))
    log.write('Before Pruning:\nLoss: {}\nAccuracy: {}\n'.format(
        initial_loss, initial_acc))

    # set up config list and pruner
    config_list = []
    if 'conv0' in args.pruning_mode or args.pruning_mode == 'all':
        if args.pruner_name == 'slim' or (args.pruner_name == 'agp'
                                          and args.agp_pruning_alg == 'slim'):
            config_list.append({
                'op_names':
                ['features.{}.conv.0.1'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
        else:
            config_list.append({
                'op_names':
                ['features.{}.conv.0.0'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
    if 'conv1' in args.pruning_mode or args.pruning_mode == 'all':
        if args.pruner_name == 'slim' or (args.pruner_name == 'agp'
                                          and args.agp_pruning_alg == 'slim'):
            config_list.append({
                'op_names':
                ['features.{}.conv.1.1'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
        else:
            config_list.append({
                'op_names':
                ['features.{}.conv.1.0'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
    if 'conv2' in args.pruning_mode or args.pruning_mode == 'all':
        if args.pruner_name == 'slim' or (args.pruner_name == 'agp'
                                          and args.agp_pruning_alg == 'slim'):
            config_list.append({
                'op_names':
                ['features.{}.conv.3'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
        else:
            config_list.append({
                'op_names':
                ['features.{}.conv.2'.format(x) for x in range(2, 18)],
                'sparsity':
                args.sparsity
            })
    print(config_list)

    kwargs = {}
    if args.pruner_name in [
            'slim', 'taylorfo', 'mean_activation', 'apoz', 'agp'
    ]:

        def trainer(model, optimizer, criterion, epoch):
            if not args.kd:
                return trainer_helper(model, criterion, optimizer,
                                      train_dataloader, device)
            else:
                return trainer_helper_with_distillation(
                    model, teacher_model, args.alpha, args.temp, optimizer,
                    train_dataloader, device)

        kwargs = {
            'trainer': trainer,
            'optimizer': torch.optim.Adam(model.parameters()),
            'criterion': nn.CrossEntropyLoss()
        }
        if args.pruner_name == 'agp':
            kwargs['pruning_algorithm'] = args.agp_pruning_alg
            kwargs['num_iterations'] = args.agp_n_iters
            kwargs['epochs_per_iteration'] = args.agp_n_epochs_per_iter
        if args.pruner_name == 'slim':
            kwargs['sparsifying_training_epochs'] = 10

    # pruning
    pruner = pruner_type_to_class[args.pruner_name](model, config_list,
                                                    **kwargs)
    pruner.compress()
    pruner.export_model(args.experiment_dir + '/model_temp.pth',
                        args.experiment_dir + './mask_temp.pth')

    # model speedup
    pruner._unwrap_model()
    if args.speedup:
        dummy_input = torch.rand(1, 3, 224, 224).to(device)
        ms = ModelSpeedup(model, dummy_input,
                          args.experiment_dir + './mask_temp.pth')
        ms.speedup_model()
        print(model)
        count_flops(model, log)

    intermediate_loss, intermediate_acc = run_eval(model, test_dataloader,
                                                   device)
    print('Before Finetuning:\nLoss: {}\nAccuracy: {}'.format(
        intermediate_loss, intermediate_acc))
    log.write('Before Finetuning:\nLoss: {}\nAccuracy: {}\n'.format(
        intermediate_loss, intermediate_acc))

    # finetuning
    if args.kd:
        model = run_finetune_distillation(model,
                                          teacher_model,
                                          train_dataloader,
                                          valid_dataloader,
                                          device,
                                          args.alpha,
                                          args.temp,
                                          n_epochs=args.finetune_epochs,
                                          learning_rate=args.learning_rate,
                                          weight_decay=args.weight_decay)
    else:
        model = run_finetune(model,
                             train_dataloader,
                             valid_dataloader,
                             device,
                             n_epochs=args.finetune_epochs,
                             learning_rate=args.learning_rate,
                             weight_decay=args.weight_decay)

    # final evaluation
    final_loss, final_acc = run_eval(model, test_dataloader, device)
    print('After Pruning:\nLoss: {}\nAccuracy: {}'.format(
        final_loss, final_acc))
    log.write('After Pruning:\nLoss: {}\nAccuracy: {}'.format(
        final_loss, final_acc))

    # clean up
    filePaths = [
        args.experiment_dir + '/model_tmp.pth',
        args.experiment_dir + '/mask_tmp.pth'
    ]
    for f in filePaths:
        if os.path.exists(f):
            os.remove(f)

    log.close()
Example #23
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.experiment_data_dir, exist_ok=True)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, transform=transform),
        batch_size=64,
    )
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        'data', train=False, transform=transform),
                                              batch_size=1000)

    # Step1. Model Pretraining
    model = NaiveModel().to(device)
    criterion = torch.nn.NLLLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False)

    if args.pretrained_model_dir is None:
        args.pretrained_model_dir = os.path.join(args.experiment_data_dir,
                                                 f'pretrained.pth')

        best_acc = 0
        for epoch in range(args.pretrain_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = test(args, model, device, criterion, test_loader)
            if acc > best_acc:
                best_acc = acc
                state_dict = model.state_dict()

        model.load_state_dict(state_dict)
        torch.save(state_dict, args.pretrained_model_dir)
        print(f'Model saved to {args.pretrained_model_dir}')
    else:
        state_dict = torch.load(args.pretrained_model_dir)
        model.load_state_dict(state_dict)
        best_acc = test(args, model, device, criterion, test_loader)

    dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
    time_cost = get_model_time_cost(model, dummy_input)

    # 125.49 M, 0.85M, 93.29, 1.1012
    print(
        f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}'
    )

    # Step2. Model Pruning
    config_list = [{'sparsity': args.sparsity, 'op_types': ['Conv2d']}]

    kw_args = {}
    if args.dependency_aware:
        dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
        print('Enable the dependency_aware mode')
        # note that, not all pruners support the dependency_aware mode
        kw_args['dependency_aware'] = True
        kw_args['dummy_input'] = dummy_input

    pruner = L1FilterPruner(model, config_list, **kw_args)
    model = pruner.compress()
    pruner.get_pruned_weights()

    mask_path = os.path.join(args.experiment_data_dir, 'mask.pth')
    model_path = os.path.join(args.experiment_data_dir, 'pruned.pth')
    pruner.export_model(model_path=model_path, mask_path=mask_path)
    pruner._unwrap_model()  # unwrap all modules to normal state

    # Step3. Model Speedup
    m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
    m_speedup.speedup_model()
    print('model after speedup', model)

    flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
    acc = test(args, model, device, criterion, test_loader)
    time_cost = get_model_time_cost(model, dummy_input)
    print(
        f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}'
    )

    # Step4. Model Finetuning
    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

    best_acc = 0
    for epoch in range(args.finetune_epochs):
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        acc = test(args, model, device, criterion, test_loader)
        if acc > best_acc:
            best_acc = acc
            state_dict = model.state_dict()

    model.load_state_dict(state_dict)
    save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
    torch.save(state_dict, save_path)

    flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
    time_cost = get_model_time_cost(model, dummy_input)

    # FLOPs 28.48 M, #Params: 0.18M, Accuracy:  89.03, Time Cost: 1.03
    print(
        f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}'
    )
    print(f'Model saved to {save_path}')

    # Step5. Model Quantization via QAT
    config_list = [{
        'quant_types': ['weight', 'output'],
        'quant_bits': {
            'weight': 8,
            'output': 8
        },
        'op_names': ['conv1']
    }, {
        'quant_types': ['output'],
        'quant_bits': {
            'output': 8
        },
        'op_names': ['relu1']
    }, {
        'quant_types': ['weight', 'output'],
        'quant_bits': {
            'weight': 8,
            'output': 8
        },
        'op_names': ['conv2']
    }, {
        'quant_types': ['output'],
        'quant_bits': {
            'output': 8
        },
        'op_names': ['relu2']
    }]

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    quantizer = QAT_Quantizer(model, config_list, optimizer)
    quantizer.compress()

    # Step6. Quantization Aware Training
    best_acc = 0
    for epoch in range(1):
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        acc = test(args, model, device, criterion, test_loader)
        if acc > best_acc:
            best_acc = acc
            state_dict = model.state_dict()

    calibration_path = os.path.join(args.experiment_data_dir,
                                    'calibration.pth')
    calibration_config = quantizer.export_model(model_path, calibration_path)
    print("calibration_config: ", calibration_config)

    # Step7. Model Speedup
    batch_size = 32
    input_shape = (batch_size, 1, 28, 28)
    engine = ModelSpeedupTensorRT(model,
                                  input_shape,
                                  config=calibration_config,
                                  batchsize=32)
    engine.compress()

    test_trt(engine, test_loader)
Example #24
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.experiment_data_dir, exist_ok=True)

    # prepare model and data
    train_loader, test_loader, criterion = get_data(args.dataset,
                                                    args.data_dir,
                                                    args.batch_size,
                                                    args.test_batch_size)

    model, optimizer, scheduler = get_model_optimizer_scheduler(
        args, device, train_loader, test_loader, criterion)

    dummy_input = get_dummy_input(args, device)
    flops, params, results = count_flops_params(model, dummy_input)
    print(f"FLOPs: {flops}, params: {params}")

    print(f'start {args.pruner} pruning...')

    def trainer(model, optimizer, criterion, epoch):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch)

    pruner_cls = str2pruner[args.pruner]

    kw_args = {}
    config_list = [{'sparsity': args.sparsity, 'op_types': ['Conv2d']}]

    if args.pruner == 'level':
        config_list = [{'sparsity': args.sparsity, 'op_types': ['default']}]

    else:
        if args.dependency_aware:
            dummy_input = get_dummy_input(args, device)
            print('Enable the dependency_aware mode')
            # note that, not all pruners support the dependency_aware mode
            kw_args['dependency_aware'] = True
            kw_args['dummy_input'] = dummy_input
        if args.pruner not in ('l1filter', 'l2filter', 'fpgm'):
            # set only work for training aware pruners
            kw_args['trainer'] = trainer
            kw_args['optimizer'] = optimizer
            kw_args['criterion'] = criterion

        if args.pruner in ('mean_activation', 'apoz', 'taylorfo'):
            kw_args['sparsifying_training_batches'] = 1

        if args.pruner == 'slim':
            kw_args['sparsifying_training_epochs'] = 1

        if args.pruner == 'agp':
            kw_args['pruning_algorithm'] = 'l1'
            kw_args['num_iterations'] = 2
            kw_args['epochs_per_iteration'] = 1

        # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
        # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
        if args.pruner == 'slim':
            config_list = [{
                'sparsity': args.sparsity,
                'op_types': ['BatchNorm2d'],
            }]
        else:
            config_list = [{
                'sparsity':
                args.sparsity,
                'op_types': ['Conv2d'],
                'op_names': [
                    'feature.0', 'feature.24', 'feature.27', 'feature.30',
                    'feature.34', 'feature.37'
                ]
            }]

    pruner = pruner_cls(model, config_list, **kw_args)

    # Pruner.compress() returns the masked model
    model = pruner.compress()
    pruner.get_pruned_weights()

    # export the pruned model masks for model speedup
    model_path = os.path.join(
        args.experiment_data_dir,
        'pruned_{}_{}_{}.pth'.format(args.model, args.dataset, args.pruner))
    mask_path = os.path.join(
        args.experiment_data_dir,
        'mask_{}_{}_{}.pth'.format(args.model, args.dataset, args.pruner))
    pruner.export_model(model_path=model_path, mask_path=mask_path)

    if args.test_only:
        test(args, model, device, criterion, test_loader)

    if args.speed_up:
        # Unwrap all modules to normal state
        pruner._unwrap_model()
        m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
        m_speedup.speedup_model()

    print('start finetuning...')
    best_top1 = 0
    save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
    for epoch in range(args.fine_tune_epochs):
        print('# Epoch {} #'.format(epoch))
        train(args, model, device, train_loader, criterion, optimizer, epoch)
        scheduler.step()
        top1 = test(args, model, device, criterion, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            torch.save(model.state_dict(), save_path)

    flops, params, results = count_flops_params(model, dummy_input)
    print(
        f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_top1: .2f}'
    )

    if args.nni:
        nni.report_final_result(best_top1)