def test_agp(pruning_algorithm): model = Model() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) config_list = prune_config['agp']['config_list'] pruner = AGP_Pruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm) pruner.compress() x = torch.randn(2, 1, 28, 28) y = torch.tensor([0, 1]).long() for epoch in range(config_list[0]['start_epoch'], config_list[0]['end_epoch'] + 1): pruner.update_epoch(epoch) out = model(x) loss = F.cross_entropy(out, y) optimizer.zero_grad() loss.backward() optimizer.step() target_sparsity = pruner.compute_target_sparsity(config_list[0]) actual_sparsity = (model.conv1.weight_mask == 0 ).sum().item() / model.conv1.weight_mask.numel() # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small. assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
def main(): torch.manual_seed(0) device = torch.device('cpu') trans = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=trans), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=trans), batch_size=1000, shuffle=True) model = Mnist() '''you can change this to SensitivityPruner to implement it pruner = SensitivityPruner(configure_list) ''' configure_list = [{ 'initial_sparsity': 0, 'final_sparsity': 0.8, 'start_epoch': 1, 'end_epoch': 10, 'frequency': 1, 'op_type': 'default' }] pruner = AGP_Pruner(configure_list) pruner(model) # you can also use compress(model) method # like that pruner.compress(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(10): print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer) test(model, device, test_loader) pruner.update_epoch(epoch)
def main(): torch.manual_seed(0) device = torch.device('cuda') trans = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=trans), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=trans), batch_size=1000, shuffle=True) model = Mnist() model = model.to(device) '''you can change this to LevelPruner to implement it pruner = LevelPruner(configure_list) ''' configure_list = [{ 'initial_sparsity': 0, 'final_sparsity': 0.8, 'start_epoch': 0, 'end_epoch': 10, 'frequency': 1, 'op_types': ['default'] }] pruner = AGP_Pruner(model, configure_list) model = pruner.compress() model = model.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(10): pruner.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer) test(model, device, test_loader) pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28], device)