Ejemplo n.º 1
0
def main():
    if args.tensorboard: configure("runs/%s"%(args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=False, **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule=[10, 15, 18]
        pool_size = int(len(train_loader.dataset) * 8 / args.ood_batch_size) + 1
        num_classes = 10

    ood_dataset_size = len(train_loader.dataset) * 2

    print('OOD Dataset Size: ', ood_dataset_size)

    if args.auxiliary_dataset == '80m_tiny_images':
        ood_loader = torch.utils.data.DataLoader(
            TinyImages(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)
    elif args.auxiliary_dataset == 'imagenet':
        ood_loader = torch.utils.data.DataLoader(
            ImageNet(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)


    # create model
    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            assert False, "=> no checkpoint found at '{}'".format(args.resume)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        selected_ood_loader = select_ood(ood_loader, model, args.batch_size * 2, num_classes, pool_size, ood_dataset_size, args.quantile)

        train_ntom(train_loader, selected_ood_loader, model, criterion, num_classes, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, num_classes)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, epoch + 1)
def eval_mahalanobis(sample_mean, precision, regressor, magnitude):
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name,
                            'adv' if args.adv else 'nat')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    start = time.time()
    #loading data sets

    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                                      (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar10',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 100

    model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(
            name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if args.out_dataset == 'SVHN':
        testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                               split='test',
                               transform=transforms.ToTensor(),
                               download=False)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'dtd':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/dtd/images",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'places365':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/places365/test_subset",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    else:
        testsetout = torchvision.datasets.ImageFolder(
            "./datasets/ood_datasets/{}".format(args.out_dataset),
            transform=transform)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

    # set information about feature extaction
    temp_x = torch.rand(2, 3, 32, 32)
    temp_x = Variable(temp_x)
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)

    t0 = time.time()
    f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
    f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')
    N = 10000
    if args.out_dataset == "iSUN": N = 8925
    if args.out_dataset == "dtd": N = 5640
    ########################################In-distribution###########################################
    print("Processing in-distribution images")
    if args.adv:
        attack = MahalanobisLinfPGDAttack(model,
                                          eps=args.epsilon,
                                          nb_iter=args.iters,
                                          eps_iter=args.iter_size,
                                          rand_init=True,
                                          clip_min=0.,
                                          clip_max=1.,
                                          in_distribution=True,
                                          num_classes=num_classes,
                                          sample_mean=sample_mean,
                                          precision=precision,
                                          num_output=num_output,
                                          regressor=regressor)

    count = 0
    for j, data in enumerate(testloaderIn):

        images, _ = data
        batch_size = images.shape[0]

        if count + batch_size > N:
            images = images[:N - count]
            batch_size = images.shape[0]

        if args.adv:
            inputs = attack.perturb(images)
        else:
            inputs = images

        Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes,
                                                   sample_mean, precision,
                                                   num_output, magnitude)

        confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1]

        for k in range(batch_size):
            f1.write("{}\n".format(-confidence_scores[k]))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break


###################################Out-of-Distributions#####################################
    t0 = time.time()
    print("Processing out-of-distribution images")
    if args.adv:
        attack = MahalanobisLinfPGDAttack(model,
                                          eps=args.epsilon,
                                          nb_iter=args.iters,
                                          eps_iter=args.iter_size,
                                          rand_init=True,
                                          clip_min=0.,
                                          clip_max=1.,
                                          in_distribution=False,
                                          num_classes=num_classes,
                                          sample_mean=sample_mean,
                                          precision=precision,
                                          num_output=num_output,
                                          regressor=regressor)

    count = 0

    for j, data in enumerate(testloaderOut):

        images, labels = data
        batch_size = images.shape[0]

        if args.adv:
            inputs = attack.perturb(images)
        else:
            inputs = images

        Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes,
                                                   sample_mean, precision,
                                                   num_output, magnitude)

        confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1]

        for k in range(batch_size):
            f2.write("{}\n".format(-confidence_scores[k]))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break

    f1.close()
    f2.close()

    results = metric(save_dir, stypes)

    print_results(results, stypes)
    return
def eval_msp_and_odin():
    stypes = ['MSP', 'ODIN']

    save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name,
                            'adv' if args.adv else 'nat')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    start = time.time()
    #loading data sets
    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                                      (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 100

    model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(
            name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if args.out_dataset == 'SVHN':
        testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                               split='test',
                               transform=transforms.ToTensor(),
                               download=False)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'dtd':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/dtd/images",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'places365':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/places365/test_subset",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    else:
        testsetout = torchvision.datasets.ImageFolder(
            "./datasets/ood_datasets/{}".format(args.out_dataset),
            transform=transform)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

    t0 = time.time()
    f1 = open(os.path.join(save_dir, "confidence_MSP_In.txt"), 'w')
    f2 = open(os.path.join(save_dir, "confidence_MSP_Out.txt"), 'w')
    g1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w')
    g2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w')
    N = 10000
    if args.out_dataset == "iSUN": N = 8925
    if args.out_dataset == "dtd": N = 5640
    ########################################In-distribution###########################################
    print("Processing in-distribution images")
    if args.adv:
        attack = ConfidenceLinfPGDAttack(model,
                                         eps=args.epsilon,
                                         nb_iter=args.iters,
                                         eps_iter=args.iter_size,
                                         rand_init=True,
                                         clip_min=0.,
                                         clip_max=1.,
                                         in_distribution=True,
                                         num_classes=num_classes)

    count = 0
    for j, data in enumerate(testloaderIn):
        images, _ = data
        batch_size = images.shape[0]

        if count + batch_size > N:
            images = images[:N - count]
            batch_size = images.shape[0]

        if args.adv:
            adv_images = attack.perturb(images)
            inputs = Variable(adv_images, requires_grad=True)
        else:
            inputs = Variable(images, requires_grad=True)

        outputs = model(inputs)

        nnOutputs = MSP(outputs, model)

        for k in range(batch_size):
            f1.write("{}\n".format(np.max(nnOutputs[k])))

        nnOutputs = ODIN(inputs,
                         outputs,
                         model,
                         temper=args.temperature,
                         noiseMagnitude1=args.magnitude)

        for k in range(batch_size):
            g1.write("{}\n".format(np.max(nnOutputs[k])))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break


###################################Out-of-Distributions#####################################
    t0 = time.time()
    print("Processing out-of-distribution images")
    if args.adv:
        attack = ConfidenceLinfPGDAttack(model,
                                         eps=args.epsilon,
                                         nb_iter=args.iters,
                                         eps_iter=args.iter_size,
                                         rand_init=True,
                                         clip_min=0.,
                                         clip_max=1.,
                                         in_distribution=False,
                                         num_classes=num_classes)
    count = 0

    for j, data in enumerate(testloaderOut):
        images, labels = data
        batch_size = images.shape[0]

        if args.adv:
            adv_images = attack.perturb(images)
            inputs = Variable(adv_images, requires_grad=True)
        else:
            inputs = Variable(images, requires_grad=True)

        outputs = model(inputs)

        nnOutputs = MSP(outputs, model)

        for k in range(batch_size):
            f2.write("{}\n".format(np.max(nnOutputs[k])))

        nnOutputs = ODIN(inputs,
                         outputs,
                         model,
                         temper=args.temperature,
                         noiseMagnitude1=args.magnitude)

        for k in range(batch_size):
            g2.write("{}\n".format(np.max(nnOutputs[k])))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break

    f1.close()
    f2.close()
    g1.close()
    g2.close()

    results = metric(save_dir, stypes)

    print_results(results, stypes)
Ejemplo n.º 4
0
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch',
                    type=int,
                    default=2,
                    help='Pre-fetching threads.')
args = parser.parse_args()

state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)

train_data_in = svhn.SVHN('/share/data/vision-greg/svhn/',
                          split='train_and_extra',
                          transform=trn.ToTensor(),
                          download=False)
test_data = svhn.SVHN('/share/data/vision-greg/svhn/',
                      split='test',
                      transform=trn.ToTensor(),
                      download=False)
num_classes = 10

calib_indicator = ''
if args.calibration:
    train_data_in, val_data = validation_split(train_data_in,
                                               val_share=5000 / 604388.)
    calib_indicator = 'calib_'

tiny_images = TinyImages(transform=trn.Compose([
    trn.ToTensor(),
Ejemplo n.º 5
0
    ]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args.test_bs,
                                         shuffle=True,
                                         num_workers=args.prefetch,
                                         pin_memory=True)

print('\n\nTexture Calibration')
get_and_print_results(ood_loader)

# /////////////// SVHN ///////////////

ood_data = svhn.SVHN(root='/share/data/vision-greg/svhn/',
                     split="test",
                     transform=trn.Compose([
                         trn.Resize(32),
                         trn.ToTensor(),
                         trn.Normalize(mean, std)
                     ]),
                     download=False)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args.test_bs,
                                         shuffle=True,
                                         num_workers=args.prefetch,
                                         pin_memory=True)

print('\n\nSVHN Calibration')
get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////

ood_data = dset.ImageFolder(
def main():
    if args.tensorboard: configure("runs/%s"%(args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-10', transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        num_classes = 10
        lr_schedule=[50, 75, 90]
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-100', transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        num_classes = 100
        lr_schedule=[50, 75, 90]
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        transform = transforms.Compose([transforms.ToTensor(),])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/SVHN', transform=transform),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=False, **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule=[10, 15, 18]
        num_classes = 10

    # create model
    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    attack = LinfPGDAttack(model = model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, targeted=True, num_classes=num_classes+1, loss_func='CE', elementwise_best=True)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        train_rowl(train_loader, model, criterion, optimizer, epoch, num_classes, attack)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, num_classes, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, epoch + 1)
Ejemplo n.º 7
0
                                trn.ToTensor(),
                                trn.Normalize(mean, std)
                            ]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args.test_bs,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)
print('\n\nTexture Detection')
get_and_print_results(ood_loader)

# /////////////// SVHN /////////////// # cropped and no sampling of the test set
ood_data = svhn.SVHN(
    root='../data/svhn/',
    split="test",
    transform=trn.Compose([  #trn.Resize(32), 
        trn.ToTensor(),
        trn.Normalize(mean, std)
    ]),
    download=False)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args.test_bs,
                                         shuffle=True,
                                         num_workers=2,
                                         pin_memory=True)
print('\n\nSVHN Detection')
get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////
ood_data = dset.ImageFolder(root="../data/places365/",
                            transform=trn.Compose([
                                trn.Resize(32),
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule = [10, 15, 18]
        num_classes = 10

    out_loader = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.ood_batch_size,
        shuffle=False,
        **kwargs)

    # create model
    if args.model_arch == 'densenet':
        base_model = dn.DenseNet3(args.layers,
                                  num_classes,
                                  args.growth,
                                  reduction=args.reduce,
                                  bottleneck=args.bottleneck,
                                  dropRate=args.droprate,
                                  normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        base_model = wn.WideResNet(args.depth,
                                   num_classes,
                                   widen_factor=args.width,
                                   dropRate=args.droprate,
                                   normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100])

    gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')

    gmm.alpha = nn.Parameter(gmm.alpha)
    gmm.mu.requires_grad = True
    gmm.logvar.requires_grad = True
    gmm.alpha.requires_grad = False

    gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
    gmm_out.alpha = nn.Parameter(gmm.alpha)
    gmm_out.mu.requires_grad = True
    gmm_out.logvar.requires_grad = True
    gmm_out.alpha.requires_grad = False
    loglam = 0.
    model = gmmlib.DoublyRobustModel(base_model,
                                     gmm,
                                     gmm_out,
                                     loglam,
                                     dim=3072,
                                     classes=num_classes).cuda()

    model.loglam.requires_grad = False

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    lr = args.lr
    lr_gmm = 1e-5

    param_groups = [{
        'params': model.mm.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.mm_out.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.base_model.parameters(),
        'lr': lr,
        'weight_decay': args.weight_decay
    }]

    optimizer = torch.optim.SGD(param_groups,
                                momentum=args.momentum,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        lam = model.loglam.data.exp().item()
        train_CEDA_gmm_out(model,
                           train_loader,
                           out_loader,
                           optimizer,
                           epoch,
                           lam=lam)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
def tune_odin_hyperparams():
    print('Tuning hyper-parameters...')
    stypes = ['ODIN']

    save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset,
                            args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 10

    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar100',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                    batch_size=args.batch_size,
                                                    shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        args.epochs = 20
        num_classes = 10

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=False)

    valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset))

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    print('Len of val in: ', len(val_in))
    print('Len of val out: ', len(val_out))

    best_fpr = 1.1
    best_magnitude = 0.0
    for magnitude in np.arange(0, 0.0041, 0.004 / 20):

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f1.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f2.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['ODIN']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude

    return best_magnitude
