# Oneshot pruning to various levels pruning_levels = [(.8**i, i) for i in range(15)] # one-shot pruning for for to_retain, idx in pruning_levels: print('Starting pruning, retaining {}%'.format(to_retain)) net_prune = Network(trainloader, testloader, device=device) net_prune = net_prune.to(device) net_prune.load_state_dict( torch.load('./checkpoints/oneshot-pruning-cifar10/{}-trained'.format( EXPERIMENT_NAME))) val_acc, _ = net_prune.test() before_count = net_prune.param_count() print('Before pruning: {}, params: {}'.format(val_acc, before_count)) pruner = SparsityPruner(net_prune) pruner.prune(to_retain, prune_global=True) after_count = net_prune.param_count() print('Fine-tuning on NORB {}...'.format(to_retain)) net_retrain = Network(trainloader_tr, testloader_tr, device=device) net_retrain = net_retrain.to(device) net_retrain.load_state_dict( torch.load('./checkpoints/oneshot-pruning-cifar10/{}-init'.format( EXPERIMENT_NAME))) pruner_retrain = SparsityPruner(net_retrain) pruner_retrain.masks = pruner.masks pruner_retrain.apply_mask(prune_global=True)
# One-shot pruning & fine-tuning to_retain = 0.2 #pruning_iter = 5 to_retain_iter = to_retain**(1 / pruning_iter) net_ft = Network(trainloader, testloader) net_ft = net_ft.to(device) net_ft.load_state_dict( torch.load( './checkpoints/iterative-pruning/{}-trained'.format(EXPERIMENT_NAME))) val_acc, _ = net_ft.test() before_count = net_ft.param_count() print('Before pruning: {}, params: {}'.format(val_acc, before_count)) pruner = SparsityPruner(net_ft) optimizer = optim.SGD(net_ft.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4) train_losses, val_losses, train_accs, val_accs = [], [], [], [] for prune_epoch in range(pruning_iter): plt_data = (train_losses, val_losses, train_accs, val_accs) pruner.prune(to_retain_iter) after_count = net_ft.param_count() print('Starting pruning iteration {}, % pruned: {}'.format( prune_epoch + 1, after_count / before_count)) train_losses, val_losses, train_accs, val_accs = net_ft.train_epoch( prune_epoch,
# Retrain from late reset initialization for prune_epoch in range(pruning_iter): pruner_path = "experiment_data/iterative-pruning-fmnist-resnet/{}-{}.p".format( EXPERIMENT_NAME, prune_epoch + 1) print('Retraining from late reset on NORB at pruning level {}...'.format( prune_epoch + 1)) net_late_reset = Network(trainloader_tr, testloader_tr, device=device) net_late_reset = net_late_reset.to(device) net_late_reset.load_state_dict( torch.load( './checkpoints/iterative-pruning-fmnist-resnet/{}-trained'.format( EXPERIMENT_NAME))) pruner = pickle.load(open(pruner_path, 'rb')) _masks = pruner.masks pruner_late_reset = SparsityPruner(net_late_reset) pruner_late_reset.masks = _masks train_losses, val_losses, train_accs, val_accs = [], [], [], [] pruner_late_reset.apply_mask(prune_global=True) print(net_late_reset.param_count()) for epoch in range(N_EPOCH): print('Starting epoch {}'.format(epoch + 1)) optimizer = optim.SGD(net_late_reset.parameters(), lr=get_lr(epoch), momentum=0.9, weight_decay=1e-4) plt_data = (train_losses, val_losses, train_accs, val_accs) train_losses, val_losses, train_accs, val_accs, stop = net_late_reset.train_epoch(
EXPERIMENT_NAME), index=None) # Iterative pruning to_retain_iter = 0.8 net_ft = Network(trainloader, testloader, device=device) net_ft = net_ft.to(device) net_ft.load_state_dict( torch.load('./checkpoints/iterative-pruning-cifar10/{}-trained'.format( EXPERIMENT_NAME))) val_acc, _ = net_ft.test() before_count = net_ft.param_count() print('Before pruning: {}, params: {}'.format(val_acc, before_count)) pruner = SparsityPruner(net_ft) optimizer = optim.SGD(net_ft.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4) train_losses, val_losses, train_accs, val_accs = [], [], [], [] for prune_epoch in range(pruning_iter): plt_data = (train_losses, val_losses, train_accs, val_accs) pruner.prune(to_retain_iter, prune_global=True) after_count = net_ft.param_count() print('Starting pruning iteration {}, pct. remaining : {}'.format( prune_epoch + 1, after_count / before_count)) for ft_epoch in range(n_ft_epochs): train_losses, val_losses, train_accs, val_accs, stopped = net_ft.train_epoch(