예제 #1
0
    def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        self.dataset2idx = dataset2idx
        self.dataset2biases = dataset2biases

        if args.mode != 'check':
            # Set up data loader, criterion, and pruner.
            if "CIFAR100":
                self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args.train_path, args.batch_size)

            if 'cropped' in args.train_path:
                train_loader = dataset.train_loader_cropped
                test_loader = dataset.test_loader_cropped
            else:
                train_loader = dataset.train_loader
                test_loader = dataset.test_loader
            self.train_data_loader = train_loader(
                args.train_path, args.batch_size, pin_memory=args.cuda)
            self.test_data_loader = test_loader(
                args.test_path, args.batch_size, pin_memory=args.cuda)
            self.criterion = nn.CrossEntropyLoss()

            self.pruner = SparsePruner(
                self.model, self.args.prune_perc_per_layer, previous_masks,
                self.args.train_biases, self.args.train_bn)
예제 #2
0
    def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases, soft_labels=False, prune_per=None):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        self.dataset2idx = dataset2idx
        self.dataset2biases = dataset2biases
        self.pruning_record = {}
        self.soft_labels = soft_labels

        if args.mode != 'check':
            # Set up data loader, criterion, and pruner.
            # if 'cropped' in args.train_path:
            #     print("PATH", args.train_path)
            #     train_loader = dataset.train_loader_cropped
            #     test_loader = dataset.test_loader_cropped
            # else:
            #     train_loader = dataset.train_loader
            #     test_loader = dataset.test_loader
            # self.train_data_loader = train_loader(
            #     args.train_path, args.batch_size, pin_memory=args.cuda)
            # self.test_data_loader = test_loader(
            #     args.test_path, args.batch_size, pin_memory=args.cuda)
            self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args)

            self.criterion = nn.CrossEntropyLoss()
            if prune_per:
                # Use dynamic pruning ratio
                self.pruner = SparsePruner(
                    self.model, prune_per, previous_masks,
                    self.args.train_biases, self.args.train_bn)
            else:
                self.pruner = SparsePruner(
                    self.model, self.args.prune_perc_per_layer, previous_masks,
                    self.args.train_biases, self.args.train_bn)
