コード例 #1
0
def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    lr = next(iter(optimizer.param_groups))['lr']
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))
    num_parameters = get_conv_zero_param(model)
    print('Zero parameters: {}'.format(num_parameters))
    num_parameters = sum([param.nelement() for param in model.parameters()])
    print('Parameters: {}'.format(num_parameters))
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        optimizer.zero_grad()
        output_tc = []
        with torch.no_grad():
            for model_tc in models:
                output_tc.append(model_tc(data))
        loss = reduce(lambda acc, elem: acc + criterion(output, elem),
                      output_tc, 0) / len(models)
        loss.backward()
        # zeros out gradient of pruned weights
        for k, m in enumerate(model.modules()):
            if isinstance(m, nn.Conv2d):
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(0).float().cuda()
                m.weight.grad.data.mul_(mask)
        # ------------------------------------
        optimizer.step()
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        lr_scheduler.step()
コード例 #2
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch
    
    os.makedirs(args.save_dir, exist_ok=True)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100


    trainset = dataloader(root='./data', train=True, download=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)

    testset = dataloader(root='./data', train=False, download=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    cardinality=args.cardinality,
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    cardinality=args.cardinality,
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    growthRate=args.growthRate,
                    compressionRate=args.compressionRate,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    growthRate=args.growthRate,
                    compressionRate=args.compressionRate,
                    dropRate=args.drop,
                )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)
        model_ref = models.__dict__[args.arch](num_classes=num_classes)

    model.cuda()
    model_ref.cuda()

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # default is 0.001

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Getting reference model from checkpoint..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        # args.save_dir = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = args.start_epoch
        model_ref.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_scratch.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    # set some weights to zero, according to model_ref ---------------------------------
    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            n = mask.sum() / float(m.in_channels)
            m.weight.data.normal_(0, math.sqrt(2. / n))
            m.weight.data.mul_(mask)

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum([param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.save_dir)

    logger.close()

    print('Best acc:')
    print(best_acc)
コード例 #3
0
ファイル: kd_ticket.py プロジェクト: lilujunai/KD-ticket
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    os.makedirs(args.save_dir, exist_ok=True)

    # ######################################### Dataset ################################################
    train_transforms = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        transforms.RandomResizedCrop(224, scale=(0.1, 1.0), ratio=(0.8, 1.25)),  # according to official open LTH

        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    # #################### train / valid dataset ####################
    train_dir = os.path.join(args.img_dir, 'train')
    valid_dir = os.path.join(args.img_dir, 'val')

    trainset = datasets.ImageFolder(root=train_dir, transform=train_transforms)
    devset = datasets.ImageFolder(root=valid_dir, transform=val_transforms)

    print('Total images in train, ', len(trainset))
    print('Total images in valid, ', len(devset))

    # #################### data loader ####################
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch,
                                  shuffle=True, num_workers=args.workers)

    devloader = data.DataLoader(devset, batch_size=args.test_batch,
                                shuffle=False, num_workers=args.workers)

    # ######################################### Model ##################################################
    print("==> creating model ResNet={}".format(args.depth))
    if args.depth == 18:
        model = resnet18(pretrained=False)
        model_ref = resnet18(pretrained=False)
        teacher_model = resnet18(pretrained=False)

    elif args.depth == 34:
        model = resnet34(pretrained=False)
        model_ref = resnet34(pretrained=False)
        teacher_model = resnet34(pretrained=False)

    elif args.depth == 50:
        model = resnet50(pretrained=False)
        model_ref = resnet50(pretrained=False)
        teacher_model = resnet50(pretrained=False)

    elif args.depth == 101:
        model = resnet101(pretrained=False)
        model_ref = resnet101(pretrained=False)
        teacher_model = resnet101(pretrained=False)

    elif args.depth == 152:
        model = resnet152(pretrained=False)
        model_ref = resnet152(pretrained=False)
        teacher_model = resnet152(pretrained=False)
    else:
        model = resnet50(pretrained=False)  # default Res-50
        model_ref = resnet50(pretrained=False)  # default Res-50
        teacher_model = resnet50(pretrained=False)  # default Res-50

    model.cuda(device_ids[0])                           # model to train (student model)
    model_ref.cuda(device_ids[0])                       # pruned model
    teacher_model.cuda(device_ids[0])  # teacher model, the last epoch of unpruned training model

    # ############################### Optimizer and Loss ###############################
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr,
                          momentum=args.momentum, weight_decay=args.weight_decay)

    # ************* USE APEX *************
    if USE_APEX:
        print('Use APEX !!! Initialize Model with APEX')
        model, optimizer = apex.amp.initialize(model, optimizer, loss_scale='dynamic', verbosity=0)

    # ****************** multi-GPU ******************
    model = nn.DataParallel(model, device_ids=device_ids)
    model_ref = nn.DataParallel(model_ref, device_ids=device_ids)
    teacher_model = nn.DataParallel(teacher_model, device_ids=device_ids)

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # ############################### Resume ###############################
    # load pruned model (model_ref), use it to mute some weights of model
    title = 'ImageNet'
    if args.resume:
        # Load checkpoint.
        print('==> Getting reference model from checkpoint..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
        best_acc = checkpoint['best_acc']
        start_epoch = args.start_epoch
        model_ref.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_scratch.txt'), title=title)
    logger.set_names(['EPOCH', 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    # set some weights to zero, according to model_ref ---------------------------------
    # ############## load Lottery Ticket (initialization parameters of un pruned model) ##############
    if args.model:
        print('==> Loading init model (Lottery Ticket) from %s' % args.model)
        checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        if 'init' in args.model:
            start_epoch = 0
        else:
            start_epoch = checkpoint['epoch']
        print('Start Epoch ', start_epoch)
    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            m.weight.data.mul_(mask)

    # ############## load parameters of teacher model ##############
    print('==> Loading teacher model (un-pruned) from %s' % args.teacher)
    checkpoint = torch.load(args.teacher, map_location=lambda storage, loc: storage)
    teacher_model.load_state_dict(checkpoint['state_dict'])
    teacher_model.eval()

    # ############################### Train and val ###############################
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum([param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        # train model
        train_loss, train_acc = train(trainloader, model, teacher_model, optimizer, epoch, use_cuda)

        # ######## acc on validation data each epoch ########
        dev_loss, dev_acc = test(devloader, model, criterion, epoch, use_cuda)

        # append logger file
        logger.append([ epoch, state['lr'], train_loss, dev_loss, train_acc, dev_acc])

        # save model after one epoch
        # Note: save all models after one epoch, to help find the best rewind
        is_best = dev_acc > best_acc
        best_acc = max(dev_acc, best_acc)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'acc': dev_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint=args.save_dir, filename=str(epoch + 1)+'_checkpoint.pth.tar')

    print('Best val acc:')
    print(best_acc)

    logger.close()
コード例 #4
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root='../data',
                          train=True,
                          download=True,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root='../data',
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    # if args.arch.startswith('resnext'):
    #     model = models.__dict__[args.arch](
    #                 cardinality=args.cardinality,
    #                 num_classes=num_classes,
    #                 depth=args.depth,
    #                 widen_factor=args.widen_factor,
    #                 dropRate=args.drop,
    #             )
    if args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    # elif args.arch.startswith('wrn'):
    #     model = models.__dict__[args.arch](
    #                 num_classes=num_classes,
    #                 depth=args.depth,
    #                 widen_factor=args.widen_factor,
    #                 dropRate=args.drop,
    #             )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    elif args.arch.endswith('preresnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)  # default is 0.001

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_finetune.txt'),
                    title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum(
            [param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.save_dir)

    logger.close()

    print('Best acc:')
    print(best_acc)
コード例 #5
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root='./data',
                          train=True,
                          download=True,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root='./data',
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](dataset=args.dataset)

    model = model.cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)  # default is 0.001
    if args.use_onecycle:
        lr_scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=args.lr,
            div_factor=10,
            epochs=args.epochs,
            steps_per_epoch=len(trainloader),
            pct_start=0.1,
            final_div_factor=100)
    else:
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=args.schedule,
                                                      gamma=args.gamma)
    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_finetune.txt'),
                    title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        lr = next(iter(optimizer.param_groups))['lr']
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum(
            [param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda, lr_scheduler)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda)

        # append logger file
        logger.append([lr, train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.save_dir)

    logger.close()

    print('Best acc:')
    print(best_acc)
コード例 #6
0
ファイル: cifar_lt_kd.py プロジェクト: lilujunai/KD-ticket
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    os.makedirs(args.save_dir, exist_ok=True)

    # ############################### Dataset ###############################
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    # #################### train, dev, test split ####################
    trainset = dataloader(root='./data',
                          train=True,
                          download=True,
                          transform=transform_train)  # with augmentation
    devset = dataloader(root='./data',
                        train=True,
                        download=False,
                        transform=transform_test)  # without augmentation
    testset = dataloader(root='./data',
                         train=False,
                         download=False,
                         transform=transform_test)

    num_train = len(trainset)  # should be 50000
    indices = list(range(num_train))
    split = int(0.1 * num_train)  #

    train_idx, dev_idx = indices[split:], indices[:split]  # 45000, 5000

    trainset = data.Subset(trainset, train_idx)
    devset = data.Subset(devset, dev_idx)

    print('Total image in train, ', len(trainset))
    print('Total image in valid, ', len(devset))
    print('Total image in test, ', len(testset))

    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    devloader = data.DataLoader(devset,
                                batch_size=args.test_batch,
                                shuffle=False,
                                num_workers=args.workers)

    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # ############################## Model ###############################
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
        model_ref = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
        teacher_model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
        model_ref = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
        teacher_model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
        model_ref = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
        teacher_model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )

    elif args.arch.startswith('resnet'):
        # using new Res Net architecture #####################
        if args.depth == 18:
            model = models.resnet18()
            model_ref = models.resnet18()
            teacher_model = models.resnet18()
        elif args.depth == 34:
            model = models.resnet34()
            model_ref = models.resnet34()
            teacher_model = models.resnet34()
        elif args.depth == 50:
            model = models.resnet50()
            model_ref = models.resnet50()
            teacher_model = models.resnet50()
        elif args.depth == 101:
            model = models.resnet101()
            model_ref = models.resnet101()
            teacher_model = models.resnet101()

        elif args.depth == 152:
            model = models.resnet152()
            model_ref = models.resnet152()
            teacher_model = models.resnet152()
        # if not specify, the default is ResNet 18
        else:
            model = models.resnet18()
            model_ref = models.resnet18()
            teacher_model = models.resnet18()

    elif args.arch.startswith('oresnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
        model_ref = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
        teacher_model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )

    else:
        model = models.__dict__[args.arch](num_classes=num_classes)
        model_ref = models.__dict__[args.arch](num_classes=num_classes)
        teacher_model = models.__dict__[args.arch](num_classes=num_classes)

    model.cuda()  # model to train (student model)
    model_ref.cuda()  # pruned model
    teacher_model.cuda(
    )  # teacher model, the last epoch of unpruned training model

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # ############################### Optimizer and Loss ###############################
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)  # default is 0.001

    # ############################### Resume ###############################
    # load pruned model (model_ref), use it to mute some weights of model
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Getting reference model from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = args.start_epoch
        model_ref.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_scratch.txt'),
                    title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    # set some weights to zero, according to model_ref ---------------------------------
    # ############## load Lottery Ticket (initialization parameters of un pruned model) ##############
    if args.model:
        print('==> Loading init model (Lottery Ticket) from %s' % args.model)
        checkpoint = torch.load(args.model)
        model.load_state_dict(checkpoint['state_dict'])

    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            m.weight.data.mul_(mask)

    # ############## load parameters of teacher model ##############
    print('==> Loading teacher model (un-pruned) from %s' % args.teacher)
    checkpoint = torch.load(args.teacher)
    teacher_model.load_state_dict(checkpoint['state_dict'])
    teacher_model.eval()

    # ############################### Train and val ###############################
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum(
            [param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        # train model
        train_loss, train_acc = train(trainloader, model, teacher_model,
                                      optimizer, epoch, use_cuda)

        # ######## acc on validation data each epoch ########
        dev_loss, dev_acc = test(devloader, model, criterion, epoch, use_cuda)

        # append logger file
        logger.append([state['lr'], train_loss, dev_loss, train_acc, dev_acc])

        # save model after one epoch
        # Note: save all models after one epoch, to help find the best rewind
        is_best = dev_acc > best_acc
        best_acc = max(dev_acc, best_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': dev_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.save_dir)

    print('Best val acc:')
    print(best_acc)
    # ############################### test ###############################
    print('Load best model ...')
    checkpoint = torch.load(os.path.join(args.save_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    test_loss, test_acc = test(testloader, model, criterion, 0, use_cuda)
    logger.append([state['lr'], -1, test_loss, -1, test_acc])
    print('test acc (best val acc)')
    print(test_acc)

    print('Load last model ...')
    checkpoint = torch.load(os.path.join(args.save_dir, 'checkpoint.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    test_loss, test_acc = test(testloader, model, criterion, 0, use_cuda)
    logger.append([state['lr'], -1, test_loss, -1, test_acc])
    print('test acc (last epoch)')
    print(test_acc)

    logger.close()
コード例 #7
0
ファイル: lottery_ticket.py プロジェクト: Aloereed/DessiLBI
def main():
    if args.gamma_supp == "True":
        writer = SummaryWriter("{}/tblogs/gamma_epoch{}".format(
            args.save_dir, args.gamma_epoch))
    else:
        writer = SummaryWriter("{}/tblogs/norm_prune_sgdepoch{}_p{}".format(
            args.save_dir, args.sgd_epoch, args.percent))
    global best_acc
    start_epoch = args.start_epoch

    os.makedirs(args.save_dir, exist_ok=True)

    # Data
    if True:
        print('==> Preparing dataset %s' % args.dataset)
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
        else:
            dataloader = datasets.CIFAR100
            num_classes = 100
        trainset = dataloader(root='./data',
                              train=True,
                              download=True,
                              transform=transform_train)
        trainloader = data.DataLoader(trainset,
                                      batch_size=args.train_batch,
                                      shuffle=True,
                                      num_workers=args.workers)
        testset = dataloader(root='./data',
                             train=False,
                             download=False,
                             transform=transform_test)
        testloader = data.DataLoader(testset,
                                     batch_size=args.test_batch,
                                     shuffle=False,
                                     num_workers=args.workers)

    # Model
    if True:
        print("==> creating model '{}'".format(args.arch))
        if args.arch.endswith('resnet'):
            model = models.__dict__[args.arch](
                num_classes=num_classes,
                depth=args.depth,
            )
            model_ref = models.__dict__[args.arch](
                num_classes=num_classes,
                depth=args.depth,
            )
        else:
            model = models.__dict__[args.arch](num_classes=num_classes)
            model_ref = models.__dict__[args.arch](num_classes=num_classes)

        model.cuda()
        model_ref.cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # Resume
    title = 'cifar-10-' + args.arch
    if args.save_dir:
        # Load checkpoint.
        if args.gamma_supp == "True":
            print('==> Getting reference model from gamma_supp..')
            assert os.path.isfile("{}/masks/epoch{}.t7".format(
                args.save_dir,
                args.gamma_epoch)), 'Error: no checkpoint directory found!'
            mask_dict = torch.load("{}/masks/epoch{}.t7".format(
                args.save_dir, args.gamma_epoch))
            best_acc = 0.0  # TODO: check it
            start_epoch = args.start_epoch
            for name, param in model_ref.named_parameters():
                if name in mask_dict:
                    mask = mask_dict[name].cuda()
                    param.data.mul_(mask)
            res_str, res_list = analysis_masks(mask_dict)
            writer.add_text("mask_analysis", res_str)
            for i in range(len(res_list)):
                writer.add_scalar("layer sparse rate",
                                  res_list[i],
                                  global_step=i)
        else:
            print('==> Getting reference model from weight..')
            if args.sgd_epoch == -1:
                print("Using SGD best weights!")
                checkpoint = torch.load("{}/pruned_p{}.pth.tar".format(
                    args.save_dir, args.percent))
            else:
                print("Using SGD weight prune @ epoch {}".format(
                    args.sgd_epoch))
                checkpoint = torch.load("{}/pruned_p{}_epoch{}.pth.tar".format(
                    args.save_dir, args.percent, args.sgd_epoch))
            best_acc = checkpoint['best_acc']
            start_epoch = args.start_epoch
            model_ref.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(
        args.save_dir,
        str(args.percent) + '_' + str(args.sgd_epoch) + '_log_scratch.txt'),
                    title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    # set some weights to zero, according to model_ref ---------------------------------
    assert args.model
    if args.save_dir:
        print('==> Loading init model from %s' % args.model)
        checkpoint = torch.load("{}/init.pth.tar".format(args.save_dir))
        model.load_state_dict(checkpoint['state_dict'])
    # setting zeros
    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            m.weight.data.mul_(mask)

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))
        num_zeros = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_zeros))
        num_parameters = sum(
            [param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))
        print("Prune rate: {}".format(float(num_zeros) / num_parameters))
        writer.add_scalar("Prune rate", float(num_zeros) / num_parameters)
        writer.add_scalar("Parameters", num_parameters)
        writer.add_scalar("Zero parameters", num_zeros)

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda)

        writer.add_scalar("train_loss", train_loss, global_step=epoch)
        writer.add_scalar("train_err", 100 - train_acc, global_step=epoch)
        writer.add_scalar("test_loss", test_loss, global_step=epoch)
        writer.add_scalar("test_err", 100 - test_acc, global_step=epoch)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.save_dir)

    logger.close()

    print('Best acc:')
    print(best_acc)
    writer.add_scalar("Best_err", 100 - best_acc)
コード例 #8
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch
    
    os.makedirs(args.save_dir, exist_ok=True)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    trainloader, testloader = get_split_cifar100(args.start_class, args.end_class)

    if args.dataset == 'cifar10':
        num_classes = 10
    else:
        num_classes = 100

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    cardinality=args.cardinality,
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    cardinality=args.cardinality,
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    growthRate=args.growthRate,
                    compressionRate=args.compressionRate,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    growthRate=args.growthRate,
                    compressionRate=args.compressionRate,
                    dropRate=args.drop,
                )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                )
        model_ref = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)
        model_ref = models.__dict__[args.arch](num_classes=num_classes)

    model.cuda()
    model_ref.cuda()

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # default is 0.001

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Getting reference model from checkpoint..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        # args.save_dir = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = args.start_epoch
        model_ref.load_state_dict(checkpoint['state_dict'])

    logger = Logger(os.path.join(args.save_dir, 'log_scratch.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    # set some weights to zero, according to model_ref ---------------------------------
    if args.model:
        print('==> Loading init model from %s'%args.model)
        checkpoint = torch.load(args.model)
        model.load_state_dict(checkpoint['state_dict'])

    for m, m_ref in zip(model.modules(), model_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            m.weight.data.mul_(mask)

    train_losses = []
    train_acces = []
    test_losses = []
    test_acces = []

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
        num_parameters = get_conv_zero_param(model)
        print('Zero parameters: {}'.format(num_parameters))
        num_parameters = sum([param.nelement() for param in model.parameters()])
        print('Parameters: {}'.format(num_parameters))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

        train_losses.append(train_loss)
        train_acces.append(train_acc)
        test_acces.append(test_acc)
        test_losses.append(test_loss)

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.save_dir)

    logger.close()

    sns.lineplot(x=range(len(test_acces)), y=test_acces, color='red', dashes=True)
    plt.xlabel("episode")
    plt.ylabel("accuracy")
    plt.legend()
    plt.savefig(os.path.join(args.save_dir, "test.png"))

    print('Best acc:')
    print(best_acc)
コード例 #9
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    # transform_train = transforms.Compose([
    #     transforms.RandomCrop(32, padding=4),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    # ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    testset = dataloader(root='../data',
                         train=False,
                         download=False,
                         transform=transform_test)
    rand_sampler = torch.utils.data.RandomSampler(testset,
                                                  num_samples=1,
                                                  replacement=True)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 sampler=rand_sampler,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    elif args.arch.endswith('preresnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)

    # model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = False
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)  # default is 0.001

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
        # model.load_state_dict(checkpoint['state_dict'])

    # logger = Logger(os.path.join(args.save_dir, 'log_finetune.txt'), title=title)
    # logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])
    epoch = start_epoch
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
    num_parameters = get_conv_zero_param(model)
    print('Zero parameters: {}'.format(num_parameters))
    num_parameters = sum([param.nelement() for param in model.parameters()])
    print('Parameters: {}'.format(num_parameters))

    print('Best acc:')
    print(best_acc)

    iterations = args.iterations
    # results = runThroughput(iterations, testloader, model, criterion, epoch, use_cuda)
    results = runThroughput(iterations, testloader, model)
    addResults(iterations, 1, results)