def tune_mahalanobis_hyperparams():
    def print_tuning_results(results, stypes):
        mtypes = ['FPR', 'DTERR', 'AUROC', 'AUIN', 'AUOUT']

        for stype in stypes:
            print(' OOD detection method: ' + stype)
            for mtype in mtypes:
                print(' {mtype:6s}'.format(mtype=mtype), end='')
            print('\n{val:6.2f}'.format(val=100. * results[stype]['FPR']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['DTERR']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['AUROC']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['AUIN']),
                  end='')
            print(' {val:6.2f}\n'.format(val=100. * results[stype]['AUOUT']),
                  end='')
            print('')

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/hyperparams/', args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                                      (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar10',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 100

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2)

    model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(
            name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2, 3, 32, 32)
    temp_x = Variable(temp_x)
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list,
                                              trainloaderIn)

    print('train logistic regression model')
    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in trainloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(data[0].numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(val_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(val_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in np.arange(0, 0.0041, 0.004 / 20):
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(
                int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total:total + args.batch_size]
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(model, data,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis,
                                          dtype=np.float32)

        regressor = LogisticRegressionCV().fit(train_lr_Mahalanobis,
                                               train_lr_label)

        print('Logistic Regressor params:', regressor.coef_,
              regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"),
                  'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(model, images,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)

            confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:,
                                                                            1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, m,
                time.time() - t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(model, images,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)

            confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:,
                                                                            1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, m,
                time.time() - t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_tuning_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_,
          best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude
Ejemplo n.º 2
0
def main():
    if args.tensorboard: configure("runs/%s"%(args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=False, **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule=[10, 15, 18]
        pool_size = int(len(train_loader.dataset) * 8 / args.ood_batch_size) + 1
        num_classes = 10

    ood_dataset_size = len(train_loader.dataset) * 2

    print('OOD Dataset Size: ', ood_dataset_size)

    if args.auxiliary_dataset == '80m_tiny_images':
        ood_loader = torch.utils.data.DataLoader(
            TinyImages(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)
    elif args.auxiliary_dataset == 'imagenet':
        ood_loader = torch.utils.data.DataLoader(
            ImageNet(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)


    # create model
    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            assert False, "=> no checkpoint found at '{}'".format(args.resume)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        selected_ood_loader = select_ood(ood_loader, model, args.batch_size * 2, num_classes, pool_size, ood_dataset_size, args.quantile)

        train_ntom(train_loader, selected_ood_loader, model, criterion, num_classes, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, num_classes)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, epoch + 1)
Ejemplo n.º 3
0
import sys

import numpy as np

from utils import TinyImages
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import os

transform = transforms.Compose([
    transforms.ToTensor(),
])

ood_loader = torch.utils.data.DataLoader(
    TinyImages(transform=transforms.ToTensor()), batch_size=1, shuffle=False)

ood_loader.dataset.offset = 0

save_dir = "datasets/val_ood_data/tiny_images/0"

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i, (images, labels) in enumerate(ood_loader):
    torchvision.utils.save_image(images[0],
                                 os.path.join(save_dir, '%d.png' % i))
    if i + 1 == 10000:
        break
Ejemplo n.º 4
0
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalizer = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            # transforms.RandomCrop(32, padding=4),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('../../data',
                             train=True,
                             download=True,
                             transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            '../../data', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('../../data',
                              train=True,
                              download=True,
                              transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            '../../data', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)
        num_classes = 100

    if args.ood:
        ood_loader = torch.utils.data.DataLoader(
            TinyImages(transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])),
            batch_size=args.ood_batch_size,
            shuffle=False,
            **kwargs)

    # create model
    model = dn.DenseNet3(args.layers,
                         num_classes,
                         args.growth,
                         reduction=args.reduce,
                         bottleneck=args.bottleneck,
                         dropRate=args.droprate,
                         normalizer=normalizer)

    if args.adv:
        attack_in = LinfPGDAttack(model=model,
                                  eps=args.epsilon,
                                  nb_iter=args.iters,
                                  eps_iter=args.iter_size,
                                  rand_init=True,
                                  loss_func='CE')
        if args.ood:
            attack_out = LinfPGDAttack(model=model,
                                       eps=args.epsilon,
                                       nb_iter=args.iters,
                                       eps_iter=args.iter_size,
                                       rand_init=True,
                                       loss_func='OE')

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.ood:
        ood_criterion = OELoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    if args.lr_scheduler != 'cosine_annealing' and args.lr_scheduler != 'step_decay':
        assert False, 'Not supported lr_scheduler {}'.format(args.lr_scheduler)

    if args.lr_scheduler == 'cosine_annealing':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                args.epochs * len(train_loader),
                1,  # since lr_lambda computes multiplicative factor
                1e-6 / args.lr))
    else:
        scheduler = None

    for epoch in range(args.start_epoch, args.epochs):
        if args.lr_scheduler == 'step_decay':
            adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        if args.ood:
            if args.adv:
                train_ood(train_loader, ood_loader, model, criterion,
                          ood_criterion, optimizer, scheduler, epoch,
                          attack_in, attack_out)
            else:
                train_ood(train_loader, ood_loader, model, criterion,
                          ood_criterion, optimizer, scheduler, epoch)
        else:
            if args.adv:
                train(train_loader, model, criterion, optimizer, scheduler,
                      epoch, attack_in)
            else:
                train(train_loader, model, criterion, optimizer, scheduler,
                      epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule = [10, 15, 18]
        num_classes = 10

    out_loader = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.ood_batch_size,
        shuffle=False,
        **kwargs)

    # create model
    if args.model_arch == 'densenet':
        base_model = dn.DenseNet3(args.layers,
                                  num_classes,
                                  args.growth,
                                  reduction=args.reduce,
                                  bottleneck=args.bottleneck,
                                  dropRate=args.droprate,
                                  normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        base_model = wn.WideResNet(args.depth,
                                   num_classes,
                                   widen_factor=args.width,
                                   dropRate=args.droprate,
                                   normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100])

    gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')

    gmm.alpha = nn.Parameter(gmm.alpha)
    gmm.mu.requires_grad = True
    gmm.logvar.requires_grad = True
    gmm.alpha.requires_grad = False

    gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
    gmm_out.alpha = nn.Parameter(gmm.alpha)
    gmm_out.mu.requires_grad = True
    gmm_out.logvar.requires_grad = True
    gmm_out.alpha.requires_grad = False
    loglam = 0.
    model = gmmlib.DoublyRobustModel(base_model,
                                     gmm,
                                     gmm_out,
                                     loglam,
                                     dim=3072,
                                     classes=num_classes).cuda()

    model.loglam.requires_grad = False

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    lr = args.lr
    lr_gmm = 1e-5

    param_groups = [{
        'params': model.mm.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.mm_out.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.base_model.parameters(),
        'lr': lr,
        'weight_decay': args.weight_decay
    }]

    optimizer = torch.optim.SGD(param_groups,
                                momentum=args.momentum,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        lam = model.loglam.data.exp().item()
        train_CEDA_gmm_out(model,
                           train_loader,
                           out_loader,
                           optimizer,
                           epoch,
                           lam=lam)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
def tune_odin_hyperparams():
    print('Tuning hyper-parameters...')
    stypes = ['ODIN']

    save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset,
                            args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 10

    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar100',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                    batch_size=args.batch_size,
                                                    shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        args.epochs = 20
        num_classes = 10

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=False)

    valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset))

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    print('Len of val in: ', len(val_in))
    print('Len of val out: ', len(val_out))

    best_fpr = 1.1
    best_magnitude = 0.0
    for magnitude in np.arange(0, 0.0041, 0.004 / 20):

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f1.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f2.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['ODIN']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude

    return best_magnitude
Ejemplo n.º 7
0
import numpy as np

from utils import TinyImages
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import utils.svhn_loader as svhn
import torchvision
import os

transform = transforms.Compose([
    transforms.ToTensor(),
    ])

ood_loader = torch.utils.data.DataLoader(TinyImages(transform=transforms.ToTensor()), batch_size=1, shuffle=False)

save_dir = "datasets/rowl_train_data/CIFAR-10"

if not os.path.exists(save_dir):
	os.makedirs(save_dir)

for i in range(11):
    os.makedirs(os.path.join(save_dir, '%02d'%i))

class_count = np.zeros(10)

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./datasets/cifar10', train=True, download=True,
                     transform=transform),
    batch_size=1, shuffle=False)