コード例 #1
0
def get_mahalanobis_score(inputs, model, method_args):
    num_classes = method_args['num_classes']
    sample_mean = method_args['sample_mean']
    precision = method_args['precision']
    magnitude = method_args['magnitude']
    regressor = method_args['regressor']
    num_output = method_args['num_output']

    Mahalanobis_scores = get_Mahalanobis_score(inputs, model, num_classes,
                                               sample_mean, precision,
                                               num_output, magnitude)
    scores = -regressor.predict_proba(Mahalanobis_scores)[:, 1]

    return scores
コード例 #2
0
def tune_mahalanobis_hyperparams():
    def print_tuning_results(results, stypes):
        mtypes = ['FPR', 'DTERR', 'AUROC', 'AUIN', 'AUOUT']

        for stype in stypes:
            print(' OOD detection method: ' + stype)
            for mtype in mtypes:
                print(' {mtype:6s}'.format(mtype=mtype), end='')
            print('\n{val:6.2f}'.format(val=100. * results[stype]['FPR']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['DTERR']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['AUROC']),
                  end='')
            print(' {val:6.2f}'.format(val=100. * results[stype]['AUIN']),
                  end='')
            print(' {val:6.2f}\n'.format(val=100. * results[stype]['AUOUT']),
                  end='')
            print('')

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/hyperparams/', args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                                      (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar10',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 100

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2)

    model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(
            name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2, 3, 32, 32)
    temp_x = Variable(temp_x)
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list,
                                              trainloaderIn)

    print('train logistic regression model')
    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in trainloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(data[0].numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(val_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(val_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in np.arange(0, 0.0041, 0.004 / 20):
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(
                int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total:total + args.batch_size]
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(model, data,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis,
                                          dtype=np.float32)

        regressor = LogisticRegressionCV().fit(train_lr_Mahalanobis,
                                               train_lr_label)

        print('Logistic Regressor params:', regressor.coef_,
              regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"),
                  'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(model, images,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)

            confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:,
                                                                            1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, m,
                time.time() - t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(model, images,
                                                       num_classes,
                                                       sample_mean, precision,
                                                       num_output, magnitude)

            confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:,
                                                                            1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, m,
                time.time() - t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_tuning_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_,
          best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude
コード例 #3
0
def eval_mahalanobis(sample_mean, precision, regressor, magnitude):
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name,
                            'adv' if args.adv else 'nat')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    start = time.time()
    #loading data sets

    normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                                      (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar10',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        num_classes = 100

    model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)

    checkpoint = torch.load(
        "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format(
            name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if args.out_dataset == 'SVHN':
        testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                               split='test',
                               transform=transforms.ToTensor(),
                               download=False)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'dtd':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/dtd/images",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    elif args.out_dataset == 'places365':
        testsetout = torchvision.datasets.ImageFolder(
            root="datasets/ood_datasets/places365/test_subset",
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor()
            ]))
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)
    else:
        testsetout = torchvision.datasets.ImageFolder(
            "./datasets/ood_datasets/{}".format(args.out_dataset),
            transform=transform)
        testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=2)

    # set information about feature extaction
    temp_x = torch.rand(2, 3, 32, 32)
    temp_x = Variable(temp_x)
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)

    t0 = time.time()
    f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
    f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')
    N = 10000
    if args.out_dataset == "iSUN": N = 8925
    if args.out_dataset == "dtd": N = 5640
    ########################################In-distribution###########################################
    print("Processing in-distribution images")
    if args.adv:
        attack = MahalanobisLinfPGDAttack(model,
                                          eps=args.epsilon,
                                          nb_iter=args.iters,
                                          eps_iter=args.iter_size,
                                          rand_init=True,
                                          clip_min=0.,
                                          clip_max=1.,
                                          in_distribution=True,
                                          num_classes=num_classes,
                                          sample_mean=sample_mean,
                                          precision=precision,
                                          num_output=num_output,
                                          regressor=regressor)

    count = 0
    for j, data in enumerate(testloaderIn):

        images, _ = data
        batch_size = images.shape[0]

        if count + batch_size > N:
            images = images[:N - count]
            batch_size = images.shape[0]

        if args.adv:
            inputs = attack.perturb(images)
        else:
            inputs = images

        Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes,
                                                   sample_mean, precision,
                                                   num_output, magnitude)

        confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1]

        for k in range(batch_size):
            f1.write("{}\n".format(-confidence_scores[k]))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break


###################################Out-of-Distributions#####################################
    t0 = time.time()
    print("Processing out-of-distribution images")
    if args.adv:
        attack = MahalanobisLinfPGDAttack(model,
                                          eps=args.epsilon,
                                          nb_iter=args.iters,
                                          eps_iter=args.iter_size,
                                          rand_init=True,
                                          clip_min=0.,
                                          clip_max=1.,
                                          in_distribution=False,
                                          num_classes=num_classes,
                                          sample_mean=sample_mean,
                                          precision=precision,
                                          num_output=num_output,
                                          regressor=regressor)

    count = 0

    for j, data in enumerate(testloaderOut):

        images, labels = data
        batch_size = images.shape[0]

        if args.adv:
            inputs = attack.perturb(images)
        else:
            inputs = images

        Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes,
                                                   sample_mean, precision,
                                                   num_output, magnitude)

        confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1]

        for k in range(batch_size):
            f2.write("{}\n".format(-confidence_scores[k]))

        count += batch_size
        print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
            count, N,
            time.time() - t0))
        t0 = time.time()

        if count == N: break

    f1.close()
    f2.close()

    results = metric(save_dir, stypes)

    print_results(results, stypes)
    return
コード例 #4
0
def tune_mahalanobis_hyperparams():

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)

        args.epochs = 20
        num_classes = 10

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2,3,32,32)
    temp_x = Variable(temp_x).cuda()
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn)

    print('train logistic regression model')
    m = 500

    train_in = []
    train_in_label = []
    train_out = []

    val_in = []
    val_in_label = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        data = data.numpy()
        target = target.numpy()
        for x, y in zip(data, target):
            cnt += 1
            if cnt <= m:
                train_in.append(x)
                train_in_label.append(y)
            elif cnt <= 2*m:
                val_in.append(x)
                val_in_label.append(y)

            if cnt == 2*m:
                break
        if cnt == 2*m:
            break

    print('In', len(train_in), len(val_in))

    criterion = nn.CrossEntropyLoss().cuda()
    adv_noise = 0.05

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        train_out.extend(adv_data.cpu().numpy())

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        val_out.extend(adv_data.cpu().numpy())

    print('Out', len(train_out),len(val_out))

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(train_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(train_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]:
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total : total + args.batch_size].cuda()
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32)
        regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label)

        print('Logistic Regressor params:', regressor.coef_, regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')

    ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]
            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)
            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)

            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude