Exemplo n.º 1
0
def main():
    manager = Manager.init()

    models = [["model", MobileNetV2(**manager.args.model)]]

    manager.init_model(models)
    args = manager.args
    criterion = Criterion()
    optimizer, scheduler = Optimizer(models, args.optim).init()

    args.cuda = args.cuda and torch.cuda.is_available()
    if args.cuda:
        for item in models:
            item[1].cuda()
        criterion.cuda()

    dataloader = DataLoader(args.dataloader, args.cuda)

    summary = manager.init_summary()
    trainer = Trainer(models, criterion, optimizer, scheduler, dataloader,
                      summary, args.cuda)

    for epoch in range(args.runtime.start_epoch,
                       args.runtime.num_epochs + args.runtime.start_epoch):
        try:
            print("epoch {}...".format(epoch))
            trainer.train(epoch)
            manager.save_checkpoint(models, epoch)

            if (epoch + 1) % args.runtime.test_every == 0:
                trainer.validate()
        except KeyboardInterrupt:
            print("Training had been Interrupted\n")
            break
    trainer.test()
Exemplo n.º 2
0
def main(args):
    # training_set, test_set = compile_dataset_files(args.data)
    w300_dataset = W300(args.data, train=True)

    train_dataset = w300_dataset.create_dataset(args)

    # Tensorboard
    if (os.path.exists(tb_train_dir)):
        shutil.rmtree(tb_train_dir)

    logdir = os.path.join(tb_train_dir,
                          datetime.now().strftime("%Y%m%d-%H%M%S"))

    file_writer = tf.summary.create_file_writer(logdir)

    model = MobileNetV2(input_shape=(256, 256, 3),
                        k=w300_dataset.expression_number())

    optimizer = tf.keras.optimizers.Adam()
    loss_fn = tf.keras.losses.mean_squared_error()

    epoch_num = 0
    with file_writer.as_default():
        for step, data in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                img, exp_params, vertex = data

                epoch_step = step % len(w300_dataset)

                if epoch_step == 0:
                    epoch_num += 1
                    print("starting epoch {}".format(epoch_num))

                prediction = model(img, training=True)
                loss = loss_fn(prediction, exp_params)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))

            if step % args.vis_freq == 0:
                print("epoch {} step {}".format(epoch_num, epoch_step))

                tf.summary.image("Training Image",
                                 img,
                                 max_outputs=1,
                                 step=step)
                summary = mesh_summary.mesh('GT Expression Mesh',
                                            vertices=vertex[0:],
                                            faces=w300_dataset.faces,
                                            step=step)
                file_writer.flush()
Exemplo n.º 3
0
                      prune_fname=vgg_60_fname,
                      descent_idx=14)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_60.append(params)
        latency_60.append(latency)
elif args.model == "mobilenetv2":  # CIFAR100 only
    filters = [[32], [16], [24, 24], [32, 32, 32], [64, 64, 64, 64],
               [96, 96, 96], [160, 160, 160], [320], [1280]]
    ratios = np.arange(0.25, 2.1,
                       0.25)  # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    for ratio in ratios:
        # uniform
        new_config = MobileNetV2.prepare_filters(MobileNetV2,
                                                 filters,
                                                 ratio=ratio,
                                                 neuralscale=False,
                                                 num_classes=num_classes)
        model = MobileNetV2(filters=new_config,
                            num_classes=num_classes,
                            dataset=args.dataset)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_uni.append(params)
        latency_uni.append(latency)
        ## efficient
        mobilenetv2_0_fname = "mobilenetv2_0_eff_c100"
        mobilenetv2_2_fname = "mobilenetv2_2_eff_c100"
        mobilenetv2_5_fname = "mobilenetv2_5_eff_c100"
        mobilenetv2_10_fname = "mobilenetv2_10_eff_c100"
        mobilenetv2_30_fname = "mobilenetv2_30_eff_c100"
Exemplo n.º 4
0
elif args.model == "mobilenetv2":
    original_filters = [[32],[16],[24,24],[32,32,32],[64,64,64,64],[96,96,96],[160,160,160],[320],[1280]]

for idx, ratio in enumerate(ratios):
    if plot_growth:
        total_params = []
    print("Ratio: {}".format(ratio))
    for iteration in range(15):
        if args.model == "vgg":
            filters = VGG.prepare_filters(VGG, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration)
            filters = [cfg for cfg in list(filter(lambda a: a != 'M', filters))]
        elif args.model == "resnet18":
            filters = PreActResNet.prepare_filters(PreActResNet, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration)
            filters = [cfg for cfg in sum(filters,[])]
        elif args.model == "mobilenetv2":
            filters = MobileNetV2.prepare_filters(MobileNetV2, original_filters, ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=iteration)
            filters = [cfg for cfg in sum(filters,[])][:-2]
        if plot_growth:
            total_params.append(filters)
                
    if plot_growth:
        # ax = sns.heatmap(np.array(total_params).T, linewidth=0.005, xticklabels=2, yticklabels=2, ax=axs[idx])
        ax = sns.heatmap(np.array(total_params).T, linewidth=0.005, xticklabels=2, ax=axs[int(idx/2)][idx%2])
        ax.set_title("Ratio={}".format(ratio))
        # if idx!=0:
        #     ax.tick_params(axis='y', which='both', width=0)
        if idx==1 or idx==3:
            ax.tick_params(axis='y', which='both', width=0, length=0)
        if idx==0 or idx==1:
            ax.tick_params(axis='x', which='both', width=0, length=0)
Exemplo n.º 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'Search for optimal configuration using architecture descent')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: CIFAR10, CIFAR100',
        choices=["CIFAR10", "CIFAR100"])
    parser.add_argument(
        '--model',
        default="vgg",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--save', default='model', help='model and prune file')
    parser.add_argument('--no_prune',
                        dest="prune_on",
                        action='store_false',
                        default=True,
                        help='Turn off pruning')
    parser.add_argument(
        '--warmup',
        type=int,
        default=10,
        help=
        'number of warm-up or fine-tuning epochs before pruning (default: 10)')
    parser.add_argument(
        '--morph',
        dest="morph",
        action='store_true',
        default=False,
        help='Prunes only 50 percent of neurons, for comparison with MorphNet')
    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################
    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)

    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "Imagenet":
        print("Using Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, CIFAR100]")
        exit()

    if args.dataset == "CIFAR10":
        num_classes = 10
    elif args.dataset == "CIFAR100":
        num_classes = 100

    param_list = []
    if args.morph:
        total_iter = 1
    else:
        total_iter = 15
    for iteration in range(total_iter):
        print("Iteration: {}".format(iteration))
        args.lr = 0.1
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=1,
                          neuralscale=True,
                          iteration=iteration,
                          num_classes=num_classes,
                          search=True)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=1,
                                   neuralscale=True,
                                   iteration=iteration,
                                   num_classes=num_classes,
                                   search=True,
                                   dataset=args.dataset)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=1,
                                neuralscale=True,
                                iteration=iteration,
                                num_classes=num_classes,
                                search=True,
                                dataset=args.dataset)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))

        # for model saving
        model_path = "saved_models"
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        log_save_folder = "%s/%s" % (model_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        model_save_path = "%s/%s" % (log_save_folder,
                                     args.save) + "_checkpoint.t7"
        model_state_dict = model.state_dict()
        if args.save:
            print("Model will be saved to {}".format(model_save_path))
            save_checkpoint({'state_dict': model_state_dict},
                            False,
                            filename=model_save_path)
        else:
            print("Save path not defined. Model will not be saved.")

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum(
            [np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ",
              total_size_params)
        optimizer = optim.SGD(parameters_for_update,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        if args.prune_on:
            pruning_parameters_list = prepare_pruning_list(model)
            print("Total pruning layers:", len(pruning_parameters_list))
            if args.morph:
                prune_neurons = 0.5
            else:
                prune_neurons = 0.95
            pruning_engine = pruner(pruning_parameters_list,
                                    iteration=iteration,
                                    prune_fname=args.save,
                                    classes=num_classes,
                                    model=args.model,
                                    prune_neurons=prune_neurons)

        ###############
        ## Training  ##
        ###############
        for epoch in range(1, args.epochs + 1):
            print("Epoch: {}".format(epoch))
            adjust_learning_rate(args,
                                 optimizer,
                                 epoch,
                                 search=True,
                                 warmup=args.warmup)
            # train model
            if args.prune_on:
                train_acc, train_loss = train(args,
                                              model,
                                              train_loader,
                                              optimizer,
                                              epoch,
                                              criterion,
                                              pruning_engine=pruning_engine,
                                              num_classes=num_classes)
                if train_acc == -1 and train_loss == -1:
                    break
            else:
                train_acc, train_loss = train(args,
                                              model,
                                              train_loader,
                                              optimizer,
                                              epoch,
                                              criterion,
                                              num_classes=num_classes)

        # =========================
        # store searched parameters
        # =========================
        alpha, beta = fit_params(
            iteration=iteration,
            prune_fname=args.save,
            classes=num_classes,
            model=args.model)  # search for params of each layer
        param_list.append([alpha, beta])

    # =======================
    # save scaling parameters
    # =======================
    pickle_save = {
        "param": param_list,
    }

    pickle_out = open("prune_record/" + args.save + ".pk", "wb")
    pickle.dump(pickle_save, pickle_out)
    pickle_out.close()
Exemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Efficient Filter Scaling of Convolutional Neural Network')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=300, metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--dataset', default="CIFAR10", type=str,
                            help='dataset for experiment, choice: CIFAR10, CIFAR100', choices= ["CIFAR10", "CIFAR100"])
    parser.add_argument('--model', default="resnet18", type=str,
                        help='model selection, choices: vgg, mobilenetv2, resnet18',
                        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--save', default='model',
                        help='model file')
    parser.add_argument('--prune_fname', default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx', type=int, default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--morph', dest="morph", action='store_true', default=False,
                        help='Prunes only 50 percent of neurons, for comparison with MorphNet')
    parser.add_argument('--uniform', dest="uniform", action='store_true', default=False,
                        help='Use uniform scaling instead of NeuralScale')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if(args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                            std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=True, 
                                                download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/', train=False,
                                                download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=True, 
                                                download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/', train=False,
                                                download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                    shuffle=True, **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, CIFAR100]")
        exit()

    if args.dataset=='CIFAR10':
        num_classes = 10
    elif args.dataset=='CIFAR100':
        num_classes = 100

    ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    pruned_filters = None
    neuralscale = True  # turn NeuralScale on by default
    if args.uniform:
        neuralscale = False
    if args.morph:
        neuralscale = False
        if args.model == "vgg":
            if args.dataset == "CIFAR10":
                pruned_filters = [64, 128, 249, 253, 268, 175, 87, 152] # VGG C10
            elif args.dataset == "CIFAR100":
                pruned_filters = [63, 125, 204, 215, 234, 174, 120, 241] # VGG C100
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
        elif args.model == "resnet18":
            if args.dataset == "CIFAR100":
                pruned_filters = [48, 46, 40, 41, 54, 91, 75, 73, 95, 157, 149, 149, 156, 232, 216, 140, 190] # resnet18 c100
            elif args.dataset == "tinyimagenet":
                print("Please use imagenet_ratio_swipe.py instead.")
                exit()
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
        elif args.model == "mobilenetv2":
            if args.dataset == "CIFAR100":
                pruned_filters = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] # mobilenetv2 c100
            elif args.dataset == "tinyimagenet":
                print("Please use imagenet_ratio_swipe.py instead.")
                exit()
            else:
                print("{} not supported for {}".format(args.dataset, args.model))
                exit()
    for ratio in ratios:
        print("Current ratio: {}".format(ratio))
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))


        # for model saving
        model_path = "saved_models"
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        log_save_folder = "%s/%s"%(model_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        model_save_path = "%s/%s"%(log_save_folder, args.save) + "_checkpoint.t7"
        model_state_dict = model.state_dict()
        if args.save:
            print("Model will be saved to {}".format(model_save_path))
            save_checkpoint({
                'state_dict': model_state_dict
            }, False, filename = model_save_path)
        else:
            print("Save path not defined. Model will not be saved.")

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        
        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum([np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ", total_size_params)

        optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        ###############
        ## Training  ##
        ###############
        best_test_acc = 0
        train_acc_plt = []
        train_loss_plt = []
        test_acc_plt = []
        test_loss_plt = []
        epoch_plt = []
        for epoch in range(1, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            print("Epoch: {}".format(epoch))

            # train model
            train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion)

            # evaluate on validation set
            test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer)

            # remember best prec@1 and save checkpoint
            is_best = test_acc > best_test_acc
            best_test_acc = max(test_acc, best_test_acc)
            model_state_dict = model.state_dict()
            if args.save:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': test_acc,
                }, is_best, filename=model_save_path)


            train_acc_plt.append(train_acc)
            train_loss_plt.append(train_loss)
            test_acc_plt.append(test_acc)
            test_loss_plt.append(test_loss)
            epoch_plt.append(epoch)


        pickle_save = {
            "ratio": ratio,
            "train_acc": train_acc_plt,
            "train_loss": train_loss_plt,
            "test_acc": test_acc_plt,
            "test_loss": test_loss_plt,
        }
        plot_path = "saved_plots"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        log_save_folder = "%s/%s"%(plot_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        pickle_out = open("%s/%s_%s.pk"%(log_save_folder, args.save, int(ratio*100)),"wb")
        pickle.dump(pickle_save, pickle_out)
        pickle_out.close()
Exemplo n.º 7
0
def load_model(model_file):
    model = MobileNetV2()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model
Exemplo n.º 8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Efficient Filter Scaling of Convolutional Neural Network')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=150, metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--dataset', default="tinyimagenet", type=str,
                            help='dataset for experiment, choice: tinyimagenet', choices= ["tinyimagenet"])
    parser.add_argument('--data', metavar='DIR', default='/DATA/tiny-imagenet-200', help='path to imagenet dataset')
    parser.add_argument('--model', default="resnet18", type=str,
                        help='model selection, choices: vgg, mobilenetv2, resnet18',
                        choices=["mobilenetv2", "resnet18"])
    parser.add_argument('--save', default='model',
                        help='model file')
    parser.add_argument('--prune_fname', default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx', type=int, default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--morph', dest="morph", action='store_true', default=False,
                        help='Prunes only 50 percent of neurons, for comparison with MorphNet')
    parser.add_argument('--uniform', dest="uniform", action='store_true', default=False,
                        help='Use uniform scaling instead of NeuralScale')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if args.dataset == "tinyimagenet":
        print("Using tiny-Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'test')

        normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])
        train_dataset = torchvision.datasets.ImageFolder(
                            traindir,
                            transforms.Compose([
                                transforms.RandomCrop(64, padding=4),
                                transforms.RandomRotation(20),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                normalize,
                            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            sampler=train_sampler, pin_memory=True, **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(valdir, transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs)
    else:
        print("Dataset does not exist! [Imagenet]")
        exit()

    if args.dataset=='tinyimagenet':
        num_classes = 200
    else:
        print("Only tinyimagenet")
        exit()

    ratios = [0.25, 0.5, 0.75, 1.0]
    pruned_filters = None
    
    neuralscale = True  # turn on NeuralScale by default
    if args.uniform:
        neuralscale = False
    if args.morph:
        neuralscale = False
        if args.model == "resnet18":
            pruned_filters = [82,90,78,80,96,180,104,96,194,312,182,178,376,546,562,454,294] # resnet18 tinyimagenet
        elif args.mode == "mobilenetv2":
            pruned_filters = [28, 16, 24, 24, 32, 32, 30, 64, 59, 50, 41, 96, 73, 48, 160, 69, 47, 155, 360] # mobilenetv2 tinyimagenet
        else:
            print("{} not supported.".format(args.model))
            exit()

    for ratio in ratios:
        print("Current ratio: {}".format(ratio))
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "resnet18":
            model = PreActResNet18(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=ratio, neuralscale=neuralscale, num_classes=num_classes, dataset=args.dataset, prune_fname=args.prune_fname, descent_idx=args.descent_idx, pruned_filters=pruned_filters)
        else:
            print(args.model, "model not supported [resnet18 mobilenetv2] only")
            exit()
        print("{} set up.".format(args.model))


        # for model saving
        model_path = "saved_models"
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        log_save_folder = "%s/%s"%(model_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        model_save_path = "%s/%s"%(log_save_folder, args.save) + "_checkpoint.t7"
        model_state_dict = model.state_dict()
        if args.save:
            print("Model will be saved to {}".format(model_save_path))
            save_checkpoint({
                'state_dict': model_state_dict
            }, False, filename = model_save_path)
        else:
            print("Save path not defined. Model will not be saved.")

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        
        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum([np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ", total_size_params)

        optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        ###############
        ## Training  ##
        ###############
        best_test_acc = 0
        train_acc_plt = []
        train_loss_plt = []
        test_acc_plt = []
        test_loss_plt = []
        epoch_plt = []
        for epoch in range(1, args.epochs + 1):
            adjust_learning_rate_imagenet(args, optimizer, epoch, search=False)
            print("Epoch: {}".format(epoch))

            # train model
            train_acc, train_loss = train(args, model, train_loader, optimizer, epoch, criterion)

            # evaluate on validation set
            test_acc, test_loss = validate(args, test_loader, model, criterion, epoch, optimizer=optimizer)

            # remember best prec@1 and save checkpoint
            is_best = test_acc > best_test_acc
            best_test_acc = max(test_acc, best_test_acc)
            model_state_dict = model.state_dict()
            if args.save:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': test_acc,
                }, is_best, filename=model_save_path)


            train_acc_plt.append(train_acc)
            train_loss_plt.append(train_loss)
            test_acc_plt.append(test_acc)
            test_loss_plt.append(test_loss)
            epoch_plt.append(epoch)


        pickle_save = {
            "ratio": ratio,
            "train_acc": train_acc_plt,
            "train_loss": train_loss_plt,
            "test_acc": test_acc_plt,
            "test_loss": test_loss_plt,
        }
        plot_path = "saved_plots"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        log_save_folder = "%s/%s"%(plot_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        pickle_out = open("%s/%s_%s.pk"%(log_save_folder, args.save, int(ratio*100)),"wb")
        pickle.dump(pickle_save, pickle_out)
        pickle_out.close()
Exemplo n.º 9
0
        model = PreActResNet18(ratio=ratio, neuralscale=True, num_classes=num_classes, prune_fname=fname, descent_idx=14)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_high.append(params)
        latency_high.append(latency)
elif args.model == "mobilenetv2":
    filters = [[32],[16],[24,24],[32,32,32],[64,64,64,64],[96,96,96],[160,160,160],[320],[1280]]
    if args.dataset == "CIFAR100":
        filters_prune = [28, 16, 24, 21, 30, 31, 26, 56, 50, 49, 46, 83, 70, 58, 120, 101, 68, 134, 397] 
        ratios = np.arange(0.25,2.1,0.25) # [0.25, 0.5 , 0.75, 1, 1.25, 1.5 , 1.75, 2]
    elif args.dataset == "tinyimagenet":
        filters_prune = [28, 16, 24, 24, 32, 32, 30, 64, 59, 50, 41, 96, 73, 48, 160, 69, 47, 155, 360] # mobilenetv2 tinyimagenet
        ratios = [0.25,0.5,0.75,1.0]
    for ratio in ratios:
        # convcut
        new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes)
        model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset, convcut=True)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_convcut.append(params)
        latency_convcut.append(latency)
        # uniform
        new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes)
        model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset)
        latency = compute_latency(model)
        params = compute_params_(model)
        param_uni.append(params)
        latency_uni.append(latency)
        # pruned
        new_config = MobileNetV2.prepare_filters(MobileNetV2, filters, ratio=ratio, neuralscale=False, num_classes=num_classes, pruned_filters=filters_prune)
        model = MobileNetV2(filters=new_config, num_classes=num_classes, dataset=args.dataset)
Exemplo n.º 10
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Fine-tune on pruned architecture')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: CIFAR10, CIFAR100, tinyimagenet',
        choices=["CIFAR10", "CIFAR100", "tinyimagenet"])
    parser.add_argument('--data',
                        metavar='DIR',
                        default='/DATA/tiny-imagenet-200',
                        help='path to imagenet dataset')
    parser.add_argument(
        '--model',
        default="resnet18",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18"])
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--save', default='model', help='model file')
    parser.add_argument('--prune_fname',
                        default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx',
                        type=int,
                        default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--method',
                        type=int,
                        default=0,
                        help='sets pruning method')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "tinyimagenet":
        print("Using tiny-Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'test')

        normalize = transforms.Normalize([0.4802, 0.4481, 0.3975],
                                         [0.2302, 0.2265, 0.2262])
        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomRotation(20),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir, transforms.Compose([
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]")
        exit()

    if args.dataset == 'CIFAR10':
        num_classes = 10
        args.epochs = 40
    elif args.dataset == 'CIFAR100':
        num_classes = 100
        args.epochs = 40
    elif args.dataset == 'tinyimagenet':
        num_classes = 200
        args.epochs = 20

    # ratios = [0.25, 0.75]
    ratios = [0.75]
    pruned_filters = None
    for ratio in ratios:
        print("Current ratio: {}".format(ratio))
        ###########
        ## Model ##
        ###########
        print("Setting Up Model...")
        if args.model == "vgg":
            model = vgg11(ratio=1,
                          neuralscale=False,
                          num_classes=num_classes,
                          prune_fname=args.prune_fname,
                          descent_idx=args.descent_idx,
                          pruned_filters=pruned_filters,
                          search=True)
        elif args.model == "resnet18":
            model = PreActResNet18(ratio=1,
                                   neuralscale=False,
                                   num_classes=num_classes,
                                   dataset=args.dataset,
                                   prune_fname=args.prune_fname,
                                   descent_idx=args.descent_idx,
                                   pruned_filters=pruned_filters,
                                   search=True)
        elif args.model == "mobilenetv2":
            model = MobileNetV2(ratio=1,
                                neuralscale=False,
                                num_classes=num_classes,
                                dataset=args.dataset,
                                prune_fname=args.prune_fname,
                                descent_idx=args.descent_idx,
                                pruned_filters=pruned_filters,
                                search=True)
        else:
            print(args.model, "model not supported")
            exit()
        print("{} set up.".format(args.model))

        # optionally resume from a checkpoint
        if args.resume:
            model_path = "saved_models"
            if not os.path.exists(model_path):
                os.makedirs(model_path)
            log_save_folder = "%s/%s" % (model_path, args.model)
            model_resume_path = "%s/%s" % (log_save_folder,
                                           args.resume) + "_best_model.t7"
            if os.path.isfile(model_resume_path):
                print("=> loading checkpoint '{}'".format(model_resume_path))
                checkpoint = torch.load(model_resume_path)
                best_acc1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    model_resume_path, checkpoint['epoch']))
            else:
                print(
                    "=> no checkpoint found at '{}'".format(model_resume_path))

        # Assume cuda is available and uses single GPU
        model.cuda()
        cudnn.benchmark = True

        # define objective
        criterion = nn.CrossEntropyLoss()

        ######################
        ## Set up pruning   ##
        ######################
        # remove updates from gate layers, because we want them to be 0 or 1 constantly
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

        total_size_params = sum(
            [np.prod(par.shape) for par in parameters_for_update])
        print("Total number of parameters, w/o usage of bn consts: ",
              total_size_params)

        optimizer = optim.SGD(parameters_for_update,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        pruner = compare_pruner(model, method=args.method, size=ratio)
        pruner.prune_neurons(optimizer)

        ###############
        ## Training  ##
        ###############
        best_test_acc = 0
        train_acc_plt = []
        train_loss_plt = []
        test_acc_plt = []
        test_loss_plt = []
        epoch_plt = []
        for epoch in range(1, args.epochs + 1):
            print("Epoch: {}".format(epoch))

            # train model
            train_acc, train_loss = train(args, model, train_loader, optimizer,
                                          epoch, criterion)

            # evaluate on validation set
            test_acc, test_loss = validate(args,
                                           test_loader,
                                           model,
                                           criterion,
                                           epoch,
                                           optimizer=optimizer)

            # remember best prec@1 and save checkpoint
            is_best = test_acc > best_test_acc
            best_test_acc = max(test_acc, best_test_acc)
            print(best_test_acc)

            train_acc_plt.append(train_acc)
            train_loss_plt.append(train_loss)
            test_acc_plt.append(test_acc)
            test_loss_plt.append(test_loss)
            epoch_plt.append(epoch)

        pickle_save = {
            "ratio": ratio,
            "train_acc": train_acc_plt,
            "train_loss": train_loss_plt,
            "test_acc": test_acc_plt,
            "test_loss": test_loss_plt,
        }
        plot_path = "saved_plots"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        log_save_folder = "%s/%s" % (plot_path, args.model)
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        pickle_out = open(
            "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)),
            "wb")
        pickle.dump(pickle_save, pickle_out)
        pickle_out.close()
Exemplo n.º 11
0
                      descent_idx=args.descent_idx,
                      pruned_filters=pruned_filters)
    elif args.model == "resnet18":
        model = PreActResNet18(ratio=ratio,
                               efficient_scale=False,
                               num_classes=num_classes,
                               dataset=args.dataset,
                               prune_fname=args.prune_fname,
                               descent_idx=args.descent_idx,
                               pruned_filters=pruned_filters,
                               search=True)
    elif args.model == "mobilenetv2":
        model = MobileNetV2(ratio=ratio,
                            efficient_scale=False,
                            num_classes=num_classes,
                            dataset=args.dataset,
                            prune_fname=args.prune_fname,
                            descent_idx=args.descent_idx,
                            pruned_filters=pruned_filters,
                            search=True)
    else:
        print(args.model, "model not supported")
        exit()
    print("{} set up.".format(args.model))

    # for model saving
    model_path = "saved_models"
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    log_save_folder = "%s/%s" % (model_path, args.model)
    if not os.path.exists(log_save_folder):
Exemplo n.º 12
0
    new_config = PreActResNet.prepare_filters(PreActResNet,
                                              filters,
                                              ratio=0.75,
                                              neuralscale=False,
                                              num_classes=num_classes)
    model = PreActResNet18(filters=new_config,
                           num_classes=num_classes,
                           dataset=args.dataset)
    latency = compute_latency(model)
    params = compute_params_(model)
elif args.model == "mobilenetv2":
    filters = [[32], [16], [24, 24], [32, 32, 32], [64, 64, 64, 64],
               [96, 96, 96], [160, 160, 160], [320], [1280]]
    new_config = MobileNetV2.prepare_filters(MobileNetV2,
                                             filters,
                                             ratio=0.75,
                                             neuralscale=False,
                                             num_classes=num_classes)
    model = MobileNetV2(filters=new_config,
                        num_classes=num_classes,
                        dataset=args.dataset)
    latency = compute_latency(model)
    params = compute_params_(model)

ratio = 0.75

uni_test_loss = []
uni_test_acc = []

uni_test_loss_tmp = []
uni_test_acc_tmp = []
Exemplo n.º 13
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'Train a base network under ratio=1 (default configuration) for pruning'
    )
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    # parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
    #                     help='learning rate (default: 0.01)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--lr-decay-every',
                        type=int,
                        default=100,
                        help='learning rate decay by 10 every X epochs')
    parser.add_argument('--lr-decay-scalar',
                        type=float,
                        default=0.1,
                        help='--')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--dataset',
        default="CIFAR10",
        type=str,
        help='dataset for experiment, choice: MNIST, CIFAR10',
        choices=["MNIST", "CIFAR10", "CIFAR100", "Imagenet", "tinyimagenet"])
    parser.add_argument('--data',
                        metavar='DIR',
                        default='/DATA/tiny-imagenet-200',
                        help='path to tinyimagenet dataset')
    parser.add_argument(
        '--model',
        default="resnet18",
        type=str,
        help='model selection, choices: vgg, mobilenetv2, resnet18',
        choices=["vgg", "mobilenetv2", "resnet18", "mobilenet"])
    parser.add_argument('--r',
                        dest="resume",
                        action='store_true',
                        default=False,
                        help='Resume from checkpoint')
    parser.add_argument('--save', default='model', help='model file')
    parser.add_argument('--prune_fname',
                        default='filename',
                        help='prune save file')
    parser.add_argument('--descent_idx',
                        type=int,
                        default=14,
                        help='Iteration for Architecture Descent')
    parser.add_argument('--s',
                        type=float,
                        default=0.0001,
                        help='scale sparse rate (default: 0.0001)')

    args = parser.parse_args()

    ##################
    ## Data loading ##
    ##################

    kwargs = {'num_workers': 1, 'pin_memory': True}
    if (args.dataset == "CIFAR10"):
        print("Using Cifar10 Dataset")
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR10(root='/DATA/data_cifar10/',
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "CIFAR100":
        print("Using Cifar100 Dataset")
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        trainset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        testset = torchvision.datasets.CIFAR100(root='/DATA/data_cifar100/',
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)
    elif args.dataset == "Imagenet":
        print("Using Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        if args.use_test_as_train:
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    elif args.dataset == "tinyimagenet":
        print("Using tiny-Imagenet Dataset")
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'test')

        normalize = transforms.Normalize([0.4802, 0.4481, 0.3975],
                                         [0.2302, 0.2265, 0.2262])
        train_dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomRotation(20),
                # transforms.RandomResizedCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(
                valdir, transforms.Compose([
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            **kwargs)
    else:
        print("Dataset does not exist! [CIFAR10, MNIST, tinyimagenet]")
        exit()

    if args.dataset == 'CIFAR10':
        num_classes = 10
    elif args.dataset == 'CIFAR100':
        num_classes = 100
    elif args.dataset == 'tinyimagenet':
        num_classes = 200

    ratio = 1
    pruned_filters = None
    ###########
    ## Model ##
    ###########
    print("Setting Up Model...")
    if args.model == "vgg":
        model = vgg11(ratio=ratio,
                      neuralscale=False,
                      num_classes=num_classes,
                      prune_fname=args.prune_fname,
                      descent_idx=args.descent_idx,
                      pruned_filters=pruned_filters)
    elif args.model == "resnet18":
        model = PreActResNet18(ratio=ratio,
                               neuralscale=False,
                               num_classes=num_classes,
                               dataset=args.dataset,
                               prune_fname=args.prune_fname,
                               descent_idx=args.descent_idx,
                               pruned_filters=pruned_filters,
                               search=True)
    elif args.model == "mobilenetv2":
        model = MobileNetV2(ratio=ratio,
                            neuralscale=False,
                            num_classes=num_classes,
                            dataset=args.dataset,
                            prune_fname=args.prune_fname,
                            descent_idx=args.descent_idx,
                            pruned_filters=pruned_filters,
                            search=True)
    else:
        print(args.model, "model not supported")
        exit()
    print("{} set up.".format(args.model))

    # for model saving
    model_path = "saved_models"
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    log_save_folder = "%s/%s" % (model_path, args.model)
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    model_save_path = "%s/%s" % (log_save_folder, args.save) + "_checkpoint.t7"
    model_state_dict = model.state_dict()
    if args.save:
        print("Model will be saved to {}".format(model_save_path))
        save_checkpoint({'state_dict': model_state_dict},
                        False,
                        filename=model_save_path)
    else:
        print("Save path not defined. Model will not be saved.")

    # Assume cuda is available and uses single GPU
    model.cuda()
    cudnn.benchmark = True

    # define objective
    criterion = nn.CrossEntropyLoss()

    ######################
    ## Set up pruning   ##
    ######################
    # remove updates from gate layers, because we want them to be 0 or 1 constantly
    parameters_for_update = []
    parameters_for_update_named = []
    for name, m in model.named_parameters():
        if "gate" not in name:
            parameters_for_update.append(m)
            parameters_for_update_named.append((name, m))
        else:
            print("skipping parameter", name, "shape:", m.shape)

    total_size_params = sum(
        [np.prod(par.shape) for par in parameters_for_update])
    print("Total number of parameters, w/o usage of bn consts: ",
          total_size_params)

    optimizer = optim.SGD(parameters_for_update,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    ###############
    ## Training  ##
    ###############
    best_test_acc = 0
    train_acc_plt = []
    train_loss_plt = []
    test_acc_plt = []
    test_loss_plt = []
    epoch_plt = []
    if num_classes == 200:  # tinyimagenet
        args.epochs = 150
    else:
        args.epochs = 300
    for epoch in range(1, args.epochs + 1):
        if num_classes == 200:  # tinyimagenet
            adjust_learning_rate_imagenet(args, optimizer, epoch, search=False)
        else:
            adjust_learning_rate(args, optimizer, epoch)

        print("Epoch: {}".format(epoch))

        # train model
        train_acc, train_loss = train(args, model, train_loader, optimizer,
                                      epoch, criterion)

        # evaluate on validation set
        test_acc, test_loss = validate(args,
                                       test_loader,
                                       model,
                                       criterion,
                                       epoch,
                                       optimizer=optimizer)

        # remember best prec@1 and save checkpoint
        is_best = test_acc > best_test_acc
        best_test_acc = max(test_acc, best_test_acc)
        model_state_dict = model.state_dict()
        if args.save:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': test_acc,
                },
                is_best,
                filename=model_save_path)

        train_acc_plt.append(train_acc)
        train_loss_plt.append(train_loss)
        test_acc_plt.append(test_acc)
        test_loss_plt.append(test_loss)
        epoch_plt.append(epoch)

    pickle_save = {
        "ratio": ratio,
        "train_acc": train_acc_plt,
        "train_loss": train_loss_plt,
        "test_acc": test_acc_plt,
        "test_loss": test_loss_plt,
    }
    plot_path = "saved_plots"
    if not os.path.exists(plot_path):
        os.makedirs(plot_path)

    log_save_folder = "%s/%s" % (plot_path, args.model)
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    pickle_out = open(
        "%s/%s_%s.pk" % (log_save_folder, args.save, int(ratio * 100)), "wb")
    pickle.dump(pickle_save, pickle_out)
    pickle_out.close()