Пример #1
0
def get_trained_model_optimizer(args, device, train_loader, val_loader, criterion):
    if args.model == 'LeNet':
        model = LeNet().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4)
        else:
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    elif args.model == 'vgg16':
        model = VGG(depth=16).to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet18':
        model = ResNet18().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    elif args.model == 'resnet50':
        model = ResNet50().to(device)
        if args.load_pretrained_model:
            model.load_state_dict(torch.load(args.pretrained_model_dir))
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = MultiStepLR(
                optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1)
    else:
        raise ValueError("model not recognized")

    if not args.load_pretrained_model:
        best_acc = 0
        best_epoch = 0
        for epoch in range(args.pretrain_epochs):
            train(args, model, device, train_loader, criterion, optimizer, epoch)
            scheduler.step()
            acc = test(model, device, criterion, val_loader)
            if acc > best_acc:
                best_acc = acc
                best_epoch = epoch
                state_dict = model.state_dict()
        model.load_state_dict(state_dict)
        print('Best acc:', best_acc)
        print('Best epoch:', best_epoch)

        if args.save_model:
            torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
            print('Model trained saved to %s' % args.experiment_data_dir)

    return model, optimizer
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader,
                                                   val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{'sparsity': args.sparsity, 'op_types': op_types}]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'ActivationMeanRankFilterPruner':
        pruner = ActivationMeanRankFilterPruner(model, config_list)
    elif args.pruner == 'ActivationAPoZRankFilterPruner':
        pruner = ActivationAPoZRankFilterPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model,
                                config_list,
                                short_term_fine_tuner=short_term_fine_tuner,
                                evaluator=evaluator,
                                base_algo=args.base_algo,
                                experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model,
                            config_list,
                            trainer=trainer,
                            num_iterations=2,
                            training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model,
            config_list,
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model,
            config_list,
            trainer=trainer,
            evaluator=evaluator,
            dummy_input=dummy_input,
            num_iterations=3,
            optimize_mode='maximize',
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            admm_num_iterations=30,
            admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)
            elif args.model == 'mobilenet_v2':
                model = models.mobilenet_v2(pretrained=False).to(device)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.experiment_data_dir,
                                 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s', args.experiment_data_dir)
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'),
              'w+') as f:
        json.dump(result, f)
Пример #3
0
                    help='option2: attacker for generating adv input')

parser.add_argument('--loss_schema', default='averaged', type=str,
                    help='option3: loss schema')


# reproducible 
torch.manual_seed(66)
np.random.seed(66)


######################################### modify accordingly ##################################################
# adv models: the static model used to generate adv input images
# fixed to memory for all the trainings to speed up.
adv_resnet18 = ResNet18()
adv_resnet50 = ResNet50()
adv_mobilenet_125 = MobileNetV2(width_mult=1.25)
adv_googlenet = GoogLeNet()


adv_models = [adv_resnet18, adv_resnet50, adv_mobilenet_125, adv_googlenet]
adv_model_names = ['resnet18', 'resnet50', 'mobilenet_125', 'googlenet']

# models: models for be adv training
# loaded only on its training to save memory.
model_classes = [ ResNet34, ResNet101, MobileNetV2, MobileNetV2]
model_names = [ 'resnet34', 'resnet101', 'mobilenet_1', 'mobilenet_075']
params = {
    'mobilenet_1': 1.0,
    'mobilenet_075': 0.75,
}