예제 #3
0
class Manager(object):
    """Handles training and pruning."""

    def __init__(self, args, model, previous_masks, dataset2idx, dataset2biases):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        self.dataset2idx = dataset2idx
        self.dataset2biases = dataset2biases

        if args.mode != 'check':
            # Set up data loader, criterion, and pruner.
            if "CIFAR100":
                self.train_data_loader, self.test_data_loader = dataset.CIFAR_loader(args.train_path, args.batch_size)

            if 'cropped' in args.train_path:
                train_loader = dataset.train_loader_cropped
                test_loader = dataset.test_loader_cropped
            else:
                train_loader = dataset.train_loader
                test_loader = dataset.test_loader
            self.train_data_loader = train_loader(
                args.train_path, args.batch_size, pin_memory=args.cuda)
            self.test_data_loader = test_loader(
                args.test_path, args.batch_size, pin_memory=args.cuda)
            self.criterion = nn.CrossEntropyLoss()

            self.pruner = SparsePruner(
                self.model, self.args.prune_perc_per_layer, previous_masks,
                self.args.train_biases, self.args.train_bn)

    def eval(self, dataset_idx, biases=None):
        """Performs evaluation."""
        if not self.args.disable_pruning_mask:
            self.pruner.apply_mask(dataset_idx)
        if biases is not None:
            self.pruner.restore_biases(biases)

        self.model.eval()
        error_meter = None

        print('Performing eval...')
        for batch, label in tqdm(self.test_data_loader, desc='Eval'):
            if self.cuda:
                batch = batch.cuda()
            batch = Variable(batch, volatile=True)

            output = self.model(batch)

            # Init error meter.
            if error_meter is None:
                topk = [1]
                if output.size(1) > 5:
                    topk.append(5)
                error_meter = tnt.meter.ClassErrorMeter(topk=topk)
            error_meter.add(output.data, label)

        errors = error_meter.value()
        print('Error: ' + ', '.join('@%s=%.2f' %
                                    t for t in zip(topk, errors)))
        if self.args.train_bn:
            self.model.train()
        else:
            self.model.train_nobn()
        return errors

    def do_batch(self, optimizer, batch, label):
        """Runs model for one batch."""
        if self.cuda:
            batch = batch.cuda()
            label = label.cuda()
        batch = Variable(batch)
        label = Variable(label)

        # Set grads to 0.
        self.model.zero_grad()

        # Do forward-backward.
        output = self.model(batch)
        self.criterion(output, label).backward()

        # Set fixed param grads to 0.
        if not self.args.disable_pruning_mask:
            self.pruner.make_grads_zero()

        # Update params.
        optimizer.step()

        # Set pruned weights to 0.
        if not self.args.disable_pruning_mask:
            self.pruner.make_pruned_zero()

    def do_epoch(self, epoch_idx, optimizer):
        """Trains model for one epoch."""
        for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx)):
            self.do_batch(optimizer, batch, label)

    def save_model(self, epoch, best_accuracy, errors, savename):
        """Saves model to file."""
        base_model = self.model

        # Prepare the ckpt.
        self.dataset2idx[self.args.dataset] = self.pruner.current_dataset_idx
        self.dataset2biases[self.args.dataset] = self.pruner.get_biases()
        ckpt = {
            'args': self.args,
            'epoch': epoch,
            'accuracy': best_accuracy,
            'errors': errors,
            'dataset2idx': self.dataset2idx,
            'previous_masks': self.pruner.current_masks,
            'model': base_model,
        }
        if self.args.train_biases:
            ckpt['dataset2biases'] = self.dataset2biases

        # Save to file.
        torch.save(ckpt, savename + '.pt')

    def train(self, epochs, optimizer, save=True, savename='', best_accuracy=0):
        """Performs training."""
        best_accuracy = best_accuracy
        error_history = []

        if self.args.cuda:
            self.model = self.model.cuda()

        for idx in range(epochs):
            epoch_idx = idx + 1
            print('Epoch: %d' % (epoch_idx))

            optimizer = utils.step_lr(epoch_idx, self.args.lr, self.args.lr_decay_every,
                                      self.args.lr_decay_factor, optimizer)
            if self.args.train_bn:
                self.model.train()
            else:
                self.model.train_nobn()
            self.do_epoch(epoch_idx, optimizer)
            errors = self.eval(self.pruner.current_dataset_idx)
            error_history.append(errors)
            accuracy = 100 - errors[0]  # Top-1 accuracy.

            # Save performance history and stats.
            with open(savename + '.json', 'w') as fout:
                json.dump({
                    'error_history': error_history,
                    'args': vars(self.args),
                }, fout)

            # Save best model, if required.
            if save and accuracy > best_accuracy:
                print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' %
                      (best_accuracy, accuracy))
                best_accuracy = accuracy
                self.save_model(epoch_idx, best_accuracy, errors, savename)

        print('Finished finetuning...')
        print('Best error/accuracy: %0.2f%%, %0.2f%%' %
              (100 - best_accuracy, best_accuracy))
        print('-' * 16)

    def prune(self):
        """Perform pruning."""
        print('Pre-prune eval:')
        self.eval(self.pruner.current_dataset_idx)

        self.pruner.prune()
        self.check(True)

        print('\nPost-prune eval:')
        errors = self.eval(self.pruner.current_dataset_idx)
        accuracy = 100 - errors[0]  # Top-1 accuracy.
        self.save_model(-1, accuracy, errors,
                        self.args.save_prefix + '_postprune')

        # Do final finetuning to improve results on pruned network.
        if self.args.post_prune_epochs:
            print('Doing some extra finetuning...')
            optimizer = optim.SGD(self.model.parameters(),
                                  lr=self.args.lr, momentum=0.9,
                                  weight_decay=self.args.weight_decay)
            self.train(self.args.post_prune_epochs, optimizer, save=True,
                       savename=self.args.save_prefix + '_final', best_accuracy=accuracy)

        print('-' * 16)
        print('Pruning summary:')
        self.check(True)
        print('-' * 16)

    def check(self, verbose=False):
        """Makes sure that the layers are pruned."""
        print('Checking...')
        for layer_idx, module in enumerate(self.model.shared.modules()):
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                weight = module.weight.data
                num_params = weight.numel()
                num_zero = weight.view(-1).eq(0).sum()
                if verbose:
                    print('Layer #%d: Pruned %d/%d (%.2f%%)' %
                          (layer_idx, num_zero, num_params, 100 * num_zero / num_params))