Exemple #1
0
import tikzplotlib
from tqdm import tqdm, trange
import torch.utils.data as data_utils

parser = argparse.ArgumentParser()
parser.add_argument('--randseed', type=int, default=123)
args = parser.parse_args()

train_loader = dl.CIFAR10(train=True, augm_flag=False)
val_loader, test_loader = dl.CIFAR10(train=False, val_size=2000)
targets = torch.cat([y for x, y in test_loader], dim=0).numpy()
print(len(train_loader.dataset), len(val_loader.dataset),
      len(test_loader.dataset))

test_loader_SVHN = dl.SVHN(train=False)
test_loader_LSUN = dl.LSUN_CR(train=False)

tab_ood = {
    'CIFAR10 - CIFAR10': [],
    'CIFAR10 - SVHN': [],
    'CIFAR10 - LSUN': [],
    'CIFAR10 - FarAway': [],
    'CIFAR10 - Adversarial': [],
    'CIFAR10 - FarAwayAdv': []
}

tab_cal = {'DKL': ([], [])}

delta = 2000

np.random.seed(args.randseed)
val_loader, test_loader = dl.binary_CIFAR10(class1,
                                            class2,
                                            train=False,
                                            augm_flag=False,
                                            val_size=1000)
targets = torch.cat([y for x, y in test_loader], dim=0).numpy()
targets_val = torch.cat([y for x, y in val_loader], dim=0).numpy()
print(len(train_loader.dataset), len(val_loader.dataset),
      len(test_loader.dataset))

test_loader_SVHN, _ = dl.binary_SVHN(3,
                                     9,
                                     train=False,
                                     augm_flag=False,
                                     val_size=1000)
test_loader_LSUN = dl.LSUN_CR(train=False, augm_flag=False)

ood_loader = dl.UniformNoise('CIFAR10', size=1000)
noise_loader = dl.UniformNoise('CIFAR10', size=2000)


def load_model():
    model = resnet.ResNet18(num_classes=2).cuda()
    model.load_state_dict(torch.load(f'./pretrained_models/binary_CIFAR10.pt'))
    model.eval()
    return model


tab_ood = {
    'CIFAR10 - CIFAR10': [],
    'CIFAR10 - SVHN': [],