示例#1
0
def main(config):
    acc_dict = {}
    trainset, valset = CIFAR100('data/')
    train_dict, val_dict = continual_wrapper(trainset, valset, num_tasks=10)
    for i in range(3, 10):
        del train_dict[i]
        del val_dict[i]

    net = seresnet20_cifar100().to(config.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config.lr,
                                momentum=config.mo)
    for task in range(config.num_tasks):
        trainloader = torch.utils.data.DataLoader(train_dict[task],
                                                  batch_size=config.bs)
        valloader = torch.utils.data.DataLoader(val_dict[task],
                                                batch_size=config.bs)
        learn(net, trainloader, valloader, criterion, optimizer, config)
        acc_dict[task] = {}
        for i in range(task + 1):
            valloader = torch.utils.data.DataLoader(val_dict[i],
                                                    batch_size=config.bs)
            acc_dict[task][i] = test(net, valloader, criterion, config)[1]
    visualise(acc_dict)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = CIFAR100(root='../../data',
                    train=True,
                    download=True,
                    transform=transform_train,
                    train_class_num=args.train_class_num,
                    test_class_num=args.test_class_num,
                    includes_all_train_class=args.includes_all_train_class)

testset = CIFAR100(root='../../data',
                   train=False,
                   download=True,
                   transform=transform_test,
                   train_class_num=args.train_class_num,
                   test_class_num=args.test_class_num,
                   includes_all_train_class=args.includes_all_train_class)

# data loader
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.stage1_bs,
示例#3
0
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # checkpoint
    args.checkpoint = './checkpoints/cifar/' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = CIFAR100(root='../../data', train=True, download=True, transform=transform_train,
                        train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                        includes_all_train_class=args.includes_all_train_class)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=4)
    testset = CIFAR100(root='../../data', train=False, download=True, transform=transform_test,
                       train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                       includes_all_train_class=args.includes_all_train_class)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=4)


    # Model
    print('==> Building model..')
    net = models.__dict__[args.arch](num_classes=args.train_class_num) # CIFAR 100
    net = net.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'Learning Rate', 'Train Loss','Train Acc.', 'Test Loss', 'Test Acc.'])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # test(0, net, trainloader, testloader, criterion, device)
    epoch=0
    if not args.evaluate:
        for epoch in range(start_epoch, start_epoch + args.es):
            print('\nEpoch: %d   Learning rate: %f' % (epoch+1, optimizer.param_groups[0]['lr']))
            adjust_learning_rate(optimizer, epoch, args.lr)
            train_loss, train_acc = train(net,trainloader,optimizer,criterion,device)
            save_model(net, None, epoch, os.path.join(args.checkpoint,'last_model.pth'))
            test_loss, test_acc = 0, 0
            #
            logger.append([epoch+1, optimizer.param_groups[0]['lr'], train_loss, train_acc, test_loss, test_acc])

    test(epoch, net, trainloader, testloader, criterion, device)
    logger.close()
示例#4
0
def main():
    global args, best_prec1, train_rec, test_rec
    args = parser.parse_args()
    args.root = "work"
    args.folder = osp.join(args.root, args.arch)
    setproctitle.setproctitle(args.arch)

    # if osp.exists(args.folder):
    #     shutil.rmtree(args.folder)

    os.makedirs(args.folder, exist_ok=True)

    if args.dataset == "cifar10":
        CIFAR = CIFAR10(args.data)
    else:
        CIFAR = CIFAR100(args.data)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        raise NotImplementedError("pre-trained is not supported on CIFAR")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](CIFAR.num_classes)

    args.distributed = args.world_size > 1
    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    cudnn.benchmark = True

    from trainers.classification import ClassificationTrainer

    train_loader, valid_loader = CIFAR.get_loader(args)
    trainer = ClassificationTrainer(model, criterion, args, optimizer)

    if args.evaluate:
        trainer.evaluate(valid_loader, model, criterion)
        return

    from torch.optim.lr_scheduler import MultiStepLR, StepLR

    step1 = int(args.epochs * 0.5)
    step2 = int(args.epochs * 0.75)
    lr_scheduler = MultiStepLR(optimizer, milestones=[step1, step2], gamma=0.1)

    trainer.fit(train_loader,
                valid_loader,
                start_epoch=0,
                max_epochs=200,
                lr_scheduler=lr_scheduler)
示例#5
0
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # checkpoint
    args.checkpoint = './checkpoints/cifar/' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = CIFAR100(root='../../data',
                        train=True,
                        download=True,
                        transform=transform_train,
                        train_class_num=args.train_class_num,
                        test_class_num=args.test_class_num,
                        includes_all_train_class=args.includes_all_train_class)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.bs,
                                              shuffle=True,
                                              num_workers=4)
    testset = CIFAR100(root='../../data',
                       train=False,
                       download=True,
                       transform=transform_test,
                       train_class_num=args.train_class_num,
                       test_class_num=args.test_class_num,
                       includes_all_train_class=args.includes_all_train_class)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.bs,
                                             shuffle=False,
                                             num_workers=4)

    # Model
    print('==> Building model..')
    net = Network(backbone=args.arch,
                  num_classes=args.train_class_num,
                  embed_dim=args.embed_dim)
    fea_dim = net.classifier.in_features
    net = net.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    criterion_softamx = nn.CrossEntropyLoss()
    criterion_centerloss = CenterLoss(num_classes=args.train_class_num,
                                      feat_dim=fea_dim).to(device)
    optimizer_softmax = optim.SGD(net.parameters(),
                                  lr=args.lr,
                                  momentum=0.9,
                                  weight_decay=5e-4)
    optimizer_centerloss = torch.optim.SGD(criterion_centerloss.parameters(),
                                           lr=args.center_lr,
                                           momentum=0.9,
                                           weight_decay=5e-4)

    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print('==> Resuming from checkpoint..')
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            criterion_centerloss.load_state_dict(checkpoint['centerloss'])
            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names([
            'Epoch', 'Total Loss', 'Softmax Loss', 'Center Loss', 'train Acc.'
        ])

    if not args.evaluate:
        scheduler = lr_scheduler.StepLR(optimizer_softmax,
                                        step_size=30,
                                        gamma=0.1)
        for epoch in range(start_epoch, args.es):
            print('\nEpoch: %d   Learning rate: %f' %
                  (epoch + 1, optimizer_softmax.param_groups[0]['lr']))
            train_loss, softmax_loss, center_loss, train_acc = train(
                net, trainloader, optimizer_softmax, optimizer_centerloss,
                criterion_softamx, criterion_centerloss, device)
            save_model(net, criterion_centerloss, epoch,
                       os.path.join(args.checkpoint, 'last_model.pth'))
            logger.append(
                [epoch + 1, train_loss, softmax_loss, center_loss, train_acc])
            scheduler.step()

            test(net, testloader, device)

    logger.close()
示例#6
0
def main():
    global best_prec1, train_rec, test_rec
    conf = get_configs()

    conf.root = "work"
    conf.folder = osp.join(conf.root, conf.arch)
    conf.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")
    setproctitle.setproctitle(conf.arch)

    os.makedirs(conf.folder, exist_ok=True)

    if conf.dataset == "cifar10":
        CIFAR = CIFAR10(conf.data)  # Datasets object
    else:
        CIFAR = CIFAR100(conf.data)

    # create model
    if conf.pretrained:
        print("=> using pre-trained model '{}'".format(conf.arch))
        raise NotImplementedError("pre-trained is not supported on CIFAR")
        # model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(conf.arch))
        model = models.__dict__[conf.arch](CIFAR.num_classes)
    # print(model.features)
    conf.distributed = conf.distributed_processes > 1
    if not conf.distributed:
        if conf.gpus > 0:
            model = nn.DataParallel(model)
        model.to(conf.device)
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                conf.lr,
                                momentum=conf.momentum,
                                weight_decay=conf.weight_decay)

    # optionally resume from a checkpoint
    if conf.resume:
        if os.path.isfile(conf.resume):
            print("=> loading checkpoint from '{}'".format(conf.resume))
            checkpoint = torch.load(conf.resume)
            conf.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint done. (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(conf.resume))

    cudnn.benchmark = True  # improve the efficiency of the program

    train_loader, valid_loader = CIFAR.get_loader(conf)
    trainer = ClassificationTrainer(model, criterion, conf, optimizer)

    if conf.evaluate:
        trainer.evaluate(valid_loader, model, criterion)
        return

    step1 = int(conf.epochs * 0.5)
    step2 = int(conf.epochs * 0.75)
    lr_scheduler = MultiStepLR(optimizer, milestones=[step1, step2], gamma=0.1)

    trainer.fit(train_loader,
                valid_loader,
                start_epoch=0,
                max_epochs=200,
                lr_scheduler=lr_scheduler)