Beispiel #1
0
def get_model(arch_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    else:
        print("\nWrong Model choice\n")
        exit()
    return model
Beispiel #2
0
def main(args, ITE=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reinit = True if args.prune_type == "reinit" else False
    if args.save_dir:
        utils.checkdir(
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
    else:
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")

    # Data Loader
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    if args.dataset == "mnist":
        traindataset = datasets.MNIST('../data',
                                      train=True,
                                      download=True,
                                      transform=transform)
        testdataset = datasets.MNIST('../data',
                                     train=False,
                                     transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar10":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        traindataset = datasets.CIFAR10('../data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
        testdataset = datasets.CIFAR10('../data',
                                       train=False,
                                       transform=transform_test)
        from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet

    elif args.dataset == "fashionmnist":
        traindataset = datasets.FashionMNIST('../data',
                                             train=True,
                                             download=True,
                                             transform=transform)
        testdataset = datasets.FashionMNIST('../data',
                                            train=False,
                                            transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar100":
        traindataset = datasets.CIFAR100('../data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        testdataset = datasets.CIFAR100('../data',
                                        train=False,
                                        transform=transform)
        from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet

    # If you want to add extra datasets paste here

    else:
        print("\nWrong Dataset choice \n")
        exit()

    if args.dataset == "cifar10":
        #trainsampler = torch.utils.data.RandomSampler(traindataset, replacement=True, num_samples=45000)  # 45K train dataset
        #train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False, sampler=trainsampler)
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)
    else:
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)
    #train_loader = cycle(train_loader)
    test_loader = torch.utils.data.DataLoader(testdataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4)

    # Importing Network Architecture

    #Initalize hessian dataloader, default batch_num 1
    for inputs, labels in train_loader:
        hessian_dataloader = (inputs, labels)
        break

    global model
    if args.arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif args.arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif args.arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif args.arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif args.arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif args.arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    else:
        print("\nWrong Model choice\n")
        exit()

    model = nn.DataParallel(model)
    # Weight Initialization
    model.apply(weight_init)

    # Copying and Saving Initial State
    initial_state_dict = copy.deepcopy(model.state_dict())
    if args.save_dir:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/initial_state_dict_{args.prune_type}.pth"
        )
    else:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth"
        )

    # global total_params
    total_params = 0
    # Layer Looper
    for name, param in model.named_parameters():
        print(name, param.size())
        total_params += param.numel()

    # Making Initial Mask
    make_mask(model, total_params)

    # Optimizer and Loss
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    # warm up schedule; scheduler_warmup is chained with schduler_steplr
    scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                            milestones=[0, 15],
                                                            gamma=0.1,
                                                            last_epoch=-1)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(
            optimizer,
            multiplier=1,
            total_epoch=50,
            after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
    criterion = nn.CrossEntropyLoss(
    )  # Default was F.nll_loss; why test, train different?

    # Pruning
    # NOTE First Pruning Iteration is of No Compression
    bestacc = 0.0
    best_accuracy = 0
    ITERATION = args.prune_iterations
    comp = np.zeros(ITERATION, float)
    bestacc = np.zeros(ITERATION, float)
    step = 0
    all_loss = np.zeros(args.end_iter, float)
    all_accuracy = np.zeros(args.end_iter, float)

    for _ite in range(args.start_iter, ITERATION):
        if not _ite == 0:
            prune_by_percentile(args.prune_percent,
                                resample=resample,
                                reinit=reinit,
                                total_params=total_params,
                                hessian_aware=args.hessian,
                                criterion=criterion,
                                dataloader=hessian_dataloader,
                                cuda=torch.cuda.is_available())
            if reinit:
                model.apply(weight_init)
                #if args.arch_type == "fc1":
                #    model = fc1.fc1().to(device)
                #elif args.arch_type == "lenet5":
                #    model = LeNet5.LeNet5().to(device)
                #elif args.arch_type == "alexnet":
                #    model = AlexNet.AlexNet().to(device)
                #elif args.arch_type == "vgg16":
                #    model = vgg.vgg16().to(device)
                #elif args.arch_type == "resnet18":
                #    model = resnet.resnet18().to(device)
                #elif args.arch_type == "densenet121":
                #    model = densenet.densenet121().to(device)
                #else:
                #    print("\nWrong Model choice\n")
                #    exit()
                step = 0
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        param_frac = param.numel() / total_params
                        if param_frac > 0.01:
                            weight_dev = param.device
                            param.data = torch.from_numpy(
                                param.data.cpu().numpy() *
                                mask[step]).to(weight_dev)
                            step = step + 1
                step = 0
            else:
                original_initialization(mask, initial_state_dict, total_params)
            # optimizer = torch.optim.SGD([{'params': model.parameters(), 'initial_lr': 0.03}], lr=args.lr, momentum=0.9, weight_decay=1e-4)
            # scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0, 14], gamma=0.1, last_epoch=-1)
            # scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=56, after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
        print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")

        # Optimizer and Loss
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        # warm up schedule; scheduler_warmup is chained with schduler_steplr
        scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0, 15], gamma=0.1, last_epoch=-1)
        if args.warmup:
            scheduler_warmup = GradualWarmupScheduler(
                optimizer,
                multiplier=1,
                total_epoch=50,
                after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70

        # Print the table of Nonzeros in each layer
        comp1 = utils.print_nonzeros(model)
        comp[_ite] = comp1
        pbar = tqdm(range(args.end_iter))  # process bar

        for iter_ in pbar:

            # Frequency for Testing
            if iter_ % args.valid_freq == 0:
                accuracy = test(model, test_loader, criterion)

                # Save Weights for each _ite
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    if args.save_dir:
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}.pth"
                        )
                    else:
                        # torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth")
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth"
                        )

            # Training
            loss = train(model, train_loader, optimizer, criterion,
                         total_params)
            all_loss[iter_] = loss
            all_accuracy[iter_] = accuracy

            # warm up
            if args.warmup:
                scheduler_warmup.step()
            _lr = optimizer.param_groups[0]['lr']

            # Save the model during training
            if args.save_freq > 0 and iter_ % args.save_freq == 0:
                if args.save_dir:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )

            # Frequency for Printing Accuracy and Loss
            if iter_ % args.print_freq == 0:
                pbar.set_description(
                    f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% Learning Rate: {_lr:.6f}%'
                )

        writer.add_scalar('Accuracy/test', best_accuracy, comp1)
        bestacc[_ite] = best_accuracy

        # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
        #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
        #NOTE Normalized the accuracy to [0,100] for ease of plotting.
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_loss - np.min(all_loss)) /
                 np.ptp(all_loss).astype(float),
                 c="blue",
                 label="Loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 all_accuracy,
                 c="red",
                 label="Accuracy")
        plt.title(
            f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})"
        )
        plt.xlabel("Iterations")
        plt.ylabel("Loss and Accuracy")
        plt.legend()
        plt.grid(color="gray")
        if args.save_dir:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        else:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        plt.close()

        # Dump Plot values
        if args.save_dir:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )
        else:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )

        # Dumping mask
        if args.save_dir:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)
        else:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)

        # Making variables into 0
        best_accuracy = 0
        all_loss = np.zeros(args.end_iter, float)
        all_accuracy = np.zeros(args.end_iter, float)

    # Dumping Values for Plotting
    if args.save_dir:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_bestaccuracy.dat"
        )
    else:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat"
        )
    # Plotting
    a = np.arange(args.prune_iterations)
    plt.plot(a, bestacc, c="blue", label="Winning tickets")
    plt.title(
        f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})"
    )
    plt.xlabel("Unpruned Weights Percentage")
    plt.ylabel("test accuracy")
    plt.xticks(a, comp, rotation="vertical")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(color="gray")
    if args.save_dir:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    else:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    plt.close()
