Exemple #1
0
def main_worker(gpu, ngpus_per_node, config):
    global best_acc1
    config.gpu = gpu

    if config.gpu is not None:
        print("Use GPU: {} for training".format(config.gpu))

    if config.distributed:
        if config.dist_url == "env://" and config.rank == -1:
            config.rank = int(os.environ["RANK"])
        if config.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            config.rank = config.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=config.dist_backend,
                                init_method=config.dist_url,
                                world_size=config.world_size,
                                rank=config.rank)
    # create model
    if config.pretrained:
        print("=> using pre-trained model '{}'".format(config.arch))

        model = models.__dict__[config.arch](pretrained=True)
        print(model)
        param_names = []
        module_names = []
        for name, W in model.named_modules():
            module_names.append(name)
        print(module_names)
        for name, W in model.named_parameters():
            param_names.append(name)
        print(param_names)
    else:
        print("=> creating model '{}'".format(config.arch))
        if config.arch == "alexnet_bn":
            model = AlexNet_BN()
            print(model)
            for i, (name, W) in enumerate(model.named_parameters()):
                print(name)
        else:
            model = models.__dict__[config.arch]()
            print(model)

    if config.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config.gpu is not None:
            torch.cuda.set_device(config.gpu)
            model.cuda(config.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config.batch_size = int(config.batch_size / ngpus_per_node)
            config.workers = int(config.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[config.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif config.gpu is not None:
        torch.cuda.set_device(config.gpu)
        model = model.cuda(config.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if config.arch.startswith('alexnet') or config.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:

            model = torch.nn.DataParallel(model).cuda()
    config.model = model
    # define loss function (criterion) and optimizer

    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=config.smooth_eps).cuda(
        config.gpu)

    config.smooth = config.smooth_eps > 0.0
    config.mixup = config.alpha > 0.0

    # note that loading a pretrain model does not inherit optimizer info
    # will use resume to resume admm training
    if config.load_model:
        if os.path.isfile(config.load_model):
            if (config.gpu):
                model.load_state_dict(
                    torch.load(
                        config.load_model,
                        map_location={'cuda:0': 'cuda:{}'.format(config.gpu)}))
            else:
                model.load_state_dict(torch.load(config.load_model))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))

    config.prepare_pruning()

    nonzero = 0
    zero = 0
    for name, W in model.named_parameters():
        if name in config.conv_names:
            W = W.cpu().detach().numpy()
            zero += np.sum(W == 0)
            nonzero += np.sum(W != 0)
    total = nonzero + zero
    print('compression rate is {}'.format(total * 1.0 / nonzero))
    import sys
    sys.exit()

    # optionally resume from a checkpoint
    if config.resume:
        ## will add logic for loading admm variables
        if os.path.isfile(config.resume):
            print("=> loading checkpoint '{}'".format(config.resume))
            checkpoint = torch.load(config.resume)
            config.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']

            ADMM.ADMM_U = checkpoint['admm']['ADMM_U']
            ADMM.ADMM_Z = checkpoint['admm']['ADMM_Z']

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                config.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(config.data, 'train')
    valdir = os.path.join(config.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if config.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=config.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=config.batch_size,
                                             shuffle=False,
                                             num_workers=config.workers,
                                             pin_memory=True)

    config.warmup = (not config.admm) and config.warmup_epochs > 0
    optimizer_init_lr = config.warmup_lr if config.warmup else config.lr

    optimizer = None
    if (config.optimizer == 'sgd'):
        optimizer = torch.optim.SGD(model.parameters(),
                                    optimizer_init_lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif (config.optimizer == 'adam'):
        optimizer = torch.optim.Adam(model.parameters(), optimizer_init_lr)

    scheduler = None
    if config.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=config.epochs *
                                                         len(train_loader),
                                                         eta_min=4e-08)
    elif config.lr_scheduler == 'default':
        # sets the learning rate to the initial LR decayed by gamma every 30 epochs"""
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=30 * len(train_loader),
                                              gamma=0.1)
    else:
        raise Exception("unknown lr scheduler")

    if config.warmup:
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=config.lr / config.warmup_lr,
            total_iter=config.warmup_epochs * len(train_loader),
            after_scheduler=scheduler)

    if False:
        validate(val_loader, criterion, config)
        return
    ADMM = None

    if config.verify:
        admm.masking(config)
        admm.test_sparsity(config)
        validate(val_loader, criterion, config)
        import sys
        sys.exit()
    if config.admm:
        ADMM = admm.ADMM(config)

    if config.masked_retrain:
        # make sure small weights are pruned and confirm the acc
        admm.masking(config)
        print("before retrain starts")
        admm.test_sparsity(config)
        validate(val_loader, criterion, config)
    if config.masked_progressive:
        admm.zero_masking(config)
    for epoch in range(config.start_epoch, config.epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch

        train(train_loader, config, ADMM, criterion, optimizer, scheduler,
              epoch)

        # evaluate on validation set
        acc1 = validate(val_loader, criterion, config)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if is_best and not config.admm:  # we don't need admm to have best validation acc
            print('saving new best model {}'.format(config.save_model))
            torch.save(model.state_dict(), config.save_model)

        if not config.multiprocessing_distributed or (
                config.multiprocessing_distributed
                and config.rank % ngpus_per_node == 0):
            save_checkpoint(
                config, {
                    'admm': {},
                    'epoch': epoch + 1,
                    'arch': config.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
    # save last model for admm, optimizer detail is not necessary
    if config.save_model and config.admm:
        print('saving model {}'.format(config.save_model))
        torch.save(model.state_dict(), config.save_model)
    if config.masked_retrain:
        print("after masked retrain")
        admm.test_sparsity(config)
        # my learning rate scheduler for cifar, following https://github.com/kuangliu/pytorch-cifar
        epoch_milestones = [150, 250, 350]
        """Set the learning rate of each parameter group to the initial lr decayed
            by gamma once the number of epoch reaches one of the milestones
        """
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[i * len(trainloader) for i in epoch_milestones],
            gamma=0.1)
    else:
        raise Exception("unknown lr scheduler")

    if config.warmup:
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=config.lr / config.warmup_lr,
            total_iter=config.warmup_epochs * len(trainloader),
            after_scheduler=scheduler)

    def train(train_loader, criterion, optimizer, epoch, config):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to train mode
        config.model.train()

        end = time.time()
        for i, (input, target) in enumerate(train_loader):
            # measure data loading time