def cifar(): import torchvision import os mean, std = torch.Tensor([0.471, 0.448, 0.408]), torch.Tensor([0.234, 0.239, 0.242]) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean, std) ]) data = torchvision.datasets.CIFAR10(root="../data", train=False, download=False, transform=transform) loader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=False, drop_last=True) import models.wideresnet as models white_model = models.WideResNet(num_classes=10).cuda() import models.mobilenet as BlackModel black_model = BlackModel.MobileNet().cuda() black_model.load_state_dict( torch.load("black_model/mobilenet.p")["net"]) # black_model.eval() temp = torch.load( os.path.join('wideresnet_vs_mobilenet/result_1000', "model_best.pth.tar")) white_model.load_state_dict(temp['state_dict']) white_model.eval() trans = attack(False, white_model, black_model, loader, epsilon, attack_num, "cifar", True)
def __init__(self, model, ema_model, dataset, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha if dataset == 'cifar10': self.tmp_model = models.WideResNet(num_classes=10).cuda() elif dataset == 'cifar100': self.tmp_model = models.WideResNet(num_classes=100).cuda() else: raise NotImplementedError self.wd = 0.02 * args.lr for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): ema_param.data.copy_(param.data)
def create_model(ema=False): model = models.WideResNet(num_classes=10) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model
def __init__(self, model, ema_model, alpha=0.999): self.model = model self.ema_model = ema_model self.alpha = alpha self.tmp_model = models.WideResNet(num_classes=10).cuda() self.wd = 0.02 * args.lr for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): ema_param.data.copy_(param.data)
def create_model(ema=False): model = models.WideResNet(num_classes=10) model = model.cuda() if ema: for param in model.parameters(): param.detach_() # EMA exponential moving average 指数移动平均 return model
def create_model(ema=False): model = nn.DataParallel(models.WideResNet(num_classes=num_classes)) if use_cuda: model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model
def create_model(dataset, ema=False): num_classes = None if args.dataset == 'cifar10': num_classes = 10 elif args.dataset == 'cifar100': num_classes = 100 else: raise NotImplementedError model = models.WideResNet(num_classes=num_classes) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model
def main(): if args.tensorboard: configure("runs/%s"%(args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-10', transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) num_classes = 10 lr_schedule=[50, 75, 90] elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-100', transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) num_classes = 100 lr_schedule=[50, 75, 90] elif args.in_dataset == "SVHN": # Data loading code normalizer = None transform = transforms.Compose([transforms.ToTensor(),]) train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder('./datasets/row_train_data/SVHN', transform=transform), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule=[10, 15, 18] num_classes = 10 # create model if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) attack = LinfPGDAttack(model = model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, targeted=True, num_classes=num_classes+1, loss_func='CE', elementwise_best=True) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch train_rowl(train_loader, model, criterion, optimizer, epoch, num_classes, attack) # evaluate on validation set prec1 = validate(val_loader, model, criterion, num_classes, epoch) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
def main(): if args.tensorboard: configure("runs/%s" % (args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './datasets/cifar10', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( './datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule = [50, 75, 90] num_classes = 10 elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR100( './datasets/cifar100', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(datasets.CIFAR100( './datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule = [50, 75, 90] num_classes = 100 elif args.in_dataset == "SVHN": # Data loading code normalizer = None train_loader = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule = [10, 15, 18] num_classes = 10 out_loader = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) # create model if args.model_arch == 'densenet': base_model = dn.DenseNet3(args.layers, num_classes, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': base_model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100]) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm.mu.requires_grad = True gmm.logvar.requires_grad = True gmm.alpha.requires_grad = False gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) gmm_out.mu.requires_grad = True gmm_out.logvar.requires_grad = True gmm_out.alpha.requires_grad = False loglam = 0. model = gmmlib.DoublyRobustModel(base_model, gmm, gmm_out, loglam, dim=3072, classes=num_classes).cuda() model.loglam.requires_grad = False # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # define loss function (criterion) and pptimizer lr = args.lr lr_gmm = 1e-5 param_groups = [{ 'params': model.mm.parameters(), 'lr': lr_gmm, 'weight_decay': 0. }, { 'params': model.mm_out.parameters(), 'lr': lr_gmm, 'weight_decay': 0. }, { 'params': model.base_model.parameters(), 'lr': lr, 'weight_decay': args.weight_decay }] optimizer = torch.optim.SGD(param_groups, momentum=args.momentum, nesterov=True) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch lam = model.loglam.data.exp().item() train_CEDA_gmm_out(model, train_loader, out_loader, optimizer, epoch, lam=lam) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
def __init__(self, root='', train=True, meta=True, num_meta=1000, corruption_prob=0, corruption_type='unif', transform=None, target_transform=None, download=False, seed=1): self.root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set self.meta = meta self.corruption_prob = corruption_prob self.num_meta = num_meta if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') # now load the picked numpy arrays if self.train: self.train_data = [] self.train_labels = [] self.train_coarse_labels = [] for fentry in self.train_list: f = fentry[0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.train_data.append(entry['data']) if 'labels' in entry: self.train_labels += entry['labels'] img_num_list = [int(self.num_meta/10)] * 10 num_classes = 10 else: self.train_labels += entry['fine_labels'] self.train_coarse_labels += entry['coarse_labels'] img_num_list = [int(self.num_meta/100)] * 100 num_classes = 100 fo.close() self.train_data = np.concatenate(self.train_data) self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC data_list_val = {} for j in range(num_classes): data_list_val[j] = [i for i, label in enumerate(self.train_labels) if label == j] idx_to_meta = [] idx_to_train = [] print(img_num_list) for cls_idx, img_id_list in data_list_val.items(): np.random.shuffle(img_id_list) img_num = img_num_list[int(cls_idx)] idx_to_meta.extend(img_id_list[:img_num]) idx_to_train.extend(img_id_list[img_num:]) if meta is True: self.train_data = self.train_data[idx_to_meta] self.train_labels = list(np.array(self.train_labels)[idx_to_meta]) else: self.train_data = self.train_data[idx_to_train] self.train_labels = list(np.array(self.train_labels)[idx_to_train]) if corruption_type == 'hierarchical': self.train_coarse_labels = list(np.array(self.train_coarse_labels)[idx_to_meta]) if corruption_type == 'unif': C = uniform_mix_C(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'flip': C = flip_labels_C(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'flip2': C = flip_labels_C_two(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'hierarchical': assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.' coarse_fine = [] for i in range(20): coarse_fine.append(set()) for i in range(len(self.train_labels)): coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i]) for i in range(20): coarse_fine[i] = list(coarse_fine[i]) C = np.eye(num_classes) * (1 - corruption_prob) for i in range(20): tmp = np.copy(coarse_fine[i]) for j in range(len(tmp)): tmp2 = np.delete(np.copy(tmp), j) C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2) self.C = C print(C) elif corruption_type == 'clabels': net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda() model_name = './cifar{}_labeler'.format(num_classes) net.load_state_dict(torch.load(model_name)) net.eval() else: assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(corruption_type) np.random.seed(seed) if corruption_type == 'clabels': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) # obtain sampling probabilities sampling_probs = [] print('Starting labeling') for i in range((len(self.train_labels) // 64) + 1): current = self.train_data[i*64:(i+1)*64] current = [Image.fromarray(current[i]) for i in range(len(current))] current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0) data = V(current).cuda() logits = net(data) smax = F.softmax(logits / 5) # temperature of 1 sampling_probs.append(smax.data.cpu().numpy()) sampling_probs = np.concatenate(sampling_probs, 0) print('Finished labeling 1') new_labeling_correct = 0 argmax_labeling_correct = 0 for i in range(len(self.train_labels)): old_label = self.train_labels[i] new_label = np.random.choice(num_classes, p=sampling_probs[i]) self.train_labels[i] = new_label if old_label == new_label: new_labeling_correct += 1 if old_label == np.argmax(sampling_probs[i]): argmax_labeling_correct += 1 print('Finished labeling 2') print('New labeling accuracy:', new_labeling_correct / len(self.train_labels)) print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels)) else: for i in range(len(self.train_labels)): self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]]) self.corruption_matrix = C else: f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.test_data = entry['data'] if 'labels' in entry: self.test_labels = entry['labels'] else: self.test_labels = entry['fine_labels'] fo.close() self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
import torch.optim as optim import torch.utils.data as data import torchvision.transforms as transforms import torch.nn.functional as F from torch.optim.lr_scheduler import ReduceLROnPlateau import models.wideresnet as models import dataset.freesound_X as dataset from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig, lwlrap_accumulator, load_checkpoint from tensorboardX import SummaryWriter from fastai.basic_data import * from fastai.basic_train import * from fastai.train import * from train import SemiLoss model = models.WideResNet(num_classes=80) train_labeled_set, train_unlabeled_set, val_set, test_set, train_unlabeled_warmstart_set, num_classes, pos_weights = dataset.get_freesound( ) labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=4, shuffle=True, num_workers=0, drop_last=True) val_loader = data.DataLoader(val_set, batch_size=4, shuffle=False, num_workers=0) train_criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters())
def eval_ood_detector(base_dir, in_dataset, out_datasets, batch_size, method, method_args, name, epochs, adv, corrupt, adv_corrupt, adv_args, mode_args): if adv: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv', str(int(adv_args['epsilon']))) elif adv_corrupt: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv_corrupt', str(int(adv_args['epsilon']))) elif corrupt: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'corrupt') else: in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'nat') if not os.path.exists(in_save_dir): os.makedirs(in_save_dir) transform = transforms.Compose([ transforms.ToTensor(), ]) if in_dataset == "CIFAR-10": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 10 num_reject_classes = 5 elif in_dataset == "CIFAR-100": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 100 num_reject_classes = 10 elif in_dataset == "SVHN": normalizer = None testset = svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2) num_classes = 10 num_reject_classes = 5 if method != "sofl": num_reject_classes = 0 if method == "rowl" or method == "atom" or method == "ntom": num_reject_classes = 1 method_args['num_classes'] = num_classes if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + num_reject_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + num_reject_classes, widen_factor=args.width, normalizer=normalizer) elif args.model_arch == 'densenet_ccu': model = dn.DenseNet3(args.layers, num_classes + num_reject_classes, normalizer=normalizer) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) whole_model = gmmlib.DoublyRobustModel(model, gmm, gmm_out, loglam=0., dim=3072, classes=num_classes) elif args.model_arch == 'wideresnet_ccu': model = wn.WideResNet(args.depth, num_classes + num_reject_classes, widen_factor=args.width, normalizer=normalizer) gmm = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar') gmm.alpha = nn.Parameter(gmm.alpha) gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format( in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar') gmm_out.alpha = nn.Parameter(gmm.alpha) whole_model = gmmlib.DoublyRobustModel(model, gmm, gmm_out, loglam=0., dim=3072, classes=num_classes) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load( "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format( in_dataset=in_dataset, name=name, epochs=epochs)) if args.model_arch == 'densenet_ccu' or args.model_arch == 'wideresnet_ccu': whole_model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() if method == "mahalanobis": temp_x = torch.rand(2, 3, 32, 32) temp_x = Variable(temp_x).cuda() temp_list = model.feature_list(temp_x)[1] num_output = len(temp_list) method_args['num_output'] = num_output if adv or adv_corrupt: epsilon = adv_args['epsilon'] iters = adv_args['iters'] iter_size = adv_args['iter_size'] if method == "msp" or method == "odin": attack_out = ConfidenceLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=args.iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) elif method == "mahalanobis": attack_out = MahalanobisLinfPGDAttack(model, eps=args.epsilon, nb_iter=args.iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes, sample_mean=sample_mean, precision=precision, num_output=num_output, regressor=regressor) elif method == "sofl": attack_out = SOFLLinfPGDAttack( model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes, num_reject_classes=num_reject_classes) elif method == "rowl": attack_out = OODScoreLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) elif method == "atom" or method == "ntom": attack_out = OODScoreLinfPGDAttack(model, eps=epsilon, nb_iter=iters, eps_iter=iter_size, rand_init=True, clip_min=0., clip_max=1., num_classes=num_classes) if not mode_args['out_dist_only']: t0 = time.time() f1 = open(os.path.join(in_save_dir, "in_scores.txt"), 'w') g1 = open(os.path.join(in_save_dir, "in_labels.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") N = len(testloaderIn.dataset) count = 0 for j, data in enumerate(testloaderIn): images, labels = data images = images.cuda() labels = labels.cuda() curr_batch_size = images.shape[0] inputs = images scores = get_score(inputs, model, method, method_args) for score in scores: f1.write("{}\n".format(score)) if method == "rowl": outputs = F.softmax(model(inputs), dim=1) outputs = outputs.detach().cpu().numpy() preds = np.argmax(outputs, axis=1) confs = np.max(outputs, axis=1) else: outputs = F.softmax(model(inputs)[:, :num_classes], dim=1) outputs = outputs.detach().cpu().numpy() preds = np.argmax(outputs, axis=1) confs = np.max(outputs, axis=1) for k in range(preds.shape[0]): g1.write("{} {} {}\n".format(labels[k], preds[k], confs[k])) count += curr_batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() f1.close() g1.close() if mode_args['in_dist_only']: return for out_dataset in out_datasets: out_save_dir = os.path.join(in_save_dir, out_dataset) if not os.path.exists(out_save_dir): os.makedirs(out_save_dir) f2 = open(os.path.join(out_save_dir, "out_scores.txt"), 'w') if not os.path.exists(out_save_dir): os.makedirs(out_save_dir) if out_dataset == 'SVHN': testsetout = svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) elif out_dataset == 'dtd': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/dtd/images", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) elif out_dataset == 'places365': testsetout = torchvision.datasets.ImageFolder( root="datasets/ood_datasets/places365/test_subset", transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) else: testsetout = torchvision.datasets.ImageFolder( "./datasets/ood_datasets/{}".format(out_dataset), transform=transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor() ])) testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=True, num_workers=2) ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") N = len(testloaderOut.dataset) count = 0 for j, data in enumerate(testloaderOut): images, labels = data images = images.cuda() labels = labels.cuda() curr_batch_size = images.shape[0] if adv: inputs = attack_out.perturb(images) elif corrupt: inputs = corrupt_attack(images, model, method, method_args, False, adv_args['severity_level']) elif adv_corrupt: corrupted_images = corrupt_attack(images, model, method, method_args, False, adv_args['severity_level']) inputs = attack_out.perturb(corrupted_images) else: inputs = images scores = get_score(inputs, model, method, method_args) for score in scores: f2.write("{}\n".format(score)) count += curr_batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format( count, N, time.time() - t0)) t0 = time.time() f2.close() return
def tune_odin_hyperparams(): print('Tuning hyper-parameters...') stypes = ['ODIN'] save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset, args.name, 'tmp') if not os.path.exists(save_dir): os.makedirs(save_dir) transform = transforms.Compose([ transforms.ToTensor(), ]) if args.in_dataset == "CIFAR-10": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) trainset = torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True) num_classes = 10 elif args.in_dataset == "CIFAR-100": normalizer = transforms.Normalize( (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)) trainset = torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True) num_classes = 100 elif args.in_dataset == "SVHN": normalizer = None trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) testloaderIn = torch.utils.data.DataLoader(svhn.SVHN( 'datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) args.epochs = 20 num_classes = 10 valloaderOut = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])), batch_size=args.batch_size, shuffle=False) valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset)) if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load( "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format( in_dataset=args.in_dataset, name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() m = 1000 val_in = [] val_out = [] cnt = 0 for data, target in testloaderIn: for x in data: val_in.append(x.numpy()) cnt += 1 if cnt == m: break if cnt == m: break cnt = 0 for data, target in valloaderOut: for x in data: val_out.append(x.numpy()) cnt += 1 if cnt == m: break if cnt == m: break print('Len of val in: ', len(val_in)) print('Len of val out: ', len(val_out)) best_fpr = 1.1 best_magnitude = 0.0 for magnitude in np.arange(0, 0.0041, 0.004 / 20): t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") count = 0 for i in range(int(m / args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor( val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)]) images = images.cuda() # if j<1000: continue batch_size = images.shape[0] scores = get_odin_score(images, model, temper=1000, noiseMagnitude1=magnitude) for k in range(batch_size): f1.write("{}\n".format(scores[k])) count += batch_size # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") count = 0 for i in range(int(m / args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor( val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)]) images = images.cuda() # if j<1000: continue batch_size = images.shape[0] scores = get_odin_score(images, model, temper=1000, noiseMagnitude1=magnitude) for k in range(batch_size): f2.write("{}\n".format(scores[k])) count += batch_size # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() f1.close() f2.close() results = metric(save_dir, stypes) print_results(results, stypes) fpr = results['ODIN']['FPR'] if fpr < best_fpr: best_fpr = fpr best_magnitude = magnitude return best_magnitude
def __init__(self): super(Solver, self).__init__() global numberofclass #define the network if args.net_type == 'resnet': self.model = RN.ResNet(dataset=args.dataset, depth=args.depth, num_classes=numberofclass, bottleneck=args.bottleneck) elif args.net_type == 'pyramidnet': self.model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, args.bottleneck) elif args.net_type == 'wideresnet': self.model = WR.WideResNet(depth=args.depth, num_classes=numberofclass, widen_factor=args.width) elif args.net_type == 'vggnet': self.model = VGG.vgg16(num_classes=numberofclass) elif args.net_type == 'mobilenet': self.model = MN.mobile_half(num_classes=numberofclass) elif args.net_type == 'shufflenet': self.model = SN.ShuffleV2(num_classes=numberofclass) elif args.net_type == 'densenet': self.model = DN.densenet_cifar(num_classes=numberofclass) elif args.net_type == 'resnext-2': self.model = ResNeXt29_2x64d(num_classes=numberofclass) elif args.net_type == 'resnext-4': self.model = ResNeXt29_4x64d(num_classes=numberofclass) elif args.net_type == 'resnext-32': self.model = ResNeXt29_32x4d(num_classes=numberofclass) elif args.net_type == 'imagenetresnet18': self.model = multi_resnet18_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet34': self.model = multi_resnet34_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet50': self.model = multi_resnet50_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet101': self.model = multi_resnet101_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet152': self.model = multi_resnet152_kd(num_classes=numberofclass) else: raise Exception('unknown network architecture: {}'.format( args.net_type)) self.optimizer = torch.optim.SGD(self.model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) self.loss_lams = torch.zeros(numberofclass, numberofclass, dtype=torch.float32).cuda() self.loss_lams.requires_grad = False #define the loss function if args.method == 'ce': self.criterion = nn.CrossEntropyLoss() elif args.method == 'sce': if args.dataset == 'cifar10': self.criterion = SCELoss(alpha=0.1, beta=1.0, num_classes=numberofclass) else: self.criterion = SCELoss(alpha=6.0, beta=0.1, num_classes=numberofclass) elif args.method == 'ls': self.criterion = label_smooth(num_classes=numberofclass) elif args.method == 'gce': self.criterion = generalized_cross_entropy( num_classes=numberofclass) elif args.method == 'jo': self.criterion = joint_optimization(num_classes=numberofclass) elif args.method == 'bootsoft': self.criterion = boot_soft(num_classes=numberofclass) elif args.method == 'boothard': self.criterion = boot_hard(num_classes=numberofclass) elif args.method == 'forward': self.criterion = Forward(num_classes=numberofclass) elif args.method == 'backward': self.criterion = Backward(num_classes=numberofclass) elif args.method == 'disturb': self.criterion = DisturbLabel(num_classes=numberofclass) elif args.method == 'ols': self.criterion = nn.CrossEntropyLoss() self.criterion = self.criterion.cuda()
def main(): if args.tensorboard: configure("runs/%s"%(args.name)) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) kwargs = {'num_workers': 1, 'pin_memory': True} if args.in_dataset == "CIFAR-10": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule=[50, 75, 90] pool_size = args.pool_size num_classes = 10 elif args.in_dataset == "CIFAR-100": # Data loading code normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) train_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) lr_schedule=[50, 75, 90] pool_size = args.pool_size num_classes = 100 elif args.in_dataset == "SVHN": # Data loading code normalizer = None train_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=False, **kwargs) args.epochs = 20 args.save_epoch = 2 lr_schedule=[10, 15, 18] pool_size = int(len(train_loader.dataset) * 8 / args.ood_batch_size) + 1 num_classes = 10 ood_dataset_size = len(train_loader.dataset) * 2 print('OOD Dataset Size: ', ood_dataset_size) if args.auxiliary_dataset == '80m_tiny_images': ood_loader = torch.utils.data.DataLoader( TinyImages(transform=transforms.Compose( [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) elif args.auxiliary_dataset == 'imagenet': ood_loader = torch.utils.data.DataLoader( ImageNet(transform=transforms.Compose( [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])), batch_size=args.ood_batch_size, shuffle=False, **kwargs) # create model if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce, bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: assert False, "=> no checkpoint found at '{}'".format(args.resume) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, lr_schedule) # train for one epoch selected_ood_loader = select_ood(ood_loader, model, args.batch_size * 2, num_classes, pool_size, ood_dataset_size, args.quantile) train_ntom(train_loader, selected_ood_loader, model, criterion, num_classes, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch, num_classes) # remember best prec@1 and save checkpoint if (epoch + 1) % args.save_epoch == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), }, epoch + 1)
def create_model(): model = nn.DataParallel(models.WideResNet(num_classes=num_classes)) if use_cuda: model.cuda() return model
def tune_mahalanobis_hyperparams(): print('Tuning hyper-parameters...') stypes = ['mahalanobis'] save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp') if not os.path.exists(save_dir): os.makedirs(save_dir) if args.in_dataset == "CIFAR-10": normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 10 elif args.in_dataset == "CIFAR-100": normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)) transform = transforms.Compose([ transforms.ToTensor(), ]) trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform) trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform) testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2) num_classes = 100 elif args.in_dataset == "SVHN": normalizer = None trainloaderIn = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='train', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) testloaderIn = torch.utils.data.DataLoader( svhn.SVHN('datasets/svhn/', split='test', transform=transforms.ToTensor(), download=False), batch_size=args.batch_size, shuffle=True) args.epochs = 20 num_classes = 10 if args.model_arch == 'densenet': model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer) elif args.model_arch == 'wideresnet': model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer) else: assert False, 'Not supported model arch: {}'.format(args.model_arch) checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs)) model.load_state_dict(checkpoint['state_dict']) model.eval() model.cuda() # set information about feature extaction temp_x = torch.rand(2,3,32,32) temp_x = Variable(temp_x).cuda() temp_list = model.feature_list(temp_x)[1] num_output = len(temp_list) feature_list = np.empty(num_output) count = 0 for out in temp_list: feature_list[count] = out.size(1) count += 1 print('get sample mean and covariance') sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn) print('train logistic regression model') m = 500 train_in = [] train_in_label = [] train_out = [] val_in = [] val_in_label = [] val_out = [] cnt = 0 for data, target in testloaderIn: data = data.numpy() target = target.numpy() for x, y in zip(data, target): cnt += 1 if cnt <= m: train_in.append(x) train_in_label.append(y) elif cnt <= 2*m: val_in.append(x) val_in_label.append(y) if cnt == 2*m: break if cnt == 2*m: break print('In', len(train_in), len(val_in)) criterion = nn.CrossEntropyLoss().cuda() adv_noise = 0.05 for i in range(int(m/args.batch_size) + 1): if i*args.batch_size >= m: break data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)]) target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)]) data = data.cuda() target = target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) model.zero_grad() inputs = Variable(data.data, requires_grad=True).cuda() output = model(inputs) loss = criterion(output, target) loss.backward() gradient = torch.ge(inputs.grad.data, 0) gradient = (gradient.float()-0.5)*2 adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise) adv_data = torch.clamp(adv_data, 0.0, 1.0) train_out.extend(adv_data.cpu().numpy()) for i in range(int(m/args.batch_size) + 1): if i*args.batch_size >= m: break data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)]) target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)]) data = data.cuda() target = target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) model.zero_grad() inputs = Variable(data.data, requires_grad=True).cuda() output = model(inputs) loss = criterion(output, target) loss.backward() gradient = torch.ge(inputs.grad.data, 0) gradient = (gradient.float()-0.5)*2 adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise) adv_data = torch.clamp(adv_data, 0.0, 1.0) val_out.extend(adv_data.cpu().numpy()) print('Out', len(train_out),len(val_out)) train_lr_data = [] train_lr_label = [] train_lr_data.extend(train_in) train_lr_label.extend(np.zeros(m)) train_lr_data.extend(train_out) train_lr_label.extend(np.ones(m)) train_lr_data = torch.tensor(train_lr_data) train_lr_label = torch.tensor(train_lr_label) best_fpr = 1.1 best_magnitude = 0.0 for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]: train_lr_Mahalanobis = [] total = 0 for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))): data = train_lr_data[total : total + args.batch_size].cuda() total += args.batch_size Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude) train_lr_Mahalanobis.extend(Mahalanobis_scores) train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32) regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label) print('Logistic Regressor params:', regressor.coef_, regressor.intercept_) t0 = time.time() f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w') f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w') ########################################In-distribution########################################### print("Processing in-distribution images") count = 0 for i in range(int(m/args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda() # if j<1000: continue batch_size = images.shape[0] Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f1.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() ###################################Out-of-Distributions##################################### t0 = time.time() print("Processing out-of-distribution images") count = 0 for i in range(int(m/args.batch_size) + 1): if i * args.batch_size >= m: break images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda() # if j<1000: continue batch_size = images.shape[0] Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude) confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1] for k in range(batch_size): f2.write("{}\n".format(-confidence_scores[k])) count += batch_size print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0)) t0 = time.time() f1.close() f2.close() results = metric(save_dir, stypes) print_results(results, stypes) fpr = results['mahalanobis']['FPR'] if fpr < best_fpr: best_fpr = fpr best_magnitude = magnitude best_regressor = regressor print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_) print('Best magnitude', best_magnitude) return sample_mean, precision, best_regressor, best_magnitude