Ejemplo n.º 10
0
def eval_ood_detector(base_dir, in_dataset, out_datasets, batch_size, method,
                      method_args, name, epochs, adv, corrupt, adv_corrupt,
                      adv_args, mode_args):

    if adv:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv',
                                   str(int(adv_args['epsilon'])))
    elif adv_corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method,
                                   name, 'adv_corrupt',
                                   str(int(adv_args['epsilon'])))
    elif corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name,
                                   'corrupt')
    else:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'nat')

    if not os.path.exists(in_save_dir):
        os.makedirs(in_save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5
    elif in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 100
        num_reject_classes = 10
    elif in_dataset == "SVHN":
        normalizer = None
        testset = svhn.SVHN('datasets/svhn/',
                            split='test',
                            transform=transforms.ToTensor(),
                            download=False)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5

    if method != "sofl":
        num_reject_classes = 0

    if method == "rowl" or method == "atom" or method == "ntom":
        num_reject_classes = 1

    method_args['num_classes'] = num_classes

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    elif args.model_arch == 'densenet_ccu':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    elif args.model_arch == 'wideresnet_ccu':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=in_dataset, name=name, epochs=epochs))

    if args.model_arch == 'densenet_ccu' or args.model_arch == 'wideresnet_ccu':
        whole_model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if method == "mahalanobis":
        temp_x = torch.rand(2, 3, 32, 32)
        temp_x = Variable(temp_x).cuda()
        temp_list = model.feature_list(temp_x)[1]
        num_output = len(temp_list)
        method_args['num_output'] = num_output

    if adv or adv_corrupt:
        epsilon = adv_args['epsilon']
        iters = adv_args['iters']
        iter_size = adv_args['iter_size']

        if method == "msp" or method == "odin":
            attack_out = ConfidenceLinfPGDAttack(model,
                                                 eps=epsilon,
                                                 nb_iter=iters,
                                                 eps_iter=args.iter_size,
                                                 rand_init=True,
                                                 clip_min=0.,
                                                 clip_max=1.,
                                                 num_classes=num_classes)
        elif method == "mahalanobis":
            attack_out = MahalanobisLinfPGDAttack(model,
                                                  eps=args.epsilon,
                                                  nb_iter=args.iters,
                                                  eps_iter=iter_size,
                                                  rand_init=True,
                                                  clip_min=0.,
                                                  clip_max=1.,
                                                  num_classes=num_classes,
                                                  sample_mean=sample_mean,
                                                  precision=precision,
                                                  num_output=num_output,
                                                  regressor=regressor)
        elif method == "sofl":
            attack_out = SOFLLinfPGDAttack(
                model,
                eps=epsilon,
                nb_iter=iters,
                eps_iter=iter_size,
                rand_init=True,
                clip_min=0.,
                clip_max=1.,
                num_classes=num_classes,
                num_reject_classes=num_reject_classes)
        elif method == "rowl":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)
        elif method == "atom" or method == "ntom":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)

    if not mode_args['out_dist_only']:
        t0 = time.time()

        f1 = open(os.path.join(in_save_dir, "in_scores.txt"), 'w')
        g1 = open(os.path.join(in_save_dir, "in_labels.txt"), 'w')

        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        N = len(testloaderIn.dataset)
        count = 0
        for j, data in enumerate(testloaderIn):
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f1.write("{}\n".format(score))

            if method == "rowl":
                outputs = F.softmax(model(inputs), dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)
            else:
                outputs = F.softmax(model(inputs)[:, :num_classes], dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)

            for k in range(preds.shape[0]):
                g1.write("{} {} {}\n".format(labels[k], preds[k], confs[k]))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f1.close()
        g1.close()

    if mode_args['in_dist_only']:
        return

    for out_dataset in out_datasets:

        out_save_dir = os.path.join(in_save_dir, out_dataset)

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        f2 = open(os.path.join(out_save_dir, "out_scores.txt"), 'w')

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        if out_dataset == 'SVHN':
            testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                                   split='test',
                                   transform=transforms.ToTensor(),
                                   download=False)
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'dtd':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/dtd/images",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'places365':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/places365/test_subset",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        else:
            testsetout = torchvision.datasets.ImageFolder(
                "./datasets/ood_datasets/{}".format(out_dataset),
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")

        N = len(testloaderOut.dataset)
        count = 0
        for j, data in enumerate(testloaderOut):

            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            if adv:
                inputs = attack_out.perturb(images)
            elif corrupt:
                inputs = corrupt_attack(images, model, method, method_args,
                                        False, adv_args['severity_level'])
            elif adv_corrupt:
                corrupted_images = corrupt_attack(images, model, method,
                                                  method_args, False,
                                                  adv_args['severity_level'])
                inputs = attack_out.perturb(corrupted_images)
            else:
                inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f2.write("{}\n".format(score))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f2.close()

    return
Ejemplo n.º 11
0
def tune_mahalanobis_hyperparams():

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)

        args.epochs = 20
        num_classes = 10

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2,3,32,32)
    temp_x = Variable(temp_x).cuda()
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn)

    print('train logistic regression model')
    m = 500

    train_in = []
    train_in_label = []
    train_out = []

    val_in = []
    val_in_label = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        data = data.numpy()
        target = target.numpy()
        for x, y in zip(data, target):
            cnt += 1
            if cnt <= m:
                train_in.append(x)
                train_in_label.append(y)
            elif cnt <= 2*m:
                val_in.append(x)
                val_in_label.append(y)

            if cnt == 2*m:
                break
        if cnt == 2*m:
            break

    print('In', len(train_in), len(val_in))

    criterion = nn.CrossEntropyLoss().cuda()
    adv_noise = 0.05

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        train_out.extend(adv_data.cpu().numpy())

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        val_out.extend(adv_data.cpu().numpy())

    print('Out', len(train_out),len(val_out))

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(train_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(train_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]:
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total : total + args.batch_size].cuda()
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32)
        regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label)

        print('Logistic Regressor params:', regressor.coef_, regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')

    ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]
            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)
            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)

            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude
Ejemplo n.º 12
0
    torchvision.utils.save_image(images[0], os.path.join(save_dir, '100', '%d.png'%i))
    if i + 1 == 500:
        break

save_dir = "datasets/rowl_train_data/SVHN"

if not os.path.exists(save_dir):
	os.makedirs(save_dir)

for i in range(11):
    os.makedirs(os.path.join(save_dir, '%02d'%i))

class_count = np.zeros(10, dtype=np.int32)

train_loader = torch.utils.data.DataLoader(
    svhn.SVHN('datasets/svhn/', split='train', transform=transform, download=False),
    batch_size=1, shuffle=True)

for (xs, ys) in train_loader:
    image = xs[0]
    label = ys[0].numpy()
    torchvision.utils.save_image(image, os.path.join(save_dir, '%02d'%label, '%d.png'%class_count[label]))
    class_count[label] += 1

ood_loader.dataset.offset = np.random.randint(len(ood_loader.dataset))

print('SVHN, Class count: ', class_count)
ood_count = int(np.mean(class_count))
print('SVHN, OOD count: ', ood_count)

for i, (images, labels) in enumerate(ood_loader):