Beispiel #3
0
def main(args, ITE=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reinit = True if args.prune_type == "reinit" else False

    # Data Loader
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    if args.dataset == "mnist":
        traindataset = datasets.MNIST('../data',
                                      train=True,
                                      download=True,
                                      transform=transform)
        testdataset = datasets.MNIST('../data',
                                     train=False,
                                     transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar10":
        traindataset = datasets.CIFAR10('../data',
                                        train=True,
                                        download=True,
                                        transform=transform)
        testdataset = datasets.CIFAR10('../data',
                                       train=False,
                                       transform=transform)
        from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet

    elif args.dataset == "fashionmnist":
        traindataset = datasets.FashionMNIST('../data',
                                             train=True,
                                             download=True,
                                             transform=transform)
        testdataset = datasets.FashionMNIST('../data',
                                            train=False,
                                            transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar100":
        traindataset = datasets.CIFAR100('../data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        testdataset = datasets.CIFAR100('../data',
                                        train=False,
                                        transform=transform)
        from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet

    # If you want to add extra datasets paste here

    else:
        print("\nWrong Dataset choice \n")
        exit()

    train_loader = torch.utils.data.DataLoader(traindataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0,
                                               drop_last=False)
    #train_loader = cycle(train_loader)
    test_loader = torch.utils.data.DataLoader(testdataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              drop_last=True)

    # Importing Network Architecture
    global model
    if args.arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif args.arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif args.arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif args.arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif args.arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif args.arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    else:
        print("\nWrong Model choice\n")
        exit()

    # Weight Initialization
    model.apply(weight_init)

    # Copying and Saving Initial State
    initial_state_dict = copy.deepcopy(model.state_dict())
    utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
    torch.save(
        model,
        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar"
    )

    # Making Initial Mask
    make_mask(model)

    # Optimizer and Loss
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()  # Default was F.nll_loss

    # Layer Looper
    for name, param in model.named_parameters():
        print(name, param.size())

    # Pruning
    # NOTE First Pruning Iteration is of No Compression
    bestacc = 0.0
    best_accuracy = 0
    ITERATION = args.prune_iterations
    comp = np.zeros(ITERATION, float)
    bestacc = np.zeros(ITERATION, float)
    step = 0
    all_loss = np.zeros(args.end_iter, float)
    all_accuracy = np.zeros(args.end_iter, float)

    for _ite in range(args.start_iter, ITERATION):
        if not _ite == 0:
            prune_by_percentile(args.prune_percent,
                                resample=resample,
                                reinit=reinit)
            if reinit:
                model.apply(weight_init)
                #if args.arch_type == "fc1":
                #    model = fc1.fc1().to(device)
                #elif args.arch_type == "lenet5":
                #    model = LeNet5.LeNet5().to(device)
                #elif args.arch_type == "alexnet":
                #    model = AlexNet.AlexNet().to(device)
                #elif args.arch_type == "vgg16":
                #    model = vgg.vgg16().to(device)
                #elif args.arch_type == "resnet18":
                #    model = resnet.resnet18().to(device)
                #elif args.arch_type == "densenet121":
                #    model = densenet.densenet121().to(device)
                #else:
                #    print("\nWrong Model choice\n")
                #    exit()
                step = 0
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        weight_dev = param.device
                        param.data = torch.from_numpy(
                            param.data.cpu().numpy() *
                            mask[step]).to(weight_dev)
                        step = step + 1
                step = 0
            else:
                original_initialization(mask, initial_state_dict)
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=1e-4)
        print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")

        # Print the table of Nonzeros in each layer
        comp1 = utils.print_nonzeros(model)
        comp[_ite] = comp1
        pbar = tqdm(range(args.end_iter))

        for iter_ in pbar:

            # Frequency for Testing
            if iter_ % args.valid_freq == 0:
                accuracy = test(model, test_loader, criterion)

                # Save Weights
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    utils.checkdir(
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/"
                    )
                    torch.save(
                        model,
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar"
                    )

            # Training
            if _ite == 0:
                loss = train(model, train_loader, optimizer, criterion)
                #needed to be completed
                #teacher_model = ...
            else:
                loss = train_with_distill(model, train_loader, optimizer,
                                          teacher_model)

            # Frequency for Printing Accuracy and Loss
            if iter_ % args.print_freq == 0:
                pbar.set_description(
                    f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%'
                )

        writer.add_scalar('Accuracy/test', best_accuracy, comp1)
        bestacc[_ite] = best_accuracy

        # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
        #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
        #NOTE Normalized the accuracy to [0,100] for ease of plotting.
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_loss - np.min(all_loss)) /
                 np.ptp(all_loss).astype(float),
                 c="blue",
                 label="Loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 all_accuracy,
                 c="red",
                 label="Accuracy")
        plt.title(
            f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})"
        )
        plt.xlabel("Iterations")
        plt.ylabel("Loss and Accuracy")
        plt.legend()
        plt.grid(color="gray")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
            dpi=1200)
        plt.close()

        # Dump Plot values
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
        all_loss.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat"
        )
        all_accuracy.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat"
        )

        # Dumping mask
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
        with open(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl",
                'wb') as fp:
            pickle.dump(mask, fp)

        # Making variables into 0
        best_accuracy = 0
        all_loss = np.zeros(args.end_iter, float)
        all_accuracy = np.zeros(args.end_iter, float)

    # Dumping Values for Plotting
    utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
    comp.dump(
        f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat"
    )
    bestacc.dump(
        f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat"
    )

    # Plotting
    a = np.arange(args.prune_iterations)
    plt.plot(a, bestacc, c="blue", label="Winning tickets")
    plt.title(
        f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})"
    )
    plt.xlabel("Unpruned Weights Percentage")
    plt.ylabel("test accuracy")
    plt.xticks(a, comp, rotation="vertical")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(color="gray")
    utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
    plt.savefig(
        f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png",
        dpi=1200)
    plt.close()
