Exemplo n.º 1
0
def main():
    """
    --------------------------------------------- MAIN --------------------------------------------------------

    Instantiates the model plus loss function and defines the dataloaders for several datasets including some
    data augmentation.
    Defines the grid for a grid search on lambda_max_divrs and initial_centroid_value_multipliers which both
    have a big influence on the sparsity (and respectively accuracy) of the resulting ternary networks.
    Starts grid search.
    """

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print('Building MicroNet for CIFAR-100 with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'imagenet_micronet':
        print('Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = image_micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST
    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.val_batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.val_batch_size, shuffle=True, **kwargs)

    elif args.dataset == 'CIFAR10':

        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)


    # Gridsearch on dividers for lambda_max and initial cluster center values
    for initial_c_divr in args.ini_c_divrs:
        for lambda_max_divr in args.lambda_max_divrs:
            print('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            logfile = open('./model_quantization/logfiles/logfile.txt', 'a+')
            logfile.write('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            grid_search(train_loader, val_loader, model, loss_fct, lambda_max_divr, initial_c_divr)
def grid_search(train_loader, val_loader, criterion, alpha, beta):
    """
    Builds the model with given scaling factors, sets up optimizer and learning rate schedulers plus executes
    training and evaluation of the model. A checkpoint is created each epoch. Also the best model will be saved.

    Parameters:
    -----------
        train_loader:
            PyTorch Dataloader for given train dataset.
        val_loader:
            PyTorch Dataloader for given validation dataset.
        criterion:
            Loss function to use (e.g. cross entropy, MSE, ...).
        alpha:
            Scaling factor for model depth.
        beta:
            Scaling factor for model width.
    """

    # Initializing training variables
    best_acc = 0
    all_losses = []

    # Initializing log file
    logfile = open('./model_compound_scaling/logfiles/logfile.txt', 'a+')
    logfile.write('depth multiplier: {}, width multiplier: {}\n'.format(
        alpha, beta))

    # Building the model
    if args.dataset == 'CIFAR100' or args.dataset == 'CIFAR10':
        model = micronet(d_multiplier=alpha, w_multiplier=beta)

    elif args.dataset == 'ImageNet':
        model = image_micronet(d_multiplier=alpha, w_multiplier=beta)

    # If multipile GPUs are used
    if use_cuda and torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    # Transfers model to device (GPU/CPU). Device is globally initialized.
    model.to(device)

    # Defining the optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # KERAS like summary of the model architecture
    # summary(your_model, input_size=(channels, H, W), batch_size=-1, device="cuda")
    if use_cuda:
        if args.dataset == 'CIFAR100' or args.dataset == 'CIFAR10':
            summary(model, (3, 32, 32), batch_size=args.batch_size)
            print(model)

        elif args.dataset == 'ImageNet':
            summary(model, (3, args.image_size, args.image_size),
                    batch_size=args.batch_size)
            print(model)

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            load_last_epoch = checkpoint['epoch'] - 1
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            load_last_epoch = -1

    # Learning rate schedulers for cifar_micronet and imagenet_micronet
    if args.dataset == 'CIFAR100' or args.data == 'CIFAR10':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.epochs,
            eta_min=0,
            last_epoch=load_last_epoch)

    elif args.dataset == 'ImageNet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[30, 60, 90],
            gamma=0.1,
            last_epoch=load_last_epoch)

    # START TRAINING
    start_time = time.time()
    model.train()

    for epoch in range(args.start_epoch, args.epochs):

        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))

        # Executing training process
        running_loss, running_accuracy = train(train_loader, model, criterion,
                                               optimizer, epoch)

        # Evaluation
        model.eval()
        val_loss, val_accuracy = evaluate(model, criterion, val_loader)

        # Logging the accuracies
        all_losses += [(epoch, running_loss, val_loss, running_accuracy,
                        val_accuracy)]
        print(
            'Epoch {0} running loss {1:.3f} val loss {2:.3f}  running acc {3:.3f} '
            'val acc{4:.3f}  time {5:.3f}'.format(*all_losses[-1],
                                                  time.time() - start_time))
        logfile.write(
            'Epoch {0} running loss {1:.3f} val loss {2:.3f}  running acc {3:.3f} '
            'val acc{4:.3f}  time {5:.3f}\n'.format(*all_losses[-1],
                                                    time.time() - start_time))

        # Saving checkpoint
        torch.save(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'acc': val_accuracy,
                'lr': optimizer.param_groups[0]['lr']
            }, args.resume)

        # Make a lr scheduler step
        lr_scheduler.step()

        # Checking if current epoch yielded best validation accuracy
        is_best = val_accuracy > best_acc
        best_acc = max(val_accuracy, best_acc)

        # If so, saving best model state_dict
        if is_best and epoch > 0:
            torch.save(model.state_dict(),
                       './model_compound_scaling/saved_models/best_model.pt')

        # Switch back to train mode
        model.train()
        start_time = time.time()
def main():

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print(
            'Building MicroNet for CIFAR with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        if args.dataset == 'CIFAR100':
            num_classes = 100
        elif args.dataset == 'CIFAR10':
            num_classes = 10
        model = micronet(args.dw_multps[0]**args.phi,
                         args.dw_multps[1]**args.phi, num_classes)

    elif args.model == 'image_micronet':
        print(
            'Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = image_micronet(args.dw_multps[0]**args.phi,
                               args.dw_multps[1]**args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    elif args.model == 'lenet-5':
        print(
            'Building LeNet-5 with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = lenet5(d_multiplier=args.dw_multps[0]**args.phi,
                       w_multiplier=args.dw_multps[1]**args.phi)

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST

    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  normalize,
                              ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(
                args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.val_batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.data_path,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.val_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == 'CIFAR10':

        print('Loading CIFAR-10 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)

    train_w_frozen_assignment(train_loader, val_loader, model, loss_fct)