コード例 #1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dataset', type=str, default="mnist_logistic", choices=["mnist","mnist_logistic", "mnist_MLP", \
                                                                                   "cifar10_AlexNet",\
                                                                         "cifar10_ResNet", "cifar10_VGG"],
                        metavar='D', help='training dataset (mnist or cifar10)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--percent',
                        type=list,
                        default=[0.8, 0.92, 0.991, 0.93],
                        metavar='P',
                        help='pruning percentage (default: 0.8)')
    parser.add_argument('--alpha',
                        type=float,
                        default=1e-4,
                        metavar='L',
                        help='l2 norm weight (default: none')
    parser.add_argument('--a',
                        type=float,
                        default=3.7,
                        metavar='F',
                        help='SCAD norm weight (default: 3.7)')
    parser.add_argument('--rho',
                        type=float,
                        default=1e-2,
                        metavar='R',
                        help='cardinality weight (default: 1e-2)')
    parser.add_argument(
        '--l1',
        default=False,
        action='store_true',
        help='prune weights with l1 regularization instead of cardinality')
    parser.add_argument(
        '--l0',
        default=True,
        action='store_true',
        help='prune weights with l0 regularization instead of cardinality')
    parser.add_argument(
        '--SCAD',
        default=False,
        action='store_true',
        help='prune weights with SCAD regularization instead of cardinality')
    parser.add_argument(
        '--rscad',
        default=False,
        action='store_true',
        help='prune weights with RSCAD regularization instead of cardinality')
    parser.add_argument('--l2',
                        default=False,
                        action='store_true',
                        help='apply l2 regularization')
    parser.add_argument('--num_pre_epochs',
                        type=int,
                        default=3,
                        metavar='P',
                        help='number of epochs to pretrain (default: 3)')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--num_re_epochs',
                        type=int,
                        default=3,
                        metavar='R',
                        help='number of epochs to retrain (default: 3)')
    parser.add_argument('--num_test_epochs',
                        type=int,
                        default=10,
                        metavar='m',
                        help='number of epochs to retrain (default: 3)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--adam_epsilon',
                        type=float,
                        default=1e-8,
                        metavar='E',
                        help='adam epsilon (default: 1e-8)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    if args.dataset == "mnist":
        args.num_pre_epochs = 3
        args.num_epochs = 30
        args.num_re_epochs = 1
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "mnist_logistic":
        args.num_pre_epochs = 5
        args.num_epochs = 80
        args.num_re_epochs = 1
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "mnist_MLP":
        args.num_pre_epochs = 2
        args.num_epochs = 50
        args.num_re_epochs = 1
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    else:
        args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93]
        args.num_pre_epochs = 20
        args.num_epochs = 50
        args.num_re_epochs = 2
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                'data',
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(
                        32, padding=4),  #数据增强 -将图片转化为周围加上4圈0 再裁剪为32x32
                    transforms.RandomHorizontalFlip(),  #图像翻转
                    transforms.ToTensor(),
                    transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                         (0.24703233, 0.24348505, 0.26158768))
                ])),
            shuffle=True,
            batch_size=args.batch_size,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('data',
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(
                                     (0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768))
                             ])),
            shuffle=True,
            batch_size=args.test_batch_size,
            **kwargs)
    args.test_lamda = False
    args.plot_convergence = True
    if args.test_lamda:
        if args.l1:
            lamda = [
                2.5e-4, 3e-4, 3.5e-4, 4e-4, 4.5e-4, 5e-4, 5.5e-4, 6e-4, 6.5e-4,
                7e-4, 7.5e-4
            ]
        elif args.l0:
            lamda = [
                2.5e-4, 3e-4, 3.5e-4, 4e-4, 4.5e-4, 5e-4, 5.5e-4, 6e-4, 6.5e-4,
                7e-4, 7.5e-4
            ]
        else:
            lamda = [
                4e-4, 4.5e-4, 5e-4, 5.5e-4, 6e-4, 6.5e-4, 7e-4, 7.5e-4, 8e-4,
                8.5e-4, 9e-4, 9.5e-4, 1e-3
            ]

        for i in range(0, len(lamda), 1):
            args.alpha = lamda[i]
            print(
                "**********************the test lamda*****************************"
            )
            print('\nThe test lamda: {:.6f}\n'.format(args.alpha))
            #模型选择
            if args.dataset == "mnist":
                model = LeNet().to(device)
            elif args.dataset == "mnist_logistic":
                model = L_softmax().to(device)
            elif args.dataset == "mnist_MLP":
                model = MLP().to(device)
            elif args.dataset == "cifar10_AlexNet":
                model = AlexNet().to(device)
            elif args.dataset == "cifar10_ResNet":
                model = ResNet(ResidualBlock).to(device)
            else:
                model = VGG().to(device)

            optimizer = PruneAdam(model.named_parameters(),
                                  lr=args.lr,
                                  eps=args.adam_epsilon)

            train(args, model, device, train_loader, test_loader, optimizer)

            if (args.l1):
                mask = apply_l1_prune(model, device, args)
            elif (args.l0):
                mask = apply_l0_prune(model, device, args)
            elif (args.SCAD):
                mask = apply_l1_prune(model, device, args)
            elif (args.rscad):
                mask = apply_rscad_prune(model, device, args)
            else:
                mask = apply_prune(model, device, args)

            print_prune(model)

            test(args, model, device, test_loader)
            retrain(args, model, mask, device, train_loader, test_loader,
                    optimizer)
    else:

        if args.dataset == "mnist":
            print(
                '\n*********** The test model is LeNet and lamda = %f and dataset is minst************'
                % args.alpha)
            model = LeNet().to(device)
            args.idx = 4
        elif args.dataset == "mnist_logistic":
            print(
                '\n*********** The test model is logistic and lamda = %f and dataset is minst************'
                % args.alpha)
            args.idx = 1
            model = L_softmax().to(device)
        elif args.dataset == "mnist_MLP":
            print(
                '\n*********** The test model is MLP and lamda = %f and dataset is minst************'
                % args.alpha)
            args.idx = 3
            model = MLP().to(device)
        elif args.dataset == "cifar10_AlexNet":
            print(
                '\n*********** The test model is AlexNet and lamda = %f and dataset is cifar10************'
                % args.alpha)
            model = AlexNet().to(device)
            args.idx = 8
        elif args.dataset == "cifar10_ResNet":
            print(
                '\n*********** The test model is ResNet and lamda = %f and dataset is cifar10************'
                % args.alpha)
            model = ResNet(ResidualBlock).to(device)
        else:
            print(
                '\n*********** The test model is VGG-16 and lamda = %f and dataset is cifar10************'
                % args.alpha)
            model = VGG().to(device)
            args.idx = 9
        optimizer = PruneAdam(model.named_parameters(),
                              lr=args.lr,
                              eps=args.adam_epsilon)
        A = train(args, model, device, train_loader, test_loader, optimizer)

        if args.plot_convergence:
            color = [
                'purple', 'red', 'blue', 'yellow', 'cyan', 'green', 'magenta',
                'black', 'gray', 'hotpink'
            ]
            marke = ['p', '*', 'o', 'x', 's', 'v', 'h', '|', 'd', '+']
            plt.figure
            aixs_x = np.arange(1, args.num_epochs + 1, 1)
            for i in range(args.idx):
                plt.plot(aixs_x,sorted(A[i], reverse= True),color = color[i],linewidth = 1, \
                         label = 'conv1',linestyle = '--',marker = marke[i])
            plt.xlabel('iterations')
            plt.ylabel('error')
            plt.title('||theta^k+1-z^k+1||')
            if args.dataset == "mnist":
                plt.legend(['conv1', 'conv2', 'fc1', 'fc2'])
            elif args.dataset == "mnist_MLP":
                plt.legend(['fc1', 'fc2', 'fc3'])
            elif args.dataset == "mnist_logistic":
                plt.legend(['weight'])
            elif args.dataset == "cifar10_AlexNet":
                plt.legend([
                    'conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc1', 'fc2',
                    'fc3'
                ])
            else:
                plt.legend([
                    'conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6',
                    'conv7', 'conv8', 'fc1'
                ])

            plt.savefig("convergence_logistic10.jpg", dpi=600)

        if (args.l1):
            mask = apply_l1_prune(model, device, args)
        elif (args.l0):
            mask = apply_l0_prune(model, device, args)
        elif (args.SCAD):
            mask = apply_l1_prune(model, device, args)
        elif (args.rscad):
            mask = apply_rscad_prune(model, device, args)
        else:
            mask = apply_prune(model, device, args)

        #pickle.dump(model, open("pruning_model.dat", "wb"))
        torch.save(model, 'pruning_modelnet_logistic10.pkl')
        print_prune(model)
        test(args, model, device, test_loader)
        retrain(args, model, mask, device, train_loader, test_loader,
                optimizer)