Beispiel #4
0
def main(args, ITE=0):
    import pandas as pd
    pd.set_option('display.width', 400)
    pd.set_option('display.max_columns', 10)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reinit = True if args.prune_type == "reinit" else False
    layerwise = True if args.prune_type == "layerwise" else False

    # Data Loader
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    transform_cifar10 = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if args.dataset == "mnist":
        traindataset = datasets.MNIST('../data',
                                      train=True,
                                      download=True,
                                      transform=transform)
        testdataset = datasets.MNIST('../data',
                                     train=False,
                                     transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar10":
        traindataset = datasets.CIFAR10('../data',
                                        train=True,
                                        download=True,
                                        transform=transform_cifar10)
        testdataset = datasets.CIFAR10('../data',
                                       train=False,
                                       transform=transform_cifar10)
        from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet, minivgg

    elif args.dataset == "fashionmnist":
        traindataset = datasets.FashionMNIST('../data',
                                             train=True,
                                             download=True,
                                             transform=transform)
        testdataset = datasets.FashionMNIST('../data',
                                            train=False,
                                            transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar100":
        traindataset = datasets.CIFAR100('../data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        testdataset = datasets.CIFAR100('../data',
                                        train=False,
                                        transform=transform)
        from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet

    # If you want to add extra datasets paste here

    else:
        print("\nWrong Dataset choice \n")
        exit()

    # obtain training indices that will be used for validation
    if args.early_stopping:
        print(' Splitting Validation sets ')
        trainset_size = int((1 - args.valid_size) * len(traindataset))
        valset_size = len(traindataset) - trainset_size
        trainset, valset = torch.utils.data.random_split(
            traindataset, [trainset_size, valset_size])

        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)
        valid_loader = torch.utils.data.DataLoader(valset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)

    else:
        print(' Eww, no validation set? ')
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)

    # train_loader = cycle(train_loader)
    test_loader = torch.utils.data.DataLoader(testdataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              drop_last=True)

    # Importing Network Architecture
    global model
    if args.arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif args.arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif args.arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif args.arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif args.arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif args.arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    elif args.arch_type == "conv2":
        model = minivgg.conv2().to(device)
    elif args.arch_type == "conv4":
        model = minivgg.conv4().to(device)
    elif args.arch_type == "conv6":
        model = minivgg.conv6().to(device)

    else:
        print("\nWrong Model choice\n")
        exit()

    # Weight Initialization. Warning! This drops test acc, so i'm examining this function.
    model.apply(weight_init)

    # get time for file path
    import datetime
    now = datetime.datetime.now()
    now_ = now.strftime("%02m%02d%02H%02M_")

    # Copying and Saving Initial State
    print('  saving initial model... ')
    initial_state_dict = copy.deepcopy(model.state_dict())
    utils.checkdir(
        f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/")
    torch.save(
        model,
        f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar"
    )
    print(
        "  initial model saved in ",
        f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar"
    )

    # Making Initial Mask
    make_mask(model)

    # Optimizer and Loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()  # Default was F.nll_loss

    # Layer Looper
    for name, param in model.named_parameters():
        print(name, param.size())

    # Pruning
    # NOTE First Pruning Iteration is of No Compression
    bestacc = 0.0
    best_accuracy = 0
    ITERATION = args.prune_iterations
    comp = np.zeros(ITERATION, float)
    bestacc = np.zeros(ITERATION, float)
    step = 0
    all_loss = np.zeros(args.end_iter, float)
    all_vloss = np.zeros(args.end_iter, float)
    all_accuracy = np.zeros(args.end_iter, float)

    for _ite in range(args.start_iter, ITERATION):

        # Early stopping parameter for each pruning iteration
        early_stopping = EarlyStopping(
            patience=99,
            verbose=True)  ######### we don't stop, party all night

        if not _ite == 0:
            prune_by_percentile(args.prune_percent,
                                args.fc_prune_percent,
                                resample=resample,
                                reinit=reinit,
                                layerwise=layerwise,
                                if_split=args.split_conv_and_fc)
            if reinit:
                model.apply(weight_init)
                #if args.arch_type == "fc1":
                #    model = fc1.fc1().to(device)
                #elif args.arch_type == "lenet5":
                #    model = LeNet5.LeNet5().to(device)
                #elif args.arch_type == "alexnet":
                #    model = AlexNet.AlexNet().to(device)
                #elif args.arch_type == "vgg16":
                #    model = vgg.vgg16().to(device)
                #elif args.arch_type == "resnet18":
                #    model = resnet.resnet18().to(device)
                #elif args.arch_type == "densenet121":
                #    model = densenet.densenet121().to(device)
                #else:
                #    print("\nWrong Model choice\n")
                #    exit()
                step = 0
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        weight_dev = param.device
                        param.data = torch.from_numpy(
                            param.data.cpu().numpy() *
                            mask[step]).to(weight_dev)
                        step = step + 1
                step = 0
            else:
                original_initialization(mask, initial_state_dict)
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=1e-4)

        time.sleep(0.25)
        print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")

        # Print the table of Nonzeros in each layer
        comp1 = utils.print_nonzeros(model)
        comp[_ite] = comp1

        # pbar = range(args.end_iter)
        pbar = tqdm(range(args.end_iter))

        stop_flag = False
        for iter_ in pbar:

            # Frequency for Testing
            if iter_ % args.valid_freq == 0:
                accuracy = test(model, test_loader, criterion)

                # Save Weights
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    # We don't save model per test-acc, will use validation-acc!
                    # utils.checkdir(f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/")
                    # torch.save(model,f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar")

            # Training
            loss = train(model, train_loader, optimizer, criterion)

            all_loss[iter_] = loss
            all_accuracy[iter_] = accuracy

            # Validating
            valid_loss, loss_v = validate(model, valid_loader, optimizer,
                                          criterion)
            all_vloss[iter_] = valid_loss  #loss_v

            # early stopping
            checkpoint_path = f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/"
            save_path = f"{os.getcwd()}/saves/{now_}{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar"
            # msg = early_stopping(valid_loss, model, checkpoint_path, save_path)
            early_stopping(valid_loss, model, checkpoint_path, save_path)

            # Frequency for Printing Accuracy and Loss
            if iter_ % args.print_freq == 0:
                time.sleep(0.25)
                pbar.set_description(
                    # f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% \t' + msg)
                    f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} V-Loss: {valid_loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%'
                )
                if iter_ % 5 == 4:
                    print('')

            if early_stopping.early_stop and not stop_flag:
                print("Early stopping")
                stop_flag = True
                # break

        writer.add_scalar('Accuracy/test', best_accuracy, comp1)
        bestacc[_ite] = best_accuracy

        # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
        #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
        #NOTE Normalized the accuracy to [0,100] for ease of plotting.
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_loss - np.min(all_loss)) /
                 np.ptp(all_loss).astype(float),
                 c="blue",
                 label="Train loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_vloss - np.min(all_vloss)) /
                 np.ptp(all_vloss).astype(float),
                 c="green",
                 label="Valid loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 all_accuracy,
                 c="red",
                 label="Accuracy")
        plt.title(
            f"Loss Vs Accuracy Vs Iterations ({args.dataset},{now_}{args.arch_type})"
        )
        plt.xlabel("Iterations")
        plt.ylabel("Loss and Accuracy")
        plt.legend()
        plt.grid(color="gray")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/")
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
            dpi=300)
        plt.close()

        # Dump Plot values
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/")
        all_loss.dump(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat"
        )
        all_accuracy.dump(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat"
        )

        # Dumping mask
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/")
        with open(
                f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl",
                'wb') as fp:
            pickle.dump(mask, fp)

        # Making variables into 0
        best_accuracy = 0
        all_loss = np.zeros(args.end_iter, float)
        all_accuracy = np.zeros(args.end_iter, float)

        # Dumping Values for Plotting
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/")
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat"
        )

        # Plotting
        a = np.arange(args.prune_iterations)
        plt.plot(a, bestacc, c="blue", label="Winning tickets")
        plt.title(
            f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{now_}{args.arch_type})"
        )
        plt.xlabel("Unpruned Weights Percentage")
        plt.ylabel("test accuracy")
        plt.xticks(a, comp, rotation="vertical")
        plt.ylim(0, 100)
        plt.legend()
        plt.grid(color="gray")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/")
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{now_}{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=300)
        plt.close()

    print('Training ended~~~')