def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases): self.args = args self.cuda = args.cuda self.model = model self.dataset2idx = dataset2idx self.dataset2biases = dataset2biases if args.mode != 'check': # Set up data loader, criterion, and pruner. if "CIFAR100": self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args.train_path, args.batch_size) if 'cropped' in args.train_path: train_loader = dataset.train_loader_cropped test_loader = dataset.test_loader_cropped else: train_loader = dataset.train_loader test_loader = dataset.test_loader self.train_data_loader = train_loader( args.train_path, args.batch_size, pin_memory=args.cuda) self.test_data_loader = test_loader( args.test_path, args.batch_size, pin_memory=args.cuda) self.criterion = nn.CrossEntropyLoss() self.pruner = SparsePruner( self.model, self.args.prune_perc_per_layer, previous_masks, self.args.train_biases, self.args.train_bn)
def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases, soft_labels=False, prune_per=None): self.args = args self.cuda = args.cuda self.model = model self.dataset2idx = dataset2idx self.dataset2biases = dataset2biases self.pruning_record = {} self.soft_labels = soft_labels if args.mode != 'check': # Set up data loader, criterion, and pruner. # if 'cropped' in args.train_path: # print("PATH", args.train_path) # train_loader = dataset.train_loader_cropped # test_loader = dataset.test_loader_cropped # else: # train_loader = dataset.train_loader # test_loader = dataset.test_loader # self.train_data_loader = train_loader( # args.train_path, args.batch_size, pin_memory=args.cuda) # self.test_data_loader = test_loader( # args.test_path, args.batch_size, pin_memory=args.cuda) self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args) self.criterion = nn.CrossEntropyLoss() if prune_per: # Use dynamic pruning ratio self.pruner = SparsePruner( self.model, prune_per, previous_masks, self.args.train_biases, self.args.train_bn) else: self.pruner = SparsePruner( self.model, self.args.prune_perc_per_layer, previous_masks, self.args.train_biases, self.args.train_bn)
class Manager(object): """Handles training and pruning.""" def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases): self.args = args self.cuda = args.cuda self.model = model self.dataset2idx = dataset2idx self.dataset2biases = dataset2biases if args.mode != 'check': # Set up data loader, criterion, and pruner. if "CIFAR100": self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args.train_path, args.batch_size) if 'cropped' in args.train_path: train_loader = dataset.train_loader_cropped test_loader = dataset.test_loader_cropped else: train_loader = dataset.train_loader test_loader = dataset.test_loader self.train_data_loader = train_loader( args.train_path, args.batch_size, pin_memory=args.cuda) self.test_data_loader = test_loader( args.test_path, args.batch_size, pin_memory=args.cuda) self.criterion = nn.CrossEntropyLoss() self.pruner = SparsePruner( self.model, self.args.prune_perc_per_layer, previous_masks, self.args.train_biases, self.args.train_bn) def eval(self, dataset_idx, biases=None): """Performs evaluation.""" if not self.args.disable_pruning_mask: self.pruner.apply_mask(dataset_idx) if biases is not None: self.pruner.restore_biases(biases) self.model.eval() error_meter = None print('Performing eval...') for batch, label in tqdm(self.test_data_loader, desc='Eval'): if self.cuda: batch = batch.cuda() batch = Variable(batch, volatile=True) output = self.model(batch) # Init error meter. if error_meter is None: topk = [1] if output.size(1) > 5: topk.append(5) error_meter = tnt.meter.ClassErrorMeter(topk=topk) error_meter.add(output.data, label) errors = error_meter.value() print('Error: ' + ', '.join('@%s=%.2f' % t for t in zip(topk, errors))) if self.args.train_bn: self.model.train() else: self.model.train_nobn() return errors def do_batch(self, optimizer, batch, label): """Runs model for one batch.""" if self.cuda: batch = batch.cuda() label = label.cuda() batch = Variable(batch) label = Variable(label) # Set grads to 0. self.model.zero_grad() # Do forward-backward. output = self.model(batch) self.criterion(output, label).backward() # Set fixed param grads to 0. if not self.args.disable_pruning_mask: self.pruner.make_grads_zero() # Update params. optimizer.step() # Set pruned weights to 0. if not self.args.disable_pruning_mask: self.pruner.make_pruned_zero() def do_epoch(self, epoch_idx, optimizer): """Trains model for one epoch.""" for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx)): self.do_batch(optimizer, batch, label) def save_model(self, epoch, best_accuracy, errors, savename): """Saves model to file.""" base_model = self.model # Prepare the ckpt. self.dataset2idx[self.args.dataset] = self.pruner.current_dataset_idx self.dataset2biases[self.args.dataset] = self.pruner.get_biases() ckpt = { 'args': self.args, 'epoch': epoch, 'accuracy': best_accuracy, 'errors': errors, 'dataset2idx': self.dataset2idx, 'previous_masks': self.pruner.current_masks, 'model': base_model, } if self.args.train_biases: ckpt['dataset2biases'] = self.dataset2biases # Save to file. torch.save(ckpt, savename + '.pt') def train(self, epochs, optimizer, save=True, savename='', best_accuracy=0): """Performs training.""" best_accuracy = best_accuracy error_history = [] if self.args.cuda: self.model = self.model.cuda() for idx in range(epochs): epoch_idx = idx + 1 print('Epoch: %d' % (epoch_idx)) optimizer = utils.step_lr(epoch_idx, self.args.lr, self.args.lr_decay_every, self.args.lr_decay_factor, optimizer) if self.args.train_bn: self.model.train() else: self.model.train_nobn() self.do_epoch(epoch_idx, optimizer) errors = self.eval(self.pruner.current_dataset_idx) error_history.append(errors) accuracy = 100 - errors[0] # Top-1 accuracy. # Save performance history and stats. with open(savename + '.json', 'w') as fout: json.dump({ 'error_history': error_history, 'args': vars(self.args), }, fout) # Save best model, if required. if save and accuracy > best_accuracy: print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' % (best_accuracy, accuracy)) best_accuracy = accuracy self.save_model(epoch_idx, best_accuracy, errors, savename) print('Finished finetuning...') print('Best error/accuracy: %0.2f%%, %0.2f%%' % (100 - best_accuracy, best_accuracy)) print('-' * 16) def prune(self): """Perform pruning.""" print('Pre-prune eval:') self.eval(self.pruner.current_dataset_idx) self.pruner.prune() self.check(True) print('\nPost-prune eval:') errors = self.eval(self.pruner.current_dataset_idx) accuracy = 100 - errors[0] # Top-1 accuracy. self.save_model(-1, accuracy, errors, self.args.save_prefix + '_postprune') # Do final finetuning to improve results on pruned network. if self.args.post_prune_epochs: print('Doing some extra finetuning...') optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=self.args.weight_decay) self.train(self.args.post_prune_epochs, optimizer, save=True, savename=self.args.save_prefix + '_final', best_accuracy=accuracy) print('-' * 16) print('Pruning summary:') self.check(True) print('-' * 16) def check(self, verbose=False): """Makes sure that the layers are pruned.""" print('Checking...') for layer_idx, module in enumerate(self.model.shared.modules()): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): weight = module.weight.data num_params = weight.numel() num_zero = weight.view(-1).eq(0).sum() if verbose: print('Layer #%d: Pruned %d/%d (%.2f%%)' % (layer_idx, num_zero, num_params, 100 * num_zero / num_params))