コード例 #2
0
ファイル: main.py プロジェクト: rhhc/pytorch-admm-pruning
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dataset', type=str, default="mnist", choices=["mnist", "cifar10"],
                        metavar='D', help='training dataset (mnist or cifar10)')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--percent', type=list, default=[0.8, 0.92, 0.991, 0.93],
                        metavar='P', help='pruning percentage (default: 0.8)')
    parser.add_argument('--alpha', type=float, default=5e-4, metavar='L',
                        help='l2 norm weight (default: 5e-4)')
    parser.add_argument('--rho', type=float, default=1e-2, metavar='R',
                        help='cardinality weight (default: 1e-2)')
    parser.add_argument('--l1', default=False, action='store_true',
                        help='prune weights with l1 regularization instead of cardinality')
    parser.add_argument('--l2', default=False, action='store_true',
                        help='apply l2 regularization')
    parser.add_argument('--num_pre_epochs', type=int, default=3, metavar='P',
                        help='number of epochs to pretrain (default: 3)')
    parser.add_argument('--num_epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--num_re_epochs', type=int, default=3, metavar='R',
                        help='number of epochs to retrain (default: 3)')
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--adam_epsilon', type=float, default=1e-8, metavar='E',
                        help='adam epsilon (default: 1e-8)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    if args.dataset == "mnist":
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)

    else:
        args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93]
        args.num_pre_epochs = 5
        args.num_epochs = 20
        args.num_re_epochs = 5
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('data', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                                      (0.24703233, 0.24348505, 0.26158768))
                             ])), shuffle=True, batch_size=args.batch_size, **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('data', train=False, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                                      (0.24703233, 0.24348505, 0.26158768))
                             ])), shuffle=True, batch_size=args.test_batch_size, **kwargs)

    model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(device)
    optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon)

    train(args, model, device, train_loader, test_loader, optimizer)
    mask = apply_l1_prune(model, device, args) if args.l1 else apply_prune(model, device, args)
    print_prune(model)
    test(args, model, device, test_loader)
    retrain(args, model, mask, device, train_loader, test_loader, optimizer)
コード例 #3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dataset',
                        type=str,
                        default="mnist",
                        choices=["mnist", "cifar10"],
                        metavar='D',
                        help='training dataset (mnist or cifar10)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--percent',
                        type=list,
                        default=[0.8, 0.92, 0.991, 0.93],
                        metavar='P',
                        help='pruning percentage (default: 0.8)')
    parser.add_argument('--alpha',
                        type=float,
                        default=5e-4,
                        metavar='L',
                        help='l2 norm weight (default: 5e-4)')
    parser.add_argument('--rho',
                        type=float,
                        default=1e-2,
                        metavar='R',
                        help='cardinality weight (default: 1e-2)')
    parser.add_argument(
        '--l1',
        default=False,
        action='store_true',
        help='prune weights with l1 regularization instead of cardinality')
    parser.add_argument('--l2',
                        default=False,
                        action='store_true',
                        help='apply l2 regularization')
    parser.add_argument('--num_pre_epochs',
                        type=int,
                        default=3,
                        metavar='P',
                        help='number of epochs to pretrain (default: 3)')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--num_re_epochs',
                        type=int,
                        default=3,
                        metavar='R',
                        help='number of epochs to retrain (default: 3)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--adam_epsilon',
                        type=float,
                        default=1e-8,
                        metavar='E',
                        help='adam epsilon (default: 1e-8)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--structured',
                        action='store_true',
                        default=False,
                        help='Enabling Structured Pruning')
    parser.add_argument('--test',
                        action='store_true',
                        default=False,
                        help='For Testing the current Model')
    parser.add_argument(
        '--stat',
        action='store_true',
        default=False,
        help='For showing the statistic result of the current Model')
    parser.add_argument('--n1',
                        type=int,
                        default=2,
                        metavar='N',
                        help='ReRAM OU size (row number) (default: 2)')
    parser.add_argument('--n2',
                        type=int,
                        default=2,
                        metavar='N',
                        help='ReRAM OU size (column number) (default: 2)')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    if args.dataset == "mnist":
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            'data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    else:
        args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93]
        args.num_pre_epochs = 5
        args.num_epochs = 20
        args.num_re_epochs = 5
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            'data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768))
            ])),
                                                   shuffle=True,
                                                   batch_size=args.batch_size,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('data',
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(
                                     (0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768))
                             ])),
            shuffle=True,
            batch_size=args.test_batch_size,
            **kwargs)

    model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(
        device)
    optimizer = PruneAdam(model.named_parameters(),
                          lr=args.lr,
                          eps=args.adam_epsilon)

    structured_tag = "_structured{}x{}".format(
        args.n1, args.n2) if args.structured else ""

    model_file = "mnist_cnn{}.pt".format(structured_tag) if args.dataset == "mnist" \
            else 'cifar10_cnn{}.pt'.format(structured_tag)

    if args.stat or args.test:
        print("=> loading model '{}'".format(model_file))

        if os.path.isfile(model_file):
            model.load_state_dict(torch.load(model_file))
            print("=> loaded model '{}'".format(model_file))
            if args.test:
                test(args, model, device, test_loader)
            if args.stat:
                show_statistic_result(args, model)
        else:
            print("=> loading model failed '{}'".format(model_file))

    else:
        checkpoint_file = 'checkpoint{}.pth.tar'.format(
            "_mnist" if args.dataset == "mnist" else "_cifar10")

        if not os.path.isfile(checkpoint_file):
            pre_train(args, model, device, train_loader, test_loader,
                      optimizer)
            torch.save(
                {
                    'epoch': args.num_pre_epochs,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, checkpoint_file)
        else:
            print("=> loading checkpoint '{}'".format(checkpoint_file))
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}'".format(checkpoint_file))

        train(args, model, device, train_loader, test_loader, optimizer)
        mask = apply_l1_prune(model, device, args) if args.l1 else apply_prune(
            model, device, args)
        print_prune(model)
        test(args, model, device, test_loader)
        retrain(args, model, mask, device, train_loader, test_loader,
                optimizer)

        if args.save_model:
            torch.save(model.state_dict(), model_file)