示例#1
0
def test_agp(pruning_algorithm):
    model = Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    config_list = prune_config['agp']['config_list']

    pruner = AGP_Pruner(model,
                        config_list,
                        optimizer,
                        pruning_algorithm=pruning_algorithm)
    pruner.compress()

    x = torch.randn(2, 1, 28, 28)
    y = torch.tensor([0, 1]).long()

    for epoch in range(config_list[0]['start_epoch'],
                       config_list[0]['end_epoch'] + 1):
        pruner.update_epoch(epoch)
        out = model(x)
        loss = F.cross_entropy(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        target_sparsity = pruner.compute_target_sparsity(config_list[0])
        actual_sparsity = (model.conv1.weight_mask == 0
                           ).sum().item() / model.conv1.weight_mask.numel()
        # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
        assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
示例#2
0
def main():
    torch.manual_seed(0)
    device = torch.device('cpu')

    trans = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('data',
                                                              train=True,
                                                              download=True,
                                                              transform=trans),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('data',
                                                             train=False,
                                                             transform=trans),
                                              batch_size=1000,
                                              shuffle=True)

    model = Mnist()
    '''you can change this to SensitivityPruner to implement it
    pruner = SensitivityPruner(configure_list)
    '''
    configure_list = [{
        'initial_sparsity': 0,
        'final_sparsity': 0.8,
        'start_epoch': 1,
        'end_epoch': 10,
        'frequency': 1,
        'op_type': 'default'
    }]

    pruner = AGP_Pruner(configure_list)
    pruner(model)
    # you can also use compress(model) method
    # like that pruner.compress(model)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    for epoch in range(10):
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)

        pruner.update_epoch(epoch)
示例#3
0
def main():
    torch.manual_seed(0)
    device = torch.device('cuda')

    trans = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('data',
                                                              train=True,
                                                              download=True,
                                                              transform=trans),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('data',
                                                             train=False,
                                                             transform=trans),
                                              batch_size=1000,
                                              shuffle=True)

    model = Mnist()
    model = model.to(device)
    '''you can change this to LevelPruner to implement it
    pruner = LevelPruner(configure_list)
    '''
    configure_list = [{
        'initial_sparsity': 0,
        'final_sparsity': 0.8,
        'start_epoch': 0,
        'end_epoch': 10,
        'frequency': 1,
        'op_types': ['default']
    }]

    pruner = AGP_Pruner(model, configure_list)
    model = pruner.compress()
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    for epoch in range(10):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader)
    pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28],
                        device)