Ejemplo n.º 1
0
latency_5 = []
latency_10 = []
latency_30 = []
latency_60 = []
if args.model == "vgg":  # CIFAR10 only
    config = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
    ratios = np.arange(0.25, 2.1,
                       0.25)  # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    for ratio in ratios:
        # uniform
        new_config = VGG.prepare_filters(VGG,
                                         config,
                                         ratio=ratio,
                                         neuralscale=False,
                                         num_classes=num_classes)
        model = vgg11(config=new_config, num_classes=num_classes)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_uni.append(params)
        latency_uni.append(latency)
        ## efficient
        vgg_0_fname = "vgg_0_eff_c10"
        vgg_2_fname = "vgg_2_eff_c10"
        vgg_5_fname = "vgg_5_eff_c10"
        vgg_10_fname = "vgg_10_eff_c10"
        vgg_30_fname = "vgg_30_eff_c10"
        vgg_60_fname = "vgg_60_eff_c10"
        # P=0
        new_config = VGG.prepare_filters(VGG,
                                         config,
                                         ratio=ratio,
Ejemplo n.º 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'Search for optimal configuration using architecture descent')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: CIFAR10, CIFAR100',
        choices=["CIFAR10", "CIFAR100"])
    parser.add_argument(
        '--model',
        default="vgg",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--save', default='model', help='model and prune file')
    parser.add_argument('--no_prune',
                        dest="prune_on",
                        action='store_false',
                        default=True,
                        help='Turn off pruning')
    parser.add_argument(
        '--warmup',
        type=int,
        default=10,
        help=
        'number of warm-up or fine-tuning epochs before pruning (default: 10)')
    parser.add_argument(
        '--morph',
        dest="morph",
        action='store_true',
        default=False,
        help='Prunes only 50 percent of neurons, for comparison with MorphNet')
    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################
    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)

    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "Imagenet":
        print("Using Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, CIFAR100]")
        exit()

    if args.dataset == "CIFAR10":
        num_classes = 10
    elif args.dataset == "CIFAR100":
        num_classes = 100

    param_list = []
    if args.morph:
        total_iter = 1
    else:
        total_iter = 15
    for iteration in range(total_iter):
        print("Iteration: {}".format(iteration))
        args.lr = 0.1
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=1,
                          neuralscale=True,
                          iteration=iteration,
                          num_classes=num_classes,
                          search=True)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=1,
                                   neuralscale=True,
                                   iteration=iteration,
                                   num_classes=num_classes,
                                   search=True,
                                   dataset=args.dataset)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=1,
                                neuralscale=True,
                                iteration=iteration,
                                num_classes=num_classes,
                                search=True,
                                dataset=args.dataset)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))

        # for model saving
        model_path = "saved_models"
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        log_save_folder = "%s/%s" % (model_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        model_save_path = "%s/%s" % (log_save_folder,
                                     args.save) + "_checkpoint.t7"
        model_state_dict = model.state_dict()
        if args.save:
            print("Model will be saved to {}".format(model_save_path))
            save_checkpoint({'state_dict': model_state_dict},
                            False,
                            filename=model_save_path)
        else:
            print("Save path not defined. Model will not be saved.")

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum(
            [np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ",
              total_size_params)
        optimizer = optim.SGD(parameters_for_update,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        if args.prune_on:
            pruning_parameters_list = prepare_pruning_list(model)
            print("Total pruning layers:", len(pruning_parameters_list))
            if args.morph:
                prune_neurons = 0.5
            else:
                prune_neurons = 0.95
            pruning_engine = pruner(pruning_parameters_list,
                                    iteration=iteration,
                                    prune_fname=args.save,
                                    classes=num_classes,
                                    model=args.model,
                                    prune_neurons=prune_neurons)

        ###############
        ## Training  ##
        ###############
        for epoch in range(1, args.epochs + 1):
            print("Epoch: {}".format(epoch))
            adjust_learning_rate(args,
                                 optimizer,
                                 epoch,
                                 search=True,
                                 warmup=args.warmup)
            # train model
            if args.prune_on:
                train_acc, train_loss = train(args,
                                              model,
                                              train_loader,
                                              optimizer,
                                              epoch,
                                              criterion,
                                              pruning_engine=pruning_engine,
                                              num_classes=num_classes)
                if train_acc == -1 and train_loss == -1:
                    break
            else:
                train_acc, train_loss = train(args,
                                              model,
                                              train_loader,
                                              optimizer,
                                              epoch,
                                              criterion,
                                              num_classes=num_classes)

        # =========================
        # store searched parameters
        # =========================
        alpha, beta = fit_params(
            iteration=iteration,
            prune_fname=args.save,
            classes=num_classes,
            model=args.model)  # search for params of each layer
        param_list.append([alpha, beta])

    # =======================
    # save scaling parameters
    # =======================
    pickle_save = {
        "param": param_list,
    }

    pickle_out = open("prune_record/" + args.save + ".pk", "wb")
    pickle.dump(pickle_save, pickle_out)
    pickle_out.close()
Ejemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Efficient Filter Scaling of Convolutional Neural Network')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=300, metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--dataset', default="CIFAR10", type=str,
                            help='dataset for experiment, choice: CIFAR10, CIFAR100', choices= ["CIFAR10", "CIFAR100"])
    parser.add_argument('--model', default="resnet18", type=str,
                        help='model selection, choices: vgg, mobilenetv2, resnet18',
                        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--save', default='model',
                        help='model file')
    parser.add_argument('--prune_fname', default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx', type=int, default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--morph', dest="morph", action='store_true', default=False,
                        help='Prunes only 50 percent of neurons, for comparison with MorphNet')
    parser.add_argument('--uniform', dest="uniform", action='store_true', default=False,
                        help='Use uniform scaling instead of NeuralScale')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if(args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                            std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, 
                                                download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False,
                                                download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, 
                                                download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False,
                                                download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, CIFAR100]")
        exit()

    if args.dataset=='CIFAR10':
        num_classes = 10
    elif args.dataset=='CIFAR100':
        num_classes = 100

    ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    pruned_filters = None
    neuralscale = True  # turn NeuralScale on by default
    if args.uniform:
        neuralscale = False
    if args.morph:
        neuralscale = False
        if args.model == "vgg":
            if args.dataset == "CIFAR10":
                pruned_filters = [64, 128, 249, 253, 268, 175, 87, 152] # VGG C10
            elif args.dataset == "CIFAR100":
                pruned_filters = [63, 125, 204, 215, 234, 174, 120, 241] # VGG C100
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
        elif args.model == "resnet18":
            if args.dataset == "CIFAR100":
                pruned_filters = [48, 46, 40, 41, 54, 91, 75, 73, 95, 157, 149, 149, 156, 232, 216, 140, 190] # resnet18 c100
            elif args.dataset == "tinyimagenet":
                print("Please use imagenet_ratio_swipe.py instead.")
                exit()
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
        elif args.model == "mobilenetv2":
            if args.dataset == "CIFAR100":
                pruned_filters = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] # mobilenetv2 c100
            elif args.dataset == "tinyimagenet":
                print("Please use imagenet_ratio_swipe.py instead.")
                exit()
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
    for ratio in ratios:
        print("Current ratio: {}".format(ratio))
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))


        # for model saving
        model_path = "saved_models"
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        log_save_folder = "%s/%s"%(model_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        model_save_path = "%s/%s"%(log_save_folder, args.save) + "_checkpoint.t7"
        model_state_dict = model.state_dict()
        if args.save:
            print("Model will be saved to {}".format(model_save_path))
            save_checkpoint({
                'state_dict': model_state_dict
            }, False, filename = model_save_path)
        else:
            print("Save path not defined. Model will not be saved.")

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        
        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum([np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ", total_size_params)

        optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        ###############
        ## Training  ##
        ###############
        best_test_acc = 0
        train_acc_plt = []
        train_loss_plt = []
        test_acc_plt = []
        test_loss_plt = []
        epoch_plt = []
        for epoch in range(1, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            print("Epoch: {}".format(epoch))

            # train model
            train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion)

            # evaluate on validation set
            test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer)

            # remember best prec@1 and save checkpoint
            is_best = test_acc > best_test_acc
            best_test_acc = max(test_acc, best_test_acc)
            model_state_dict = model.state_dict()
            if args.save:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': test_acc,
                }, is_best, filename=model_save_path)


            train_acc_plt.append(train_acc)
            train_loss_plt.append(train_loss)
            test_acc_plt.append(test_acc)
            test_loss_plt.append(test_loss)
            epoch_plt.append(epoch)


        pickle_save = {
            "ratio": ratio,
            "train_acc": train_acc_plt,
            "train_loss": train_loss_plt,
            "test_acc": test_acc_plt,
            "test_loss": test_loss_plt,
        }
        plot_path = "saved_plots"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        log_save_folder = "%s/%s"%(plot_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        pickle_out = open("%s/%s_%s.pk"%(log_save_folder, args.save, int(ratio*100)),"wb")
        pickle.dump(pickle_save, pickle_out)
        pickle_out.close()
Ejemplo n.º 4
0
latency_high = []
latency_convcut = []
latency_uni = []
latency_prune = []

if args.model=="vgg":
    config = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
    if args.dataset == "CIFAR10":
        config_prune = [64, 128, 249, 253, 268, 175, 87, 152] # VGG C10
    elif args.dataset == "CIFAR100":
        config_prune = [63, 125, 204, 215, 234, 174, 120, 241] # VGG C100
    ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    for ratio in ratios:
        # uniform
        new_config = VGG.prepare_filters(VGG, config, ratio=ratio, neuralscale=False, num_classes=num_classes)
        model = vgg11(config=new_config, num_classes=num_classes)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_uni.append(params)
        latency_uni.append(latency)
        # pruned
        new_config = VGG.prepare_filters(VGG, config, ratio=ratio, neuralscale=False, num_classes=num_classes, pruned_filters=config_prune)
        model = vgg11(config=new_config, num_classes=num_classes)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_prune.append(params)
        latency_prune.append(latency)
        ## efficient
        if args.dataset == "CIFAR100":
            fname = "vgg2_10_eff_c100"
            # fname = "vgg_eff_c100"
Ejemplo n.º 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Fine-tune on pruned architecture')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: CIFAR10, CIFAR100, tinyimagenet',
        choices=["CIFAR10", "CIFAR100", "tinyimagenet"])
    parser.add_argument('--data',
                        metavar='DIR',
                        default='/DATA/tiny-imagenet-200',
                        help='path to imagenet dataset')
    parser.add_argument(
        '--model',
        default="resnet18",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--save', default='model', help='model file')
    parser.add_argument('--prune_fname',
                        default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx',
                        type=int,
                        default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--method',
                        type=int,
                        default=0,
                        help='sets pruning method')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "tinyimagenet":
        print("Using tiny-Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'test')

        normalize = transforms.Normalize([0.4802, 0.4481, 0.3975],
                                         [0.2302, 0.2265, 0.2262])
        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomRotation(20),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir, transforms.Compose([
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]")
        exit()

    if args.dataset == 'CIFAR10':
        num_classes = 10
        args.epochs = 40
    elif args.dataset == 'CIFAR100':
        num_classes = 100
        args.epochs = 40
    elif args.dataset == 'tinyimagenet':
        num_classes = 200
        args.epochs = 20

    # ratios = [0.25, 0.75]
    ratios = [0.75]
    pruned_filters = None
    for ratio in ratios:
        print("Current ratio: {}".format(ratio))
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=1,
                          neuralscale=False,
                          num_classes=num_classes,
                          prune_fname=args.prune_fname,
                          descent_idx=args.descent_idx,
                          pruned_filters=pruned_filters,
                          search=True)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=1,
                                   neuralscale=False,
                                   num_classes=num_classes,
                                   dataset=args.dataset,
                                   prune_fname=args.prune_fname,
                                   descent_idx=args.descent_idx,
                                   pruned_filters=pruned_filters,
                                   search=True)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=1,
                                neuralscale=False,
                                num_classes=num_classes,
                                dataset=args.dataset,
                                prune_fname=args.prune_fname,
                                descent_idx=args.descent_idx,
                                pruned_filters=pruned_filters,
                                search=True)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))

        # optionally resume from a checkpoint
        if args.resume:
            model_path = "saved_models"
            if not os.path.exists(model_path):
                os.makedirs(model_path)
            log_save_folder = "%s/%s" % (model_path, args.model)
            model_resume_path = "%s/%s" % (log_save_folder,
                                           args.resume) + "_best_model.t7"
            if os.path.isfile(model_resume_path):
                print("=> loading checkpoint '{}'".format(model_resume_path))
                checkpoint = torch.load(model_resume_path)
                best_acc1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    model_resume_path, checkpoint['epoch']))
            else:
                print(
                    "=> no checkpoint found at '{}'".format(model_resume_path))

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum(
            [np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ",
              total_size_params)

        optimizer = optim.SGD(parameters_for_update,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        pruner = compare_pruner(model, method=args.method, size=ratio)
        pruner.prune_neurons(optimizer)

        ###############
        ## Training  ##
        ###############
        best_test_acc = 0
        train_acc_plt = []
        train_loss_plt = []
        test_acc_plt = []
        test_loss_plt = []
        epoch_plt = []
        for epoch in range(1, args.epochs + 1):
            print("Epoch: {}".format(epoch))

            # train model
            train_acc, train_loss = train(args, model, train_loader, optimizer,
                                          epoch, criterion)

            # evaluate on validation set
            test_acc, test_loss = validate(args,
                                           test_loader,
                                           model,
                                           criterion,
                                           epoch,
                                           optimizer=optimizer)

            # remember best prec@1 and save checkpoint
            is_best = test_acc > best_test_acc
            best_test_acc = max(test_acc, best_test_acc)
            print(best_test_acc)

            train_acc_plt.append(train_acc)
            train_loss_plt.append(train_loss)
            test_acc_plt.append(test_acc)
            test_loss_plt.append(test_loss)
            epoch_plt.append(epoch)

        pickle_save = {
            "ratio": ratio,
            "train_acc": train_acc_plt,
            "train_loss": train_loss_plt,
            "test_acc": test_acc_plt,
            "test_loss": test_loss_plt,
        }
        plot_path = "saved_plots"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        log_save_folder = "%s/%s" % (plot_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        pickle_out = open(
            "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)),
            "wb")
        pickle.dump(pickle_save, pickle_out)
        pickle_out.close()
Ejemplo n.º 6
0
 #ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
 # ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
 ratio = 1
 pruned_filters = None
 # pruned_filters = [48, 46, 40, 41, 54, 91, 75, 73, 95, 157, 149, 149, 156, 232, 216, 140, 190] # resnet18 c100
 #pruned_filters = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] # mobilenetv2 c100
 # pruned_filters = [126, , 256, , 464, 488, , 494, 364, , 204, 356] # VGG C10
 # pruned_filters = [122, , 244, , 370, 388, , 430, 354, , 256, 588] # VGG C100
 ###########
 ## Model ##
 ###########
 print("Setting Up Model...")
 if args.model == "vgg":
     model = vgg11(ratio=ratio,
                   efficient_scale=False,
                   num_classes=num_classes,
                   prune_fname=args.prune_fname,
                   descent_idx=args.descent_idx,
                   pruned_filters=pruned_filters)
 elif args.model == "resnet18":
     model = PreActResNet18(ratio=ratio,
                            efficient_scale=False,
                            num_classes=num_classes,
                            dataset=args.dataset,
                            prune_fname=args.prune_fname,
                            descent_idx=args.descent_idx,
                            pruned_filters=pruned_filters,
                            search=True)
 elif args.model == "mobilenetv2":
     model = MobileNetV2(ratio=ratio,
                         efficient_scale=False,
                         num_classes=num_classes,
Ejemplo n.º 7
0
        cur_time = time.time()
        if idx > 20:  # allow 20 runs for GPU to warm-up
            latency.append(cur_time - last_time)
        last_time = cur_time
    del model
    return np.mean(latency) * 1000


if args.model == "vgg":
    config = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
    new_config = VGG.prepare_filters(VGG,
                                     config,
                                     ratio=0.75,
                                     neuralscale=False,
                                     num_classes=num_classes)
    model = vgg11(config=new_config, num_classes=num_classes)
    latency = compute_latency(model)
    params = compute_params_(model)
elif args.model == "resnet18":
    filters = [[64], [64, 64], [64, 64], [128, 128], [128, 128], [256, 256],
               [256, 256], [512, 512], [512, 512]]
    new_config = PreActResNet.prepare_filters(PreActResNet,
                                              filters,
                                              ratio=0.75,
                                              neuralscale=False,
                                              num_classes=num_classes)
    model = PreActResNet18(filters=new_config,
                           num_classes=num_classes,
                           dataset=args.dataset)
    latency = compute_latency(model)
    params = compute_params_(model)
Ejemplo n.º 8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'Train a base network under ratio=1 (default configuration) for pruning'
    )
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    # parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
    #                     help='learning rate (default: 0.01)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--lr-decay-every',
                        type=int,
                        default=100,
                        help='learning rate decay by 10 every X epochs')
    parser.add_argument('--lr-decay-scalar',
                        type=float,
                        default=0.1,
                        help='--')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: MNIST, CIFAR10',
        choices=["MNIST", "CIFAR10", "CIFAR100", "Imagenet", "tinyimagenet"])
    parser.add_argument('--data',
                        metavar='DIR',
                        default='/DATA/tiny-imagenet-200',
                        help='path to tinyimagenet dataset')
    parser.add_argument(
        '--model',
        default="resnet18",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18", "mobilenet"])
    parser.add_argument('--r',
                        dest="resume",
                        action='store_true',
                        default=False,
                        help='Resume from checkpoint')
    parser.add_argument('--save', default='model', help='model file')
    parser.add_argument('--prune_fname',
                        default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx',
                        type=int,
                        default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--s',
                        type=float,
                        default=0.0001,
                        help='scale sparse rate (default: 0.0001)')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "Imagenet":
        print("Using Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        if args.use_test_as_train:
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    elif args.dataset == "tinyimagenet":
        print("Using tiny-Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'test')

        normalize = transforms.Normalize([0.4802, 0.4481, 0.3975],
                                         [0.2302, 0.2265, 0.2262])
        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomRotation(20),
                # transforms.RandomResizedCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir, transforms.Compose([
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]")
        exit()

    if args.dataset == 'CIFAR10':
        num_classes = 10
    elif args.dataset == 'CIFAR100':
        num_classes = 100
    elif args.dataset == 'tinyimagenet':
        num_classes = 200

    ratio = 1
    pruned_filters = None
    ###########
    ## Model ##
    ###########
    print("Setting Up Model...")
    if args.model == "vgg":
        model = vgg11(ratio=ratio,
                      neuralscale=False,
                      num_classes=num_classes,
                      prune_fname=args.prune_fname,
                      descent_idx=args.descent_idx,
                      pruned_filters=pruned_filters)
    elif args.model == "resnet18":
        model = PreActResNet18(ratio=ratio,
                               neuralscale=False,
                               num_classes=num_classes,
                               dataset=args.dataset,
                               prune_fname=args.prune_fname,
                               descent_idx=args.descent_idx,
                               pruned_filters=pruned_filters,
                               search=True)
    elif args.model == "mobilenetv2":
        model = MobileNetV2(ratio=ratio,
                            neuralscale=False,
                            num_classes=num_classes,
                            dataset=args.dataset,
                            prune_fname=args.prune_fname,
                            descent_idx=args.descent_idx,
                            pruned_filters=pruned_filters,
                            search=True)
    else:
        print(args.model, "model not supported")
        exit()
    print("{} set up.".format(args.model))

    # for model saving
    model_path = "saved_models"
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    log_save_folder = "%s/%s" % (model_path, args.model)
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    model_save_path = "%s/%s" % (log_save_folder, args.save) + "_checkpoint.t7"
    model_state_dict = model.state_dict()
    if args.save:
        print("Model will be saved to {}".format(model_save_path))
        save_checkpoint({'state_dict': model_state_dict},
                        False,
                        filename=model_save_path)
    else:
        print("Save path not defined. Model will not be saved.")

    # Assume cuda is available and uses single GPU
    model.cuda()
    cudnn.benchmark = True

    # define objective
    criterion = nn.CrossEntropyLoss()

    ######################
    ## Set up pruning   ##
    ######################
    # remove updates from gate layers, because we want them to be 0 or 1 constantly
    parameters_for_update = []
    parameters_for_update_named = []
    for name, m in model.named_parameters():
        if "gate" not in name:
            parameters_for_update.append(m)
            parameters_for_update_named.append((name, m))
        else:
            print("skipping parameter", name, "shape:", m.shape)

    total_size_params = sum(
        [np.prod(par.shape) for par in parameters_for_update])
    print("Total number of parameters, w/o usage of bn consts: ",
          total_size_params)

    optimizer = optim.SGD(parameters_for_update,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    ###############
    ## Training  ##
    ###############
    best_test_acc = 0
    train_acc_plt = []
    train_loss_plt = []
    test_acc_plt = []
    test_loss_plt = []
    epoch_plt = []
    if num_classes == 200:  # tinyimagenet
        args.epochs = 150
    else:
        args.epochs = 300
    for epoch in range(1, args.epochs + 1):
        if num_classes == 200:  # tinyimagenet
            adjust_learning_rate_imagenet(args, optimizer, epoch, search=False)
        else:
            adjust_learning_rate(args, optimizer, epoch)

        print("Epoch: {}".format(epoch))

        # train model
        train_acc, train_loss = train(args, model, train_loader, optimizer,
                                      epoch, criterion)

        # evaluate on validation set
        test_acc, test_loss = validate(args,
                                       test_loader,
                                       model,
                                       criterion,
                                       epoch,
                                       optimizer=optimizer)

        # remember best prec@1 and save checkpoint
        is_best = test_acc > best_test_acc
        best_test_acc = max(test_acc, best_test_acc)
        model_state_dict = model.state_dict()
        if args.save:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': test_acc,
                },
                is_best,
                filename=model_save_path)

        train_acc_plt.append(train_acc)
        train_loss_plt.append(train_loss)
        test_acc_plt.append(test_acc)
        test_loss_plt.append(test_loss)
        epoch_plt.append(epoch)

    pickle_save = {
        "ratio": ratio,
        "train_acc": train_acc_plt,
        "train_loss": train_loss_plt,
        "test_acc": test_acc_plt,
        "test_loss": test_loss_plt,
    }
    plot_path = "saved_plots"
    if not os.path.exists(plot_path):
        os.makedirs(plot_path)

    log_save_folder = "%s/%s" % (plot_path, args.model)
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    pickle_out = open(
        "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)), "wb")
    pickle.dump(pickle_save, pickle_out)
    pickle_out.close()