예제 #1
0
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalizer = transforms.Normalize((0.3337, 0.3064, 0.3171),
                                      (0.2672, 0.2564, 0.2629))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder('./datasets/gtsrb/data/train',
                                         transform=transform),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        './datasets/gtsrb/data/valid', transform=transform),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    num_classes = 43

    if args.ood:
        ood_loader = torch.utils.data.DataLoader(
            TinyImages(transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.ToPILImage(),
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])),
            batch_size=args.ood_batch_size,
            shuffle=False,
            **kwargs)

    # create model
    model = dn.DenseNet3(args.layers,
                         num_classes,
                         args.growth,
                         reduction=args.reduce,
                         bottleneck=args.bottleneck,
                         dropRate=args.droprate,
                         normalizer=normalizer)

    if args.adv:
        attack_in = LinfPGDAttack(model=model,
                                  eps=args.epsilon,
                                  nb_iter=args.iters,
                                  eps_iter=args.iter_size,
                                  rand_init=True,
                                  loss_func='CE')
        if args.ood:
            attack_out = LinfPGDAttack(model=model,
                                       eps=args.epsilon,
                                       nb_iter=args.iters,
                                       eps_iter=args.iter_size,
                                       rand_init=True,
                                       loss_func='OE')

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.ood:
        ood_criterion = OELoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    if args.lr_scheduler != 'cosine_annealing' and args.lr_scheduler != 'step_decay':
        assert False, 'Not supported lr_scheduler {}'.format(args.lr_scheduler)

    if args.lr_scheduler == 'cosine_annealing':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                args.epochs * len(train_loader),
                1,  # since lr_lambda computes multiplicative factor
                1e-6 / args.lr))
    else:
        scheduler = None

    for epoch in range(args.start_epoch, args.epochs):
        if args.lr_scheduler == 'step_decay':
            adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        if args.ood:
            if args.adv:
                train_ood(train_loader, ood_loader, model, criterion,
                          ood_criterion, optimizer, scheduler, epoch,
                          attack_in, attack_out)
            else:
                train_ood(train_loader, ood_loader, model, criterion,
                          ood_criterion, optimizer, scheduler, epoch)
        else:
            if args.adv:
                train(train_loader, model, criterion, optimizer, scheduler,
                      epoch, attack_in)
            else:
                train(train_loader, model, criterion, optimizer, scheduler,
                      epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule = [10, 15, 18]
        num_classes = 10

    out_loader = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.ood_batch_size,
        shuffle=False,
        **kwargs)

    # create model
    if args.model_arch == 'densenet':
        base_model = dn.DenseNet3(args.layers,
                                  num_classes,
                                  args.growth,
                                  reduction=args.reduce,
                                  bottleneck=args.bottleneck,
                                  dropRate=args.droprate,
                                  normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        base_model = wn.WideResNet(args.depth,
                                   num_classes,
                                   widen_factor=args.width,
                                   dropRate=args.droprate,
                                   normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100])

    gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')

    gmm.alpha = nn.Parameter(gmm.alpha)
    gmm.mu.requires_grad = True
    gmm.logvar.requires_grad = True
    gmm.alpha.requires_grad = False

    gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
    gmm_out.alpha = nn.Parameter(gmm.alpha)
    gmm_out.mu.requires_grad = True
    gmm_out.logvar.requires_grad = True
    gmm_out.alpha.requires_grad = False
    loglam = 0.
    model = gmmlib.DoublyRobustModel(base_model,
                                     gmm,
                                     gmm_out,
                                     loglam,
                                     dim=3072,
                                     classes=num_classes).cuda()

    model.loglam.requires_grad = False

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    lr = args.lr
    lr_gmm = 1e-5

    param_groups = [{
        'params': model.mm.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.mm_out.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.base_model.parameters(),
        'lr': lr,
        'weight_decay': args.weight_decay
    }]

    optimizer = torch.optim.SGD(param_groups,
                                momentum=args.momentum,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        lam = model.loglam.data.exp().item()
        train_CEDA_gmm_out(model,
                           train_loader,
                           out_loader,
                           optimizer,
                           epoch,
                           lam=lam)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
def tune_odin_hyperparams():
    print('Tuning hyper-parameters...')
    stypes = ['ODIN']

    save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset,
                            args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 10

    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar100',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                    batch_size=args.batch_size,
                                                    shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        args.epochs = 20
        num_classes = 10

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=False)

    valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset))

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    print('Len of val in: ', len(val_in))
    print('Len of val out: ', len(val_out))

    best_fpr = 1.1
    best_magnitude = 0.0
    for magnitude in np.arange(0, 0.0041, 0.004 / 20):

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f1.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f2.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['ODIN']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude

    return best_magnitude
예제 #4
0
def eval_ood_detector(base_dir, in_dataset, out_datasets, batch_size, method,
                      method_args, name, epochs, adv, corrupt, adv_corrupt,
                      adv_args, mode_args):

    if adv:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv',
                                   str(int(adv_args['epsilon'])))
    elif adv_corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method,
                                   name, 'adv_corrupt',
                                   str(int(adv_args['epsilon'])))
    elif corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name,
                                   'corrupt')
    else:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'nat')

    if not os.path.exists(in_save_dir):
        os.makedirs(in_save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5
    elif in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 100
        num_reject_classes = 10
    elif in_dataset == "SVHN":
        normalizer = None
        testset = svhn.SVHN('datasets/svhn/',
                            split='test',
                            transform=transforms.ToTensor(),
                            download=False)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5

    if method != "sofl":
        num_reject_classes = 0

    if method == "rowl" or method == "atom" or method == "ntom":
        num_reject_classes = 1

    method_args['num_classes'] = num_classes

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    elif args.model_arch == 'densenet_ccu':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    elif args.model_arch == 'wideresnet_ccu':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=in_dataset, name=name, epochs=epochs))

    if args.model_arch == 'densenet_ccu' or args.model_arch == 'wideresnet_ccu':
        whole_model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if method == "mahalanobis":
        temp_x = torch.rand(2, 3, 32, 32)
        temp_x = Variable(temp_x).cuda()
        temp_list = model.feature_list(temp_x)[1]
        num_output = len(temp_list)
        method_args['num_output'] = num_output

    if adv or adv_corrupt:
        epsilon = adv_args['epsilon']
        iters = adv_args['iters']
        iter_size = adv_args['iter_size']

        if method == "msp" or method == "odin":
            attack_out = ConfidenceLinfPGDAttack(model,
                                                 eps=epsilon,
                                                 nb_iter=iters,
                                                 eps_iter=args.iter_size,
                                                 rand_init=True,
                                                 clip_min=0.,
                                                 clip_max=1.,
                                                 num_classes=num_classes)
        elif method == "mahalanobis":
            attack_out = MahalanobisLinfPGDAttack(model,
                                                  eps=args.epsilon,
                                                  nb_iter=args.iters,
                                                  eps_iter=iter_size,
                                                  rand_init=True,
                                                  clip_min=0.,
                                                  clip_max=1.,
                                                  num_classes=num_classes,
                                                  sample_mean=sample_mean,
                                                  precision=precision,
                                                  num_output=num_output,
                                                  regressor=regressor)
        elif method == "sofl":
            attack_out = SOFLLinfPGDAttack(
                model,
                eps=epsilon,
                nb_iter=iters,
                eps_iter=iter_size,
                rand_init=True,
                clip_min=0.,
                clip_max=1.,
                num_classes=num_classes,
                num_reject_classes=num_reject_classes)
        elif method == "rowl":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)
        elif method == "atom" or method == "ntom":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)

    if not mode_args['out_dist_only']:
        t0 = time.time()

        f1 = open(os.path.join(in_save_dir, "in_scores.txt"), 'w')
        g1 = open(os.path.join(in_save_dir, "in_labels.txt"), 'w')

        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        N = len(testloaderIn.dataset)
        count = 0
        for j, data in enumerate(testloaderIn):
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f1.write("{}\n".format(score))

            if method == "rowl":
                outputs = F.softmax(model(inputs), dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)
            else:
                outputs = F.softmax(model(inputs)[:, :num_classes], dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)

            for k in range(preds.shape[0]):
                g1.write("{} {} {}\n".format(labels[k], preds[k], confs[k]))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f1.close()
        g1.close()

    if mode_args['in_dist_only']:
        return

    for out_dataset in out_datasets:

        out_save_dir = os.path.join(in_save_dir, out_dataset)

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        f2 = open(os.path.join(out_save_dir, "out_scores.txt"), 'w')

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        if out_dataset == 'SVHN':
            testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                                   split='test',
                                   transform=transforms.ToTensor(),
                                   download=False)
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'dtd':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/dtd/images",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'places365':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/places365/test_subset",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        else:
            testsetout = torchvision.datasets.ImageFolder(
                "./datasets/ood_datasets/{}".format(out_dataset),
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")

        N = len(testloaderOut.dataset)
        count = 0
        for j, data in enumerate(testloaderOut):

            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            if adv:
                inputs = attack_out.perturb(images)
            elif corrupt:
                inputs = corrupt_attack(images, model, method, method_args,
                                        False, adv_args['severity_level'])
            elif adv_corrupt:
                corrupted_images = corrupt_attack(images, model, method,
                                                  method_args, False,
                                                  adv_args['severity_level'])
                inputs = attack_out.perturb(corrupted_images)
            else:
                inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f2.write("{}\n".format(score))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f2.close()

    return
예제 #5
0
def eval_acc():
    print('test accuracy')

    save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name, 'adv' if args.adv else 'nat')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    start = time.time()
    # loading data sets
    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        testset = torchvision.datasets.CIFAR10(root='../../data', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                   shuffle=True, num_workers=2)
        num_classes = 5
    elif args.in_dataset == "CIFAR-100":
        testset = torchvision.datasets.CIFAR100(root='../../data', train=False, download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                                   shuffle=True, num_workers=2)
        num_classes = 100

    model1 = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    model2 = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint1 = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(name=args.name + '04', epochs=args.epochs))
    model1.load_state_dict(checkpoint1['state_dict'])

    model1.eval()
    model1.cuda()

    checkpoint2 = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(name=args.name + '59', epochs=args.epochs))
    model2.load_state_dict(checkpoint2['state_dict'])

    model2.eval()
    model2.cuda()

    nat_losses = AverageMeter()
    nat_top1 = AverageMeter()

    # 纯噪声数据测试
    # num_gaussian_inputs = 10
    # gaussian_inputs = torch.rand((num_gaussian_inputs, 3, 32, 32))
    # print(gaussian_inputs.size())
    # output = model(gaussian_inputs)
    criterion = nn.CrossEntropyLoss()

    for i, (input, target) in enumerate(testloaderIn):
        target = target.cuda()
        # print(target)

        nat_input = input.detach().clone()
        input1 = Variable(nat_input, requires_grad=True)
        input2 = Variable(nat_input, requires_grad=True)

        output1 = model1(input1)
        output2 = model2(input2)
        sm_score1 = F.softmax(output1, dim=1)
        sm_score2 = F.softmax(output2, dim=1)
        # print(output1)
        # print(output2)

        temper = 1000
        noiseMagnitude1 = 0.0014

        maxIndexTemp1 = np.argmax(output1.data.cpu().numpy(), axis=1)
        temp_output1 = output1 / temper
        pred1 = Variable(torch.LongTensor(maxIndexTemp1).cuda())
        loss1 = criterion(temp_output1, pred1)
        loss1.backward()
        gradient1 = torch.ge(input1.grad.data, 0)
        tempInput1 = torch.add(input1.data, -noiseMagnitude1, gradient1)
        temp_output1 = model1(Variable(tempInput1))
        temp_output1 = temp_output1 / temper
        nnOutput1 = temp_output1.data.cpu().numpy()
        nnOutput1 = nnOutput1 - np.max(nnOutput1, axis=1, keepdims=True)
        nnOutput1 = np.exp(nnOutput1) / np.sum(np.exp(nnOutput1), axis=1, keepdims=True)
        odin1 = np.max(nnOutput1, axis=1)

        maxIndexTemp2 = np.argmax(output2.data.cpu().numpy(), axis=1)
        temp_output2 = output2 / temper
        pred2 = Variable(torch.LongTensor(maxIndexTemp2).cuda())
        loss2 = criterion(temp_output2, pred2)
        loss2.backward()
        gradient2 = torch.ge(input2.grad.data, 0)
        tempInput2 = torch.add(input2.data, -noiseMagnitude1, gradient2)
        temp_output2 = model2(Variable(tempInput2))
        temp_output2 = temp_output2 / temper
        nnOutput2 = temp_output2.data.cpu().numpy()
        nnOutput2 = nnOutput2 - np.max(nnOutput2, axis=1, keepdims=True)
        nnOutput2 = np.exp(nnOutput2) / np.sum(np.exp(nnOutput2), axis=1, keepdims=True)
        odin2 = np.max(nnOutput2, axis=1)

        mask = [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0] if x else [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] for x in odin1 > odin2]
        # for x in odin1 > odin2:
        #     if x:
        #         mask.append([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
        #     else:
        #         mask.append([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
        # print(mask)

        # exit(0)

        torch.set_printoptions(precision=8, sci_mode=False)
        # print(output1.size())
        # print(output2.size())

        nat_output = torch.cat((output1, output2), dim=1) * torch.tensor(mask).float().cuda()
        # print(nat_output)
        nat_output = F.softmax(nat_output, dim=1)
        # print(nat_output)
        # exit(0)
        # print(nat_output.size())
        # print(nat_output[:10])
        # print(torch.argmax(torch.cat((output1, output2), dim=1), dim=1))
        # print(torch.argmax(nat_output, dim=1))
        # print(target[:10])
        # exit(0)

        nat_loss = criterion(nat_output, target)

        # measure accuracy and record loss
        nat_prec1 = accuracy(nat_output.data, target, topk=(1,))[0]
        nat_losses.update(nat_loss.data, input.size(0))
        nat_top1.update(nat_prec1, input.size(0))

        # # compute gradient and do SGD step
        # loss = nat_loss
        # if args.lr_scheduler == 'cosine_annealing':
        #     scheduler.step()
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        #
        # # measure elapsed time
        # batch_time.update(time.time() - end)
        # end = time.time()

        if i % args.print_freq == 0 or i == len(testloaderIn) - 1:
            print('Epoch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                i, len(testloaderIn),
                loss=nat_losses, top1=nat_top1))
예제 #6
0
파일: train.py 프로젝트: hunu12/MALCOM
def main(args):
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_idx)
    os.environ["CUDA_DEVICE"] = str(args.gpu_idx)

    np.random.seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.set_device(args.gpu_idx)

    out_file = os.path.join(args.output_dir,
                            '{}_{}.pth'.format(args.net_type, args.dataset))

    # set the transformations for training
    tfs_for_augmentation = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
    if args.dataset == 'cifar10':
        train_transform = transforms.Compose(tfs_for_augmentation + [
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435,
                                                            0.2616))
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4821, 0.4465),
                                 (0.2470, 0.2435, 0.2616)),
        ])
    elif args.dataset == 'cifar100':
        train_transform = transforms.Compose(tfs_for_augmentation + [
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564,
                                                            0.2762)),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409),
                                 (0.2673, 0.2564, 0.2762)),
        ])
    elif args.dataset == 'svhn':
        train_transform = transforms.Compose([transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])

    # load model
    if args.net_type == 'densenet':
        if args.dataset == 'svhn':
            model = densenet.DenseNet3(100,
                                       args.num_classes,
                                       growth_rate=12,
                                       dropRate=0.2)
        else:
            model = densenet.DenseNet3(100, args.num_classes, growth_rate=12)
    elif args.net_type == 'resnet':
        model = resnet.ResNet34(num_c=args.num_classes)
    elif args.net_type == 'vanilla':
        model = vanilla.VanillaCNN(args.num_classes)
    model.cuda()
    print('load model: ' + args.net_type)

    # load dataset
    print('load target data: ' + args.dataset)
    if args.dataset == 'svhn':
        train_loader, valid_loader = data_utils.get_dataloader(
            args.dataset,
            args.data_root,
            'train',
            train_transform,
            args.batch_size,
            valid_transform=test_transform)
    else:
        train_loader = data_utils.get_dataloader(args.dataset, args.data_root,
                                                 'train', train_transform,
                                                 args.batch_size)
    test_loader = data_utils.get_dataloader(args.dataset, args.data_root,
                                            'test', test_transform,
                                            args.batch_size)

    # define objective and optimizer
    criterion = nn.CrossEntropyLoss()
    if args.net_type == 'densenet' or args.net_type == 'vanilla':
        weight_decay = 1e-4
        milestones = [150, 225]
        gamma = 0.1
    elif args.net_type == 'resnet':
        weight_decay = 5e-4
        milestones = [60, 120, 160]
        gamma = 0.2
    if args.dataset == 'svhn' or args.net_type == 'vanilla':
        milestones = [20, 30]

    optimizer = optim.SGD(model.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones,
                                               gamma=gamma)

    # train
    best_loss = np.inf
    iter_cnt = 0
    for epoch in range(args.num_epochs):
        model.train()
        total, total_loss, total_step = 0, 0, 0
        for _, (data, labels) in enumerate(train_loader):
            data = data.cuda()
            labels = labels.cuda()
            total += data.size(0)

            outputs = model(data)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * data.size(0)
            iter_cnt += 1
            total_step += 1

            if args.dataset == 'svhn' and iter_cnt >= 200:
                valid_loss, _ = evaluation(model, valid_loader, criterion)
                test_loss, acc = evaluation(model, test_loader, criterion)
                print(
                    'Epoch [{:03d}/{:03d}], step [{}/{}] train loss : {:.4f}, valid loss : {:.4f}, test loss : {:.4f}, test acc : {:.2f} %'
                    .format(epoch + 1, args.num_epochs, total_step,
                            len(train_loader), total_loss / total, valid_loss,
                            test_loss, 100 * acc))
                if valid_loss < best_loss:
                    best_loss = valid_loss
                    torch.save(model, out_file)
                iter_cnt = 0
                model.train()

        if args.dataset != 'svhn':
            test_loss, acc = evaluation(model, test_loader, criterion)
            print(
                '[{:03d}/{:03d}] train loss : {:.4f}, test loss : {:.4f}, test acc : {:.2f} %'
                .format(epoch + 1, args.num_epochs, total_loss / total,
                        test_loss, 100 * acc))
            torch.save(model, out_file)

        scheduler.step()
예제 #7
0
def tune_mahalanobis_hyperparams():

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)

        args.epochs = 20
        num_classes = 10

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2,3,32,32)
    temp_x = Variable(temp_x).cuda()
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn)

    print('train logistic regression model')
    m = 500

    train_in = []
    train_in_label = []
    train_out = []

    val_in = []
    val_in_label = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        data = data.numpy()
        target = target.numpy()
        for x, y in zip(data, target):
            cnt += 1
            if cnt <= m:
                train_in.append(x)
                train_in_label.append(y)
            elif cnt <= 2*m:
                val_in.append(x)
                val_in_label.append(y)

            if cnt == 2*m:
                break
        if cnt == 2*m:
            break

    print('In', len(train_in), len(val_in))

    criterion = nn.CrossEntropyLoss().cuda()
    adv_noise = 0.05

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        train_out.extend(adv_data.cpu().numpy())

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        val_out.extend(adv_data.cpu().numpy())

    print('Out', len(train_out),len(val_out))

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(train_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(train_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]:
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total : total + args.batch_size].cuda()
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32)
        regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label)

        print('Logistic Regressor params:', regressor.coef_, regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')

    ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]
            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)
            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)

            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude