Exemple #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)

    # uncomment the following line to use tensorboard
    #writer = SummaryWriter(os.path.join(args.checkpoint_path, 'logs'))

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    print(args)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    # Currently does not support training on a distributed system
    backend_name = type(model).__name__
    # split a model into two parts, i.e., a DNN and a FC layer
    model, classifier = decomposeModel(model,
                                       args.n_classes,
                                       keep_pre_pooling=True)
    classifier = classifier.cuda()
    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # define a regressor to wrap the classifier and its optimizer for gradient modification
    reg_net = None
    optimizer_reg = optim.SGD(classifier.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    if args.mtype == 'baseline':
        args.dcl_refsize = 0
        reg_net = regressor.Net(classifier,
                                optimizer_reg,
                                ref_size=args.dcl_refsize,
                                backendtype=backend_name)
    elif args.mtype == 'dcl':
        reg_net = regressor.Net(classifier,
                                optimizer_reg,
                                ref_size=args.dcl_refsize,
                                backendtype=backend_name,
                                dcl_offset=args.dcl_offset,
                                dcl_window=args.dcl_window,
                                QP_margin=args.dcl_QP_margin)
    elif args.mtype == 'gem':
        reg_net = gem.Net(classifier,
                          optimizer_reg,
                          n_memories=args.gem_memsize,
                          backendtype=backend_name)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # --------
            optimizer_reg.load_state_dict(checkpoint['optimizer_reg'])
            reg_net.load_state_dict(checkpoint['reg_net'])
            reg_net.opt = optimizer_reg
            reg_net.ref_size = checkpoint['reg_net_ref_size']
            reg_net.backendtype = checkpoint['reg_net_backendtype']
            reg_net.dcl_window = checkpoint['reg_net_dcl_window']
            reg_net.ref_cnt = checkpoint['reg_net_dcl_ref_cnt']
            reg_net.ref_data = checkpoint['reg_net_dcl_accum_grad'].cuda(
            ) if checkpoint['reg_net_dcl_accum_grad'] is not None else None
            reg_net.stat_w1 = checkpoint['stat_ref_weight'].cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
        shutil.copyfile(
            os.path.join(args.checkpoint_path, 'stat.csv'),
            os.path.join(args.checkpoint_path,
                         'stat_end_ep{}.csv'.format(args.start_epoch)))
        logger = Logger(os.path.join(args.checkpoint_path, 'stat.csv'),
                        title='ImageNet',
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint_path, 'stat.csv'),
                        title='ImageNet')
        logger.set_names([
            'Epoch', 'Tr_T1Acc', 'Val_T1Acc', 'Val_T5Acc', 'Cong_Bef',
            'Cong_Aft', 'Magn', 'LR', 'Tr_loss', 'Val_loss', 'Tr_batch_time',
            'Tr_data_time', 'Val_batch_time', 'Val_data_time'
        ])

    cudnn.benchmark = True

    # setup data preprocessing procedure
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    # load training data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    # load validation data
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    scores = []
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        cur_lr = adjust_learning_rate_multiopt(
            optimizer,
            optimizer_reg,
            epoch,
            lr_decay_epoch=args.lr_decay_epoch)

        # train for one epoch
        tr_loss, tr_top1, tr_top5, tr_cong_bf, tr_cong_af, tr_magn, tr_batch_time, tr_data_time = train(
            train_loader, model, criterion, optimizer, epoch, reg_net=reg_net)

        # evaluate on validation set
        val_loss, prec1, prec5, val_batch_time, val_data_time = validate(
            val_loader, model, criterion, reg_net=reg_net)
        scores.append(prec1)
        print('--max top-1 acc: {}'.format(max(scores)))

        # uncomment the following lines to use tensorboard
        #writer.add_scalars('top-1 acc', {'train top-1': tr_top1, 'val top-1': prec1}, epoch+1)
        #writer.add_scalars('congruency', {'tr_cong_bf': tr_cong_bf, 'tr_cong_af': tr_cong_af}, epoch+1)
        #writer.add_scalars('loss', {'train loss': tr_loss, 'val loss': val_loss}, epoch+1)
        logger.append([
            epoch + 1, tr_top1, prec1, prec5, tr_cong_bf, tr_cong_af, tr_magn,
            cur_lr, tr_loss, val_loss, tr_batch_time, tr_data_time,
            val_batch_time, val_data_time
        ])

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict(),
                'best_prec1':
                best_prec1,
                'optimizer':
                optimizer.state_dict(),
                'optimizer_reg':
                optimizer_reg.state_dict(),
                'reg_net':
                reg_net.state_dict(),
                'reg_net_ref_size':
                reg_net.ref_size if args.mtype != 'gem' else 0,
                'reg_net_backendtype':
                reg_net.backendtype,
                'reg_net_dcl_window':
                reg_net.dcl_window if args.mtype != 'gem' else 0,
                'reg_net_dcl_ref_cnt':
                reg_net.ref_cnt if args.mtype != 'gem' else 0,
                'reg_net_dcl_accum_grad':
                reg_net.ref_data.clone().cpu() if args.mtype != 'gem'
                and reg_net.ref_data is not None else None,
                'stat_ref_weight':
                reg_net.stat_w1.clone().cpu(),
            },
            is_best,
            save_dir=args.checkpoint_path)
    logger.close()
