Exemplo n.º 1
0
def main():
    """
    --------------------------------------------- MAIN --------------------------------------------------------

    Instantiates the model plus loss function and defines the dataloaders for several datasets including some
    data augmentation.
    Defines the grid for a grid search on lambda_max_divrs and initial_centroid_value_multipliers which both
    have a big influence on the sparsity (and respectively accuracy) of the resulting ternary networks.
    Starts grid search.
    """

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print('Building MicroNet for CIFAR-100 with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'imagenet_micronet':
        print('Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = image_micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST
    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.val_batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.val_batch_size, shuffle=True, **kwargs)

    elif args.dataset == 'CIFAR10':

        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)


    # Gridsearch on dividers for lambda_max and initial cluster center values
    for initial_c_divr in args.ini_c_divrs:
        for lambda_max_divr in args.lambda_max_divrs:
            print('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            logfile = open('./model_quantization/logfiles/logfile.txt', 'a+')
            logfile.write('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            grid_search(train_loader, val_loader, model, loss_fct, lambda_max_divr, initial_c_divr)
def main():

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print(
            'Building MicroNet for CIFAR with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        if args.dataset == 'CIFAR100':
            num_classes = 100
        elif args.dataset == 'CIFAR10':
            num_classes = 10
        model = micronet(args.dw_multps[0]**args.phi,
                         args.dw_multps[1]**args.phi, num_classes)

    elif args.model == 'image_micronet':
        print(
            'Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = image_micronet(args.dw_multps[0]**args.phi,
                               args.dw_multps[1]**args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    elif args.model == 'lenet-5':
        print(
            'Building LeNet-5 with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = lenet5(d_multiplier=args.dw_multps[0]**args.phi,
                       w_multiplier=args.dw_multps[1]**args.phi)

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST

    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  normalize,
                              ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(
                args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.val_batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.data_path,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.val_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == 'CIFAR10':

        print('Loading CIFAR-10 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)

    train_w_frozen_assignment(train_loader, val_loader, model, loss_fct)