def main(): if args.tensorboard: configure("runs/%s"%(args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule=[50, 75, 90] pool_size = args.pool_size num_classes = 10 elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule=[50, 75, 90] pool_size = args.pool_size num_classes = 100 elif args.in_dataset == "SVHN": # Data loading code normalizer = None train_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule=[10, 15, 18] pool_size = int(len(train_loader.dataset) * 8 / args.ood_batch_size) + 1 num_classes = 10 ood_dataset_size = len(train_loader.dataset) * 2 print('OOD Dataset Size: ', ood_dataset_size) if args.auxiliary_dataset == '80m_tiny_images': ood_loader = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose( [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) elif args.auxiliary_dataset == 'imagenet': ood_loader = torch.utils.data.DataLoader( ImageNet(transform=transforms.Compose( [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) # create model if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: assert False, "=> no checkpoint found at '{}'".format(args.resume) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch selected_ood_loader = select_ood(ood_loader, model, args.batch_size * 2, num_classes, pool_size, ood_dataset_size, args.quantile) train_ntom(train_loader, selected_ood_loader, model, criterion, num_classes, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, num_classes) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
def eval_mahalanobis(sample_mean, precision, regressor, magnitude): stypes = ['mahalanobis'] save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name, 'adv' if args.adv else 'nat') if not os.path.exists(save_dir): os.makedirs(save_dir) start = time.time() #loading data sets normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) if args.in_dataset == "CIFAR-10": trainset = torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 10 elif args.in_dataset == "CIFAR-100": trainset = torchvision.datasets.CIFAR100('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 100 model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) checkpoint = torch.load( "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format( name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() if args.out_dataset == 'SVHN': testsetout = svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) elif args.out_dataset == 'dtd': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/dtd/images", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) elif args.out_dataset == 'places365': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/places365/test_subset", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) else: testsetout = torchvision.datasets.ImageFolder( "./datasets/ood_datasets/{}".format(args.out_dataset), transform=transform) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) # set information about feature extaction temp_x = torch.rand(2, 3, 32, 32) temp_x = Variable(temp_x) temp_list = model.feature_list(temp_x)[1] num_output = len(temp_list) t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w') N = 10000 if args.out_dataset == "iSUN": N = 8925 if args.out_dataset == "dtd": N = 5640 ########################################In-distribution########################################### print("Processing in-distribution images") if args.adv: attack = MahalanobisLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., in_distribution=True, num_classes=num_classes, sample_mean=sample_mean, precision=precision, num_output=num_output, regressor=regressor) count = 0 for j, data in enumerate(testloaderIn): images, _ = data batch_size = images.shape[0] if count + batch_size > N: images = images[:N - count] batch_size = images.shape[0] if args.adv: inputs = attack.perturb(images) else: inputs = images Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f1.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() if count == N: break ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") if args.adv: attack = MahalanobisLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., in_distribution=False, num_classes=num_classes, sample_mean=sample_mean, precision=precision, num_output=num_output, regressor=regressor) count = 0 for j, data in enumerate(testloaderOut): images, labels = data batch_size = images.shape[0] if args.adv: inputs = attack.perturb(images) else: inputs = images Mahalanobis_scores = get_Mahalanobis_score(model, inputs, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores = regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f2.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() if count == N: break f1.close() f2.close() results = metric(save_dir, stypes) print_results(results, stypes) return
def eval_msp_and_odin(): stypes = ['MSP', 'ODIN'] save_dir = os.path.join('output/ood_scores/', args.out_dataset, args.name, 'adv' if args.adv else 'nat') if not os.path.exists(save_dir): os.makedirs(save_dir) start = time.time() #loading data sets normalizer = transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) if args.in_dataset == "CIFAR-10": testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 10 elif args.in_dataset == "CIFAR-100": testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 100 model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) checkpoint = torch.load( "./checkpoints/{name}/checkpoint_{epochs}.pth.tar".format( name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() if args.out_dataset == 'SVHN': testsetout = svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) elif args.out_dataset == 'dtd': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/dtd/images", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) elif args.out_dataset == 'places365': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/places365/test_subset", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) else: testsetout = torchvision.datasets.ImageFolder( "./datasets/ood_datasets/{}".format(args.out_dataset), transform=transform) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.batch_size, shuffle=True, num_workers=2) t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_MSP_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_MSP_Out.txt"), 'w') g1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w') g2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w') N = 10000 if args.out_dataset == "iSUN": N = 8925 if args.out_dataset == "dtd": N = 5640 ########################################In-distribution########################################### print("Processing in-distribution images") if args.adv: attack = ConfidenceLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., in_distribution=True, num_classes=num_classes) count = 0 for j, data in enumerate(testloaderIn): images, _ = data batch_size = images.shape[0] if count + batch_size > N: images = images[:N - count] batch_size = images.shape[0] if args.adv: adv_images = attack.perturb(images) inputs = Variable(adv_images, requires_grad=True) else: inputs = Variable(images, requires_grad=True) outputs = model(inputs) nnOutputs = MSP(outputs, model) for k in range(batch_size): f1.write("{}\n".format(np.max(nnOutputs[k]))) nnOutputs = ODIN(inputs, outputs, model, temper=args.temperature, noiseMagnitude1=args.magnitude) for k in range(batch_size): g1.write("{}\n".format(np.max(nnOutputs[k]))) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() if count == N: break ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") if args.adv: attack = ConfidenceLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., in_distribution=False, num_classes=num_classes) count = 0 for j, data in enumerate(testloaderOut): images, labels = data batch_size = images.shape[0] if args.adv: adv_images = attack.perturb(images) inputs = Variable(adv_images, requires_grad=True) else: inputs = Variable(images, requires_grad=True) outputs = model(inputs) nnOutputs = MSP(outputs, model) for k in range(batch_size): f2.write("{}\n".format(np.max(nnOutputs[k]))) nnOutputs = ODIN(inputs, outputs, model, temper=args.temperature, noiseMagnitude1=args.magnitude) for k in range(batch_size): g2.write("{}\n".format(np.max(nnOutputs[k]))) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() if count == N: break f1.close() f2.close() g1.close() g2.close() results = metric(save_dir, stypes) print_results(results, stypes)
# Acceleration parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') args = parser.parse_args() state = {k: v for k, v in args._get_kwargs()} print(state) torch.manual_seed(1) np.random.seed(1) train_data_in = svhn.SVHN('/share/data/vision-greg/svhn/', split='train_and_extra', transform=trn.ToTensor(), download=False) test_data = svhn.SVHN('/share/data/vision-greg/svhn/', split='test', transform=trn.ToTensor(), download=False) num_classes = 10 calib_indicator = '' if args.calibration: train_data_in, val_data = validation_split(train_data_in, val_share=5000 / 604388.) calib_indicator = 'calib_' tiny_images = TinyImages(transform=trn.Compose([ trn.ToTensor(),
])) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True, num_workers=args.prefetch, pin_memory=True) print('\n\nTexture Calibration') get_and_print_results(ood_loader) # /////////////// SVHN /////////////// ood_data = svhn.SVHN(root='/share/data/vision-greg/svhn/', split="test", transform=trn.Compose([ trn.Resize(32), trn.ToTensor(), trn.Normalize(mean, std) ]), download=False) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True, num_workers=args.prefetch, pin_memory=True) print('\n\nSVHN Calibration') get_and_print_results(ood_loader) # /////////////// Places365 /////////////// ood_data = dset.ImageFolder(
def main(): if args.tensorboard: configure("runs/%s"%(args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-10', transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) num_classes = 10 lr_schedule=[50, 75, 90] elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-100', transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) num_classes = 100 lr_schedule=[50, 75, 90] elif args.in_dataset == "SVHN": # Data loading code normalizer = None transform = transforms.Compose([transforms.ToTensor(),]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/SVHN', transform=transform), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule=[10, 15, 18] num_classes = 10 # create model if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) attack = LinfPGDAttack(model = model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, targeted=True, num_classes=num_classes+1, loss_func='CE', elementwise_best=True) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch train_rowl(train_loader, model, criterion, optimizer, epoch, num_classes, attack) # evaluate on validation set prec1 = validate(val_loader, model, criterion, num_classes, epoch) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
trn.ToTensor(), trn.Normalize(mean, std) ])) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True, num_workers=4, pin_memory=True) print('\n\nTexture Detection') get_and_print_results(ood_loader) # /////////////// SVHN /////////////// # cropped and no sampling of the test set ood_data = svhn.SVHN( root='../data/svhn/', split="test", transform=trn.Compose([ #trn.Resize(32), trn.ToTensor(), trn.Normalize(mean, std) ]), download=False) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True, num_workers=2, pin_memory=True) print('\n\nSVHN Detection') get_and_print_results(ood_loader) # /////////////// Places365 /////////////// ood_data = dset.ImageFolder(root="../data/places365/", transform=trn.Compose([ trn.Resize(32),
def main(): if args.tensorboard: configure("runs/%s" % (args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './datasets/cifar10', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule = [50, 75, 90] num_classes = 10 elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR100( './datasets/cifar100', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(datasets.CIFAR100( './datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule = [50, 75, 90] num_classes = 100 elif args.in_dataset == "SVHN": # Data loading code normalizer = None train_loader = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule = [10, 15, 18] num_classes = 10 out_loader = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) # create model if args.model_arch == 'densenet': base_model = dn.DenseNet3(args.layers, num_classes, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': base_model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100]) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm.mu.requires_grad = True gmm.logvar.requires_grad = True gmm.alpha.requires_grad = False gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) gmm_out.mu.requires_grad = True gmm_out.logvar.requires_grad = True gmm_out.alpha.requires_grad = False loglam = 0. model = gmmlib.DoublyRobustModel(base_model, gmm, gmm_out, loglam, dim=3072, classes=num_classes).cuda() model.loglam.requires_grad = False # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # define loss function (criterion) and pptimizer lr = args.lr lr_gmm = 1e-5 param_groups = [{ 'params': model.mm.parameters(), 'lr': lr_gmm, 'weight_decay': 0. }, { 'params': model.mm_out.parameters(), 'lr': lr_gmm, 'weight_decay': 0. }, { 'params': model.base_model.parameters(), 'lr': lr, 'weight_decay': args.weight_decay }] optimizer = torch.optim.SGD(param_groups, momentum=args.momentum, nesterov=True) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch lam = model.loglam.data.exp().item() train_CEDA_gmm_out(model, train_loader, out_loader, optimizer, epoch, lam=lam) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
def tune_odin_hyperparams(): print('Tuning hyper-parameters...') stypes = ['ODIN'] save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset, args.name, 'tmp') if not os.path.exists(save_dir): os.makedirs(save_dir) transform = transforms.Compose([ transforms.ToTensor(), ]) if args.in_dataset == "CIFAR-10": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) trainset = torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True) num_classes = 10 elif args.in_dataset == "CIFAR-100": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) trainset = torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True) num_classes = 100 elif args.in_dataset == "SVHN": normalizer = None trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) testloaderIn = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) args.epochs = 20 num_classes = 10 valloaderOut = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])), batch_size=args.batch_size, shuffle=False) valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset)) if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load( "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format( in_dataset=args.in_dataset, name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() m = 1000 val_in = [] val_out = [] cnt = 0 for data, target in testloaderIn: for x in data: val_in.append(x.numpy()) cnt += 1 if cnt == m: break if cnt == m: break cnt = 0 for data, target in valloaderOut: for x in data: val_out.append(x.numpy()) cnt += 1 if cnt == m: break if cnt == m: break print('Len of val in: ', len(val_in)) print('Len of val out: ', len(val_out)) best_fpr = 1.1 best_magnitude = 0.0 for magnitude in np.arange(0, 0.0041, 0.004 / 20): t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") count = 0 for i in range(int(m / args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor( val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)]) images = images.cuda() # if j<1000: continue batch_size = images.shape[0] scores = get_odin_score(images, model, temper=1000, noiseMagnitude1=magnitude) for k in range(batch_size): f1.write("{}\n".format(scores[k])) count += batch_size # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") count = 0 for i in range(int(m / args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor( val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)]) images = images.cuda() # if j<1000: continue batch_size = images.shape[0] scores = get_odin_score(images, model, temper=1000, noiseMagnitude1=magnitude) for k in range(batch_size): f2.write("{}\n".format(scores[k])) count += batch_size # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() f1.close() f2.close() results = metric(save_dir, stypes) print_results(results, stypes) fpr = results['ODIN']['FPR'] if fpr < best_fpr: best_fpr = fpr best_magnitude = magnitude return best_magnitude
def eval_ood_detector(base_dir, in_dataset, out_datasets, batch_size, method, method_args, name, epochs, adv, corrupt, adv_corrupt, adv_args, mode_args): if adv: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv', str(int(adv_args['epsilon']))) elif adv_corrupt: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv_corrupt', str(int(adv_args['epsilon']))) elif corrupt: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'corrupt') else: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'nat') if not os.path.exists(in_save_dir): os.makedirs(in_save_dir) transform = transforms.Compose([ transforms.ToTensor(), ]) if in_dataset == "CIFAR-10": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 10 num_reject_classes = 5 elif in_dataset == "CIFAR-100": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 100 num_reject_classes = 10 elif in_dataset == "SVHN": normalizer = None testset = svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 10 num_reject_classes = 5 if method != "sofl": num_reject_classes = 0 if method == "rowl" or method == "atom" or method == "ntom": num_reject_classes = 1 method_args['num_classes'] = num_classes if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + num_reject_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + num_reject_classes, widen_factor=args.width, normalizer=normalizer) elif args.model_arch == 'densenet_ccu': model = dn.DenseNet3(args.layers, num_classes + num_reject_classes, normalizer=normalizer) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) whole_model = gmmlib.DoublyRobustModel(model, gmm, gmm_out, loglam=0., dim=3072, classes=num_classes) elif args.model_arch == 'wideresnet_ccu': model = wn.WideResNet(args.depth, num_classes + num_reject_classes, widen_factor=args.width, normalizer=normalizer) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) whole_model = gmmlib.DoublyRobustModel(model, gmm, gmm_out, loglam=0., dim=3072, classes=num_classes) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load( "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format( in_dataset=in_dataset, name=name, epochs=epochs)) if args.model_arch == 'densenet_ccu' or args.model_arch == 'wideresnet_ccu': whole_model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() if method == "mahalanobis": temp_x = torch.rand(2, 3, 32, 32) temp_x = Variable(temp_x).cuda() temp_list = model.feature_list(temp_x)[1] num_output = len(temp_list) method_args['num_output'] = num_output if adv or adv_corrupt: epsilon = adv_args['epsilon'] iters = adv_args['iters'] iter_size = adv_args['iter_size'] if method == "msp" or method == "odin": attack_out = ConfidenceLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) elif method == "mahalanobis": attack_out = MahalanobisLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes, sample_mean=sample_mean, precision=precision, num_output=num_output, regressor=regressor) elif method == "sofl": attack_out = SOFLLinfPGDAttack( model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes, num_reject_classes=num_reject_classes) elif method == "rowl": attack_out = OODScoreLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) elif method == "atom" or method == "ntom": attack_out = OODScoreLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) if not mode_args['out_dist_only']: t0 = time.time() f1 = open(os.path.join(in_save_dir, "in_scores.txt"), 'w') g1 = open(os.path.join(in_save_dir, "in_labels.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") N = len(testloaderIn.dataset) count = 0 for j, data in enumerate(testloaderIn): images, labels = data images = images.cuda() labels = labels.cuda() curr_batch_size = images.shape[0] inputs = images scores = get_score(inputs, model, method, method_args) for score in scores: f1.write("{}\n".format(score)) if method == "rowl": outputs = F.softmax(model(inputs), dim=1) outputs = outputs.detach().cpu().numpy() preds = np.argmax(outputs, axis=1) confs = np.max(outputs, axis=1) else: outputs = F.softmax(model(inputs)[:, :num_classes], dim=1) outputs = outputs.detach().cpu().numpy() preds = np.argmax(outputs, axis=1) confs = np.max(outputs, axis=1) for k in range(preds.shape[0]): g1.write("{} {} {}\n".format(labels[k], preds[k], confs[k])) count += curr_batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() f1.close() g1.close() if mode_args['in_dist_only']: return for out_dataset in out_datasets: out_save_dir = os.path.join(in_save_dir, out_dataset) if not os.path.exists(out_save_dir): os.makedirs(out_save_dir) f2 = open(os.path.join(out_save_dir, "out_scores.txt"), 'w') if not os.path.exists(out_save_dir): os.makedirs(out_save_dir) if out_dataset == 'SVHN': testsetout = svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) elif out_dataset == 'dtd': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/dtd/images", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) elif out_dataset == 'places365': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/places365/test_subset", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) else: testsetout = torchvision.datasets.ImageFolder( "./datasets/ood_datasets/{}".format(out_dataset), transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") N = len(testloaderOut.dataset) count = 0 for j, data in enumerate(testloaderOut): images, labels = data images = images.cuda() labels = labels.cuda() curr_batch_size = images.shape[0] if adv: inputs = attack_out.perturb(images) elif corrupt: inputs = corrupt_attack(images, model, method, method_args, False, adv_args['severity_level']) elif adv_corrupt: corrupted_images = corrupt_attack(images, model, method, method_args, False, adv_args['severity_level']) inputs = attack_out.perturb(corrupted_images) else: inputs = images scores = get_score(inputs, model, method, method_args) for score in scores: f2.write("{}\n".format(score)) count += curr_batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() f2.close() return
def tune_mahalanobis_hyperparams(): print('Tuning hyper-parameters...') stypes = ['mahalanobis'] save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp') if not os.path.exists(save_dir): os.makedirs(save_dir) if args.in_dataset == "CIFAR-10": normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 10 elif args.in_dataset == "CIFAR-100": normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 100 elif args.in_dataset == "SVHN": normalizer = None trainloaderIn = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) testloaderIn = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) args.epochs = 20 num_classes = 10 if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() # set information about feature extaction temp_x = torch.rand(2,3,32,32) temp_x = Variable(temp_x).cuda() temp_list = model.feature_list(temp_x)[1] num_output = len(temp_list) feature_list = np.empty(num_output) count = 0 for out in temp_list: feature_list[count] = out.size(1) count += 1 print('get sample mean and covariance') sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn) print('train logistic regression model') m = 500 train_in = [] train_in_label = [] train_out = [] val_in = [] val_in_label = [] val_out = [] cnt = 0 for data, target in testloaderIn: data = data.numpy() target = target.numpy() for x, y in zip(data, target): cnt += 1 if cnt <= m: train_in.append(x) train_in_label.append(y) elif cnt <= 2*m: val_in.append(x) val_in_label.append(y) if cnt == 2*m: break if cnt == 2*m: break print('In', len(train_in), len(val_in)) criterion = nn.CrossEntropyLoss().cuda() adv_noise = 0.05 for i in range(int(m/args.batch_size) + 1): if i*args.batch_size >= m: break data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)]) target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)]) data = data.cuda() target = target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) model.zero_grad() inputs = Variable(data.data, requires_grad=True).cuda() output = model(inputs) loss = criterion(output, target) loss.backward() gradient = torch.ge(inputs.grad.data, 0) gradient = (gradient.float()-0.5)*2 adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise) adv_data = torch.clamp(adv_data, 0.0, 1.0) train_out.extend(adv_data.cpu().numpy()) for i in range(int(m/args.batch_size) + 1): if i*args.batch_size >= m: break data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)]) target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)]) data = data.cuda() target = target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) model.zero_grad() inputs = Variable(data.data, requires_grad=True).cuda() output = model(inputs) loss = criterion(output, target) loss.backward() gradient = torch.ge(inputs.grad.data, 0) gradient = (gradient.float()-0.5)*2 adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise) adv_data = torch.clamp(adv_data, 0.0, 1.0) val_out.extend(adv_data.cpu().numpy()) print('Out', len(train_out),len(val_out)) train_lr_data = [] train_lr_label = [] train_lr_data.extend(train_in) train_lr_label.extend(np.zeros(m)) train_lr_data.extend(train_out) train_lr_label.extend(np.ones(m)) train_lr_data = torch.tensor(train_lr_data) train_lr_label = torch.tensor(train_lr_label) best_fpr = 1.1 best_magnitude = 0.0 for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]: train_lr_Mahalanobis = [] total = 0 for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))): data = train_lr_data[total : total + args.batch_size].cuda() total += args.batch_size Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude) train_lr_Mahalanobis.extend(Mahalanobis_scores) train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32) regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label) print('Logistic Regressor params:', regressor.coef_, regressor.intercept_) t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") count = 0 for i in range(int(m/args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda() # if j<1000: continue batch_size = images.shape[0] Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f1.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") count = 0 for i in range(int(m/args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda() # if j<1000: continue batch_size = images.shape[0] Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f2.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() f1.close() f2.close() results = metric(save_dir, stypes) print_results(results, stypes) fpr = results['mahalanobis']['FPR'] if fpr < best_fpr: best_fpr = fpr best_magnitude = magnitude best_regressor = regressor print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_) print('Best magnitude', best_magnitude) return sample_mean, precision, best_regressor, best_magnitude
torchvision.utils.save_image(images[0], os.path.join(save_dir, '100', '%d.png'%i)) if i + 1 == 500: break save_dir = "datasets/rowl_train_data/SVHN" if not os.path.exists(save_dir): os.makedirs(save_dir) for i in range(11): os.makedirs(os.path.join(save_dir, '%02d'%i)) class_count = np.zeros(10, dtype=np.int32) train_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='train', transform=transform, download=False), batch_size=1, shuffle=True) for (xs, ys) in train_loader: image = xs[0] label = ys[0].numpy() torchvision.utils.save_image(image, os.path.join(save_dir, '%02d'%label, '%d.png'%class_count[label])) class_count[label] += 1 ood_loader.dataset.offset = np.random.randint(len(ood_loader.dataset)) print('SVHN, Class count: ', class_count) ood_count = int(np.mean(class_count)) print('SVHN, OOD count: ', ood_count) for i, (images, labels) in enumerate(ood_loader):