Exemple #2
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # uncomment the following line to use tensorboard
    # writer = SummaryWriter(os.path.join(args.checkpoint, 'logs'))

    # Data loading
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    valdir = os.path.join(args.data, 'val', 'images')
    valgtfile = os.path.join(args.data, 'val', 'val_annotations.txt')
    val_dataset = TImgNetDataset(valdir, valgtfile, class_to_idx=train_loader.dataset.class_to_idx.copy(),
            transform=transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            normalize,
            ]))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # create model
    if args.pretrained_model:
        print("=> using pre-trained model '{}'".format(args.pretrained_model))
        model = modelzoo[args.pretrained_model](pretrained=bool(args.use_pretrained))
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    baseWidth=args.base_width,
                    cardinality=args.cardinality,
                )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    modeltype = type(model).__name__
    # split a model into two parts, i.e., a DNN and a FC layer
    model,classifier = decomposeModel(model, args.n_classes)
    classifier = classifier.cuda()

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # define a regressor to wrap the classifier and its optimizer for gradient modification
    reg_net = None
    optimizer_reg = optim.SGD(classifier.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.mtype == 'baseline':
        args.dcl_refsize = 0
        reg_net = regressor.Net(classifier, optimizer_reg, ref_size=args.dcl_refsize, backendtype=modeltype)
    elif args.mtype == 'dcl':
        reg_net = regressor.Net(classifier, optimizer_reg, ref_size=args.dcl_refsize, backendtype=modeltype, dcl_offset=args.dcl_offset, dcl_window=args.dcl_window)
    elif args.mtype == 'gem':
        reg_net = gem.Net(classifier, optimizer_reg, n_memories=args.gem_memsize, backendtype=modeltype)

    print(args)

    # Resume
    title = 'Tiny ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'stat.csv'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'stat.csv'), title=title)
        # Cong_Bef stands for congruency before gradient modification
        # Cong_Aft stands for congruency after gradient modification
        # Magn stands for magnitude
        logger.set_names(['Epoch', 'Tr_T1Acc', 'Val_T1Acc', 'Val_T5Acc', 
                                'Cong_Bef', 'Cong_Aft', 'Magn', 'LR', 'Tr_loss', 'Val_loss',
                                'Tr_batch_time', 'Tr_data_time', 'Val_batch_time', 'Val_data_time'])


    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        #adjust_learning_rate for two optimizers, i.e., one for the DNN trunk while the other for the regressor
        adjust_learning_rate_two(optimizer, optimizer_reg, epoch)

        print('Epoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))

        train_loss, train_top1_acc, train_top5_acc, cong_bf, cong_af, mag, tr_batch_time, tr_data_time = train(train_loader, model, criterion, optimizer, epoch, use_cuda, reg_net=reg_net)
        test_loss, test_top1_acc, test_top5_acc, te_batch_time, te_data_time = test(val_loader, model, criterion, epoch, use_cuda, reg_net=reg_net)
        print('train_loss: {:.4f}, train_top1_err: {:.2f}, train_top5_err: {:.2f}, test_loss: {:.4f}, test_top1_err: {:.2f}, test_top5_err: {:.2f}, cong_bf: {:.4f}, cong_af: {:.4f}, mag: {:.4f}'.format(
            train_loss, 100-train_top1_acc, 100-train_top5_acc,
            test_loss, 100-test_top1_acc, 100-test_top5_acc,
            cong_bf, cong_af, mag))

        # uncomment the following line to use tensorboard
        # writer.add_scalars('top-1 err', {'train top-1': 100-train_top1_acc, 'val top-1': 100-test_top1_acc}, epoch+1)
        logger.append([epoch + 1, 100-train_top1_acc, 100-test_top1_acc, 100-test_top5_acc, cong_bf, cong_af, mag, 
                        state['lr'], train_loss, test_loss,
                        tr_batch_time, tr_data_time, te_batch_time, te_data_time])

        # save model
        is_best = test_top1_acc > best_acc
        best_acc = max(test_top1_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'classifier_state_dict': reg_net.state_dict(),
                'err': 100-test_top1_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
                'cong_bf': cong_bf,
                'cong_af': cong_af,
                'mag': mag,
            }, is_best, checkpoint=args.checkpoint)

    logger.close()
    # logger.plot()
    # savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best err:')
    print(100-best_acc)
Exemple #3
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    if args.arch.startswith('efficientnet'):
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root=args.datapath,
                          train=True,
                          download=False,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root=args.datapath,
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    elif args.arch.startswith('efficientnet'):
        model = EfficientNet.from_pretrained(args.arch,
                                             num_classes=num_classes)
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)
    modeltype = type(model).__name__
    model, classifier = decomposeModel(model)
    classifier = classifier.cuda()

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    reg_net = None
    optimizer_reg = optim.SGD(classifier.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    if args.mtype == 'baseline':
        args.dcl_refsize = 0
        reg_net = regressor.Net(classifier,
                                optimizer_reg,
                                ref_size=args.dcl_refsize,
                                backendtype=modeltype)
    elif args.mtype == 'dcl':
        reg_net = regressor.Net(classifier,
                                optimizer_reg,
                                ref_size=args.dcl_refsize,
                                backendtype=modeltype,
                                dcl_offset=args.dcl_offset,
                                dcl_window=args.dcl_window,
                                QP_margin=args.dcl_QP_margin)
    elif args.mtype == 'gem':
        reg_net = gem.Net(classifier,
                          optimizer_reg,
                          n_memories=args.gem_memsize,
                          backendtype=modeltype)

    print(args)

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Err.',
            'Valid Err.', 'Cos Bef.', 'Cos Aft.', 'Mag'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate_two(optimizer, optimizer_reg, epoch)

        print('Epoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_top1_acc, train_top5_acc, cong_bf, cong_af, mag = train(
            trainloader,
            model,
            criterion,
            optimizer,
            epoch,
            use_cuda,
            reg_net=reg_net)
        test_loss, test_top1_acc, test_top5_acc = test(testloader,
                                                       model,
                                                       criterion,
                                                       epoch,
                                                       use_cuda,
                                                       reg_net=reg_net)
        print(
            'train_loss: {:.4f}, train_top1_err: {:.2f}, train_top5_err: {:.2f}, test_loss: {:.4f}, test_top1_err: {:.2f}, test_top5_err: {:.2f}, cong_bf: {:.4f}, cong_af: {:.4f}, mag: {:.4f}'
            .format(train_loss, 100 - train_top1_acc, 100 - train_top5_acc,
                    test_loss, 100 - test_top1_acc, 100 - test_top5_acc,
                    cong_bf, cong_af, mag))

        # append logger file
        logger.append([
            state['lr'], train_loss, test_loss, 100 - train_top1_acc,
            100 - test_top1_acc, cong_bf, cong_af, mag
        ])

        # save model
        is_best = test_top1_acc > best_acc
        best_acc = max(test_top1_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'classifier_state_dict': reg_net.state_dict(),
                'acc': test_top1_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'cong_bf': cong_bf,
                'cong_af': cong_af,
                'mag': mag,
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()

    print('Best err:')
    print(100 - best_acc)