# Oneshot pruning to various levels
pruning_levels = [(.8**i, i) for i in range(15)]
# one-shot pruning for
for to_retain, idx in pruning_levels:
    print('Starting pruning, retaining {}%'.format(to_retain))
    net_prune = Network(trainloader, testloader, device=device)
    net_prune = net_prune.to(device)
    net_prune.load_state_dict(
        torch.load('./checkpoints/oneshot-pruning-cifar10/{}-trained'.format(
            EXPERIMENT_NAME)))
    val_acc, _ = net_prune.test()
    before_count = net_prune.param_count()

    print('Before pruning: {}, params: {}'.format(val_acc, before_count))
    pruner = SparsityPruner(net_prune)
    pruner.prune(to_retain, prune_global=True)
    after_count = net_prune.param_count()

    print('Fine-tuning on NORB {}...'.format(to_retain))
    net_retrain = Network(trainloader_tr, testloader_tr, device=device)
    net_retrain = net_retrain.to(device)
    net_retrain.load_state_dict(
        torch.load('./checkpoints/oneshot-pruning-cifar10/{}-init'.format(
            EXPERIMENT_NAME)))

    pruner_retrain = SparsityPruner(net_retrain)
    pruner_retrain.masks = pruner.masks

    pruner_retrain.apply_mask(prune_global=True)
# One-shot pruning & fine-tuning
to_retain = 0.2
#pruning_iter = 5
to_retain_iter = to_retain**(1 / pruning_iter)

net_ft = Network(trainloader, testloader)
net_ft = net_ft.to(device)
net_ft.load_state_dict(
    torch.load(
        './checkpoints/iterative-pruning/{}-trained'.format(EXPERIMENT_NAME)))
val_acc, _ = net_ft.test()
before_count = net_ft.param_count()

print('Before pruning: {}, params: {}'.format(val_acc, before_count))
pruner = SparsityPruner(net_ft)

optimizer = optim.SGD(net_ft.parameters(),
                      lr=1e-3,
                      momentum=0.9,
                      weight_decay=5e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for prune_epoch in range(pruning_iter):
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    pruner.prune(to_retain_iter)
    after_count = net_ft.param_count()
    print('Starting pruning iteration {}, % pruned: {}'.format(
        prune_epoch + 1, after_count / before_count))
    train_losses, val_losses, train_accs, val_accs = net_ft.train_epoch(
        prune_epoch,
예제 #3
0
# Retrain from late reset initialization
for prune_epoch in range(pruning_iter):
    pruner_path = "experiment_data/iterative-pruning-fmnist-resnet/{}-{}.p".format(
        EXPERIMENT_NAME, prune_epoch + 1)
    print('Retraining from late reset on NORB at pruning level {}...'.format(
        prune_epoch + 1))
    net_late_reset = Network(trainloader_tr, testloader_tr, device=device)
    net_late_reset = net_late_reset.to(device)
    net_late_reset.load_state_dict(
        torch.load(
            './checkpoints/iterative-pruning-fmnist-resnet/{}-trained'.format(
                EXPERIMENT_NAME)))

    pruner = pickle.load(open(pruner_path, 'rb'))
    _masks = pruner.masks
    pruner_late_reset = SparsityPruner(net_late_reset)
    pruner_late_reset.masks = _masks

    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    pruner_late_reset.apply_mask(prune_global=True)
    print(net_late_reset.param_count())

    for epoch in range(N_EPOCH):
        print('Starting epoch {}'.format(epoch + 1))
        optimizer = optim.SGD(net_late_reset.parameters(),
                              lr=get_lr(epoch),
                              momentum=0.9,
                              weight_decay=1e-4)
        plt_data = (train_losses, val_losses, train_accs, val_accs)
        train_losses, val_losses, train_accs, val_accs, stop = net_late_reset.train_epoch(
        EXPERIMENT_NAME),
    index=None)

# Iterative pruning
to_retain_iter = 0.8

net_ft = Network(trainloader, testloader, device=device)
net_ft = net_ft.to(device)
net_ft.load_state_dict(
    torch.load('./checkpoints/iterative-pruning-cifar10/{}-trained'.format(
        EXPERIMENT_NAME)))
val_acc, _ = net_ft.test()
before_count = net_ft.param_count()

print('Before pruning: {}, params: {}'.format(val_acc, before_count))
pruner = SparsityPruner(net_ft)

optimizer = optim.SGD(net_ft.parameters(),
                      lr=1e-2,
                      momentum=0.9,
                      weight_decay=1e-4)
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for prune_epoch in range(pruning_iter):
    plt_data = (train_losses, val_losses, train_accs, val_accs)
    pruner.prune(to_retain_iter, prune_global=True)
    after_count = net_ft.param_count()
    print('Starting pruning iteration {}, pct. remaining : {}'.format(
        prune_epoch + 1, after_count / before_count))
    for ft_epoch in range(n_ft_epochs):
        train_losses, val_losses, train_accs, val_accs, stopped = net_ft.train_epoch(