Beispiel #1
0
 def prune(self):
     self.pruner = ActivationMeanRankFilterPruner(self.model,
                                                  self.config_list,
                                                  self.optimizer)
     self.model = self.pruner.compress()
     top_acc = 0.9
     for epoch in range(self.prune_epochs):
         self.pruner.update_epoch(epoch)
         self._train_one_epoch(epoch, self.model, self.train_loader,
                               self.optimizer)
         acc = self.test(epoch)
         if acc > top_acc:
             top_acc = acc
             print("Begining prune model")
             self.pruner.export_model(
                 model_path='results/pruned/pruned_model.pth',
                 mask_path='results/pruned/pruned_mask.pth')
def main():
    parser = argparse.ArgumentParser("multiple gpu with pruning")
    parser.add_argument("--epochs", type=int, default=160)
    parser.add_argument("--retrain", default=False, action="store_true")
    parser.add_argument("--parallel", default=False, action="store_true")

    args = parser.parse_args()
    torch.manual_seed(0)
    device = torch.device('cuda')
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                               batch_size=64,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        './data.cifar10',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])),
                                              batch_size=200,
                                              shuffle=False)

    model = VGG(depth=16)
    model.to(device)

    # Train the base VGG-16 model
    if args.retrain:
        print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=0.1,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 160, 0)
        for epoch in range(args.epochs):
            train(model, device, train_loader, optimizer)
            test(model, device, test_loader)
            lr_scheduler.step(epoch)
        torch.save(model.state_dict(), 'vgg16_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test on the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg16_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.51%

    # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
    # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
    configure_list = [{
        'sparsity':
        0.5,
        'op_types': ['default'],
        'op_names': [
            'feature.0', 'feature.24', 'feature.27', 'feature.30',
            'feature.34', 'feature.37'
        ]
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    pruner = ActivationMeanRankFilterPruner(model, configure_list)
    model = pruner.compress()
    if args.parallel:
        if torch.cuda.device_count() > 1:
            print("use {} gpus for pruning".format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
        else:
            print("only detect 1 gpu, fall back")

    model.to(device)
    test(model, device, test_loader)
    # top1 = 88.19%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg16_cifar10.pth',
                                mask_path='mask_vgg16_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=16)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
    test(new_model, device, test_loader)
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader,
                                                   val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{'sparsity': args.sparsity, 'op_types': op_types}]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'ActivationMeanRankFilterPruner':
        pruner = ActivationMeanRankFilterPruner(model, config_list)
    elif args.pruner == 'ActivationAPoZRankFilterPruner':
        pruner = ActivationAPoZRankFilterPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model,
                                config_list,
                                short_term_fine_tuner=short_term_fine_tuner,
                                evaluator=evaluator,
                                base_algo=args.base_algo,
                                experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model,
                            config_list,
                            trainer=trainer,
                            num_iterations=2,
                            training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model,
            config_list,
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model,
            config_list,
            trainer=trainer,
            evaluator=evaluator,
            dummy_input=dummy_input,
            num_iterations=3,
            optimize_mode='maximize',
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            admm_num_iterations=30,
            admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)
            elif args.model == 'mobilenet_v2':
                model = models.mobilenet_v2(pretrained=False).to(device)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.experiment_data_dir,
                                 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s', args.experiment_data_dir)
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'),
              'w+') as f:
        json.dump(result, f)
                            lr=0.01,
                            momentum=0.9,
                            weight_decay=1e-4)
print("start model training...")
for epoch in range(pretrain_epochs):
    train(model, device, train_data_loader, optimizer)
    test(model, device, test_data_loader)
torch.save(model.state_dict(), 'pretrained_model.pth')
print("start model pruning...")
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.001,
                            momentum=0.9,
                            weight_decay=1e-4)
best_top1 = 0
# pruner = SlimPruner(model, config_list, optimizer)
pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer)
model = pruner.compress()

for epoch in range(prune_epochs):
    pruner.update_epoch(epoch)
    print("# Epoch {} #".format(epoch))
    train(model, device, train_data_loader, optimizer)
    top1 = test(model, device, test_data_loader)
    if top1 > best_top1:
        pruner.export_model(model_path='pruned_model.pth',
                            mask_path='pruned_mask.pth')
        from nni.compression.torch import apply_compression_results
        from nni.compression.speedup.torch import ModelSpeedup
        model = MobileModel().cuda()
        model.eval()
        apply_compression_results(model, 'pruned_mask.pth', None)
Beispiel #5
0
class Trainer:
    def __init__(self, train_loader, test_loader, model, optimizer, scheduler,
                 epochs, prune_epochs, device, save_path):
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model = model.to(device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.epochs = epochs
        self.prune_epochs = prune_epochs
        self.device = device
        self.save_path = save_path

        self.loss = CrossEntropyLoss()
        # prune
        self.config_list = [{'sparsity': 0.1, 'op_types': ['Conv2d']}]

    def train(self):
        self.model.train()
        for epoch in range(self.epochs):
            for batch_idx, (inputs, labels) in enumerate(self.train_loader):
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(inputs)
                _, pred_labels = torch.max(outputs, 1)

                loss = self.loss(outputs, labels)
                acc = torch.sum(pred_labels == labels.data) / float(
                    len(labels))

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if batch_idx % 100 == 0:
                    print(
                        "Train Epoch: {:03} [{:05}/{:05} ({:03.0f}%) \t Loss:{:.6f} Acc:{:.6f} LR: {:.6f}"
                        .format(epoch, batch_idx * len(inputs),
                                len(self.train_loader.dataset),
                                100. * batch_idx / len(self.train_loader),
                                loss.item(), acc,
                                self.optimizer.param_groups[0]['lr']))
            self.scheduler.step()
            torch.save(
                self.model.state_dict(),
                os.path.join(
                    self.save_path, '{}_mobilenetv2_epoch_{}.pth'.format(
                        config.attribute, epoch)))
            self.test(epoch)

        self.prune()
        # apply_compression_results(self.model, 'results/pruned/pruned_mask.pth', None)
        # speedup_model = ModelSpeedup(model. torch.randn(1, 3, 224, 224).cuda(),
        #        'results/pruned/pruned_mask.pth', None)
        # speedup_model.speedup_model()
        # torch.save(model.state_dict(), 'pruned_speedup_model.pth')
    def prune(self):
        self.pruner = ActivationMeanRankFilterPruner(self.model,
                                                     self.config_list,
                                                     self.optimizer)
        self.model = self.pruner.compress()
        top_acc = 0.9
        for epoch in range(self.prune_epochs):
            self.pruner.update_epoch(epoch)
            self._train_one_epoch(epoch, self.model, self.train_loader,
                                  self.optimizer)
            acc = self.test(epoch)
            if acc > top_acc:
                top_acc = acc
                print("Begining prune model")
                self.pruner.export_model(
                    model_path='results/pruned/pruned_model.pth',
                    mask_path='results/pruned/pruned_mask.pth')

    def _train_one_epoch(self, epoch, model, train_loader, optimizer):
        model.train()
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, pred_labels = torch.max(outputs, 1)

            loss = self.loss(outputs, labels)
            acc = torch.sum(pred_labels == labels.data) / float(len(labels))

            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(
                    "Train Epoch: {:03} [{:05}/{:05} ({:03.0f}%) \t Loss:{:.6f} Acc:{:.6f} LR: {:.6f}"
                    .format(epoch, batch_idx * len(inputs),
                            len(self.train_loader.dataset),
                            100. * batch_idx / len(self.train_loader),
                            loss.item(), acc,
                            self.optimizer.param_groups[0]['lr']))

    def test(self, epoch):
        self.model.eval()
        with torch.no_grad():
            total_acc = 0
            total_sample = 0
            for batch_idx, (inputs, labels) in enumerate(self.test_loader):
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(inputs)
                _, pred_labels = torch.max(outputs, 1)

                acc = torch.sum(pred_labels == labels.data)
                total_acc += acc
                total_sample += len(inputs)
            acc = float(total_acc) / total_sample

        print("Test Acc:", acc)
        return acc
def train_net(net,
              device,
              epochs=100,
              batch_size=1,
              lr=0.1,
              val_percent=0.2,
              save_cp=True,
              img_scale=1):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=1, pin_memory=True,drop_last=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False,  num_workers=1, pin_memory=True, drop_last=True)

    gene_eval_data(val_loader, dir='./data/val/')

    writer = SummaryWriter(comment='LR_{}_BS_{}_SCALE_{}'.format(lr,batch_size,img_scale))
    global_step = 0

    logging.info('''Starting training:
        Epochs:          {}
        Batch size:      {}
        Learning rate:   {}
        Training size:   {}
        Validation size: {}
        Checkpoints:     {}
        Device:          {}
        Images scaling:  {}
    '''.format(epochs,batch_size,lr,n_train,n_val,save_cp,device.type,img_scale))
    net.to(device)
    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max',factor=0.5, patience=20)

    criterion = dice_loss
    # criterion = nn.BCELoss()

    last_loss =9999
    last_val_score = 0
    config_list = [{ 
            'sparsity': 0.5,
            'op_types': ['Conv2d']
    }]

    pruner = ActivationMeanRankFilterPruner(net, config_list,optimizer = optimizer)
    pruner.compress()

    for epoch in range(epochs):
        net.train()
        pruner.update_epoch(epoch)
        epoch_loss = 0
        step = 0
        with tqdm(total=n_train, desc='Epoch {}/{}'.format(epoch + 1,epochs), unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels,\
                    'Network has been defined with {} input channels, '.format(net.n_channels)+\
                'but loaded images have {} channels. Please check that '.format(imgs.shape[1])+\
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                optimizer.zero_grad()
                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                global_step += 1
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                optimizer.step()
                # del imgs
                pbar.update(imgs.shape[0])


# if global_step % (len(dataset) // ( 2* batch_size)) == 0:
        for tag, value in net.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
            writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
        val_score = eval_net(net, val_loader, device)
        scheduler.step(val_score)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

        if net.n_classes > 1:
            logging.info('Validation cross entropy: {}'.format(val_score))
            writer.add_scalar('Loss/test', val_score, global_step)
        else:
            logging.info('Train Loss: {}    Validation Dice Coeff: {} '.format(epoch_loss/n_train , val_score))
            writer.add_scalar('Dice/test', val_score, global_step)

            writer.add_images('images', imgs, global_step)
            if net.n_classes == 1:
                writer.add_images('masks/true', true_masks, global_step)
                writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.3 , global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            if last_loss > epoch_loss or last_val_score < val_score:
                last_loss  = min (last_loss, epoch_loss)
                last_val_score = max(last_val_score , val_score)
                # torch.save(net.state_dict(),
                pruner.export_model("./save_pruner/pruned_model.pt","./save_pruner/pruned_mask.pt")
                print('Checkpoint {} saved !'.format(epoch + 1)+'   CP_epoch{}Trainloss{}ValDice{}.pt'.format(epoch + 1,epoch_loss/n_train, val_score))

    writer.close()