Example #1
0
def cifar():
    import torchvision
    import os

    mean, std = torch.Tensor([0.471, 0.448,
                              0.408]), torch.Tensor([0.234, 0.239, 0.242])
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, std)
    ])
    data = torchvision.datasets.CIFAR10(root="../data",
                                        train=False,
                                        download=False,
                                        transform=transform)
    loader = torch.utils.data.DataLoader(data,
                                         batch_size=32,
                                         shuffle=False,
                                         drop_last=True)
    import models.wideresnet as models
    white_model = models.WideResNet(num_classes=10).cuda()

    import models.mobilenet as BlackModel
    black_model = BlackModel.MobileNet().cuda()
    black_model.load_state_dict(
        torch.load("black_model/mobilenet.p")["net"])  #
    black_model.eval()

    temp = torch.load(
        os.path.join('wideresnet_vs_mobilenet/result_1000',
                     "model_best.pth.tar"))
    white_model.load_state_dict(temp['state_dict'])
    white_model.eval()

    trans = attack(False, white_model, black_model, loader, epsilon,
                   attack_num, "cifar", True)
Example #2
0
    def __init__(self, model, ema_model, dataset, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha

        if dataset == 'cifar10':
            self.tmp_model = models.WideResNet(num_classes=10).cuda()
        elif dataset == 'cifar100':
            self.tmp_model = models.WideResNet(num_classes=100).cuda()
        else:
            raise NotImplementedError

        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)
Example #3
0
    def create_model(ema=False):
        model = models.WideResNet(num_classes=10)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.tmp_model = models.WideResNet(num_classes=10).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
            ema_param.data.copy_(param.data)
    def create_model(ema=False):
        model = models.WideResNet(num_classes=10)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()
                # EMA exponential moving average 指数移动平均
        return model
Example #6
0
    def create_model(ema=False):
        model = nn.DataParallel(models.WideResNet(num_classes=num_classes))
        if use_cuda:
            model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
Example #7
0
    def create_model(dataset, ema=False):

        num_classes = None
        if args.dataset == 'cifar10':
            num_classes = 10
        elif args.dataset == 'cifar100':
            num_classes = 100
        else:
            raise NotImplementedError

        model = models.WideResNet(num_classes=num_classes)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
def main():
    if args.tensorboard: configure("runs/%s"%(args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-10', transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        num_classes = 10
        lr_schedule=[50, 75, 90]
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/CIFAR-100', transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        num_classes = 100
        lr_schedule=[50, 75, 90]
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        transform = transforms.Compose([transforms.ToTensor(),])
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder('./datasets/row_train_data/SVHN', transform=transform),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=False, **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule=[10, 15, 18]
        num_classes = 10

    # create model
    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    attack = LinfPGDAttack(model = model, eps=args.epsilon, nb_iter=args.iters, eps_iter=args.iter_size, rand_init=True, targeted=True, num_classes=num_classes+1, loss_func='CE', elementwise_best=True)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        train_rowl(train_loader, model, criterion, optimizer, epoch, num_classes, attack)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, num_classes, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, epoch + 1)
def main():
    if args.tensorboard: configure("runs/%s" % (args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './datasets/cifar10', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100',
            train=True,
            download=True,
            transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            './datasets/cifar100', train=False, transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

        lr_schedule = [50, 75, 90]
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule = [10, 15, 18]
        num_classes = 10

    out_loader = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.ood_batch_size,
        shuffle=False,
        **kwargs)

    # create model
    if args.model_arch == 'densenet':
        base_model = dn.DenseNet3(args.layers,
                                  num_classes,
                                  args.growth,
                                  reduction=args.reduce,
                                  bottleneck=args.bottleneck,
                                  dropRate=args.droprate,
                                  normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        base_model = wn.WideResNet(args.depth,
                                   num_classes,
                                   widen_factor=args.width,
                                   dropRate=args.droprate,
                                   normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    gen_gmm(train_loader, out_loader, data_used=50000, PCA=True, N=[100])

    gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')

    gmm.alpha = nn.Parameter(gmm.alpha)
    gmm.mu.requires_grad = True
    gmm.logvar.requires_grad = True
    gmm.alpha.requires_grad = False

    gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
        in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
    gmm_out.alpha = nn.Parameter(gmm.alpha)
    gmm_out.mu.requires_grad = True
    gmm_out.logvar.requires_grad = True
    gmm_out.alpha.requires_grad = False
    loglam = 0.
    model = gmmlib.DoublyRobustModel(base_model,
                                     gmm,
                                     gmm_out,
                                     loglam,
                                     dim=3072,
                                     classes=num_classes).cuda()

    model.loglam.requires_grad = False

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    lr = args.lr
    lr_gmm = 1e-5

    param_groups = [{
        'params': model.mm.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.mm_out.parameters(),
        'lr': lr_gmm,
        'weight_decay': 0.
    }, {
        'params': model.base_model.parameters(),
        'lr': lr,
        'weight_decay': args.weight_decay
    }]

    optimizer = torch.optim.SGD(param_groups,
                                momentum=args.momentum,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        lam = model.loglam.data.exp().item()
        train_CEDA_gmm_out(model,
                           train_loader,
                           out_loader,
                           optimizer,
                           epoch,
                           lam=lam)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                }, epoch + 1)
Example #10
0
    def __init__(self, root='', train=True, meta=True, num_meta=1000,
                 corruption_prob=0, corruption_type='unif', transform=None, target_transform=None,
                 download=False, seed=1):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.meta = meta
        self.corruption_prob = corruption_prob
        self.num_meta = num_meta

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            self.train_coarse_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                    img_num_list = [int(self.num_meta/10)] * 10
                    num_classes = 10
                else:
                    self.train_labels += entry['fine_labels']
                    self.train_coarse_labels += entry['coarse_labels']
                    img_num_list = [int(self.num_meta/100)] * 100
                    num_classes = 100
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))   # convert to HWC

            data_list_val = {}
            for j in range(num_classes):
                data_list_val[j] = [i for i, label in enumerate(self.train_labels) if label == j]


            idx_to_meta = []
            idx_to_train = []
            print(img_num_list)

            for cls_idx, img_id_list in data_list_val.items():
                np.random.shuffle(img_id_list)
                img_num = img_num_list[int(cls_idx)]
                idx_to_meta.extend(img_id_list[:img_num])
                idx_to_train.extend(img_id_list[img_num:])


            if meta is True:
                self.train_data = self.train_data[idx_to_meta]
                self.train_labels = list(np.array(self.train_labels)[idx_to_meta])
            else:
                self.train_data = self.train_data[idx_to_train]
                self.train_labels = list(np.array(self.train_labels)[idx_to_train])
                if corruption_type == 'hierarchical':
                    self.train_coarse_labels = list(np.array(self.train_coarse_labels)[idx_to_meta])

                if corruption_type == 'unif':
                    C = uniform_mix_C(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'flip':
                    C = flip_labels_C(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'flip2':
                    C = flip_labels_C_two(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'hierarchical':
                    assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.'
                    coarse_fine = []
                    for i in range(20):
                        coarse_fine.append(set())
                    for i in range(len(self.train_labels)):
                        coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i])
                    for i in range(20):
                        coarse_fine[i] = list(coarse_fine[i])

                    C = np.eye(num_classes) * (1 - corruption_prob)

                    for i in range(20):
                        tmp = np.copy(coarse_fine[i])
                        for j in range(len(tmp)):
                            tmp2 = np.delete(np.copy(tmp), j)
                            C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2)
                    self.C = C
                    print(C)
                elif corruption_type == 'clabels':
                    net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda()
                    model_name = './cifar{}_labeler'.format(num_classes)
                    net.load_state_dict(torch.load(model_name))
                    net.eval()
                else:
                    assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(corruption_type)

                np.random.seed(seed)
                if corruption_type == 'clabels':
                    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
                    std = [x / 255 for x in [63.0, 62.1, 66.7]]

                    test_transform = transforms.Compose(
                        [transforms.ToTensor(), transforms.Normalize(mean, std)])

                    # obtain sampling probabilities
                    sampling_probs = []
                    print('Starting labeling')

                    for i in range((len(self.train_labels) // 64) + 1):
                        current = self.train_data[i*64:(i+1)*64]
                        current = [Image.fromarray(current[i]) for i in range(len(current))]
                        current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0)

                        data = V(current).cuda()
                        logits = net(data)
                        smax = F.softmax(logits / 5)  # temperature of 1
                        sampling_probs.append(smax.data.cpu().numpy())


                    sampling_probs = np.concatenate(sampling_probs, 0)
                    print('Finished labeling 1')

                    new_labeling_correct = 0
                    argmax_labeling_correct = 0
                    for i in range(len(self.train_labels)):
                        old_label = self.train_labels[i]
                        new_label = np.random.choice(num_classes, p=sampling_probs[i])
                        self.train_labels[i] = new_label
                        if old_label == new_label:
                            new_labeling_correct += 1
                        if old_label == np.argmax(sampling_probs[i]):
                            argmax_labeling_correct += 1
                    print('Finished labeling 2')
                    print('New labeling accuracy:', new_labeling_correct / len(self.train_labels))
                    print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels))
                else:    
                    for i in range(len(self.train_labels)):
                        self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]])
                    self.corruption_matrix = C

        else:
            f = self.test_list[0][0]
            file = os.path.join(root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry:
                self.test_labels = entry['labels']
            else:
                self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
Example #11
0
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import models.wideresnet as models
import dataset.freesound_X as dataset
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig, lwlrap_accumulator, load_checkpoint
from tensorboardX import SummaryWriter
from fastai.basic_data import *
from fastai.basic_train import *
from fastai.train import *
from train import SemiLoss

model = models.WideResNet(num_classes=80)
train_labeled_set, train_unlabeled_set, val_set, test_set, train_unlabeled_warmstart_set, num_classes, pos_weights = dataset.get_freesound(
)

labeled_trainloader = data.DataLoader(train_labeled_set,
                                      batch_size=4,
                                      shuffle=True,
                                      num_workers=0,
                                      drop_last=True)
val_loader = data.DataLoader(val_set,
                             batch_size=4,
                             shuffle=False,
                             num_workers=0)

train_criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
Example #12
0
def eval_ood_detector(base_dir, in_dataset, out_datasets, batch_size, method,
                      method_args, name, epochs, adv, corrupt, adv_corrupt,
                      adv_args, mode_args):

    if adv:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'adv',
                                   str(int(adv_args['epsilon'])))
    elif adv_corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method,
                                   name, 'adv_corrupt',
                                   str(int(adv_args['epsilon'])))
    elif corrupt:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name,
                                   'corrupt')
    else:
        in_save_dir = os.path.join(base_dir, in_dataset, method, name, 'nat')

    if not os.path.exists(in_save_dir):
        os.makedirs(in_save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5
    elif in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 100
        num_reject_classes = 10
    elif in_dataset == "SVHN":
        normalizer = None
        testset = svhn.SVHN('datasets/svhn/',
                            split='test',
                            transform=transforms.ToTensor(),
                            download=False)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)
        num_classes = 10
        num_reject_classes = 5

    if method != "sofl":
        num_reject_classes = 0

    if method == "rowl" or method == "atom" or method == "ntom":
        num_reject_classes = 1

    method_args['num_classes'] = num_classes

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    elif args.model_arch == 'densenet_ccu':
        model = dn.DenseNet3(args.layers,
                             num_classes + num_reject_classes,
                             normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    elif args.model_arch == 'wideresnet_ccu':
        model = wn.WideResNet(args.depth,
                              num_classes + num_reject_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
        gmm = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'in_gmm.pth.tar')
        gmm.alpha = nn.Parameter(gmm.alpha)
        gmm_out = torch.load("checkpoints/{in_dataset}/{name}/".format(
            in_dataset=args.in_dataset, name=args.name) + 'out_gmm.pth.tar')
        gmm_out.alpha = nn.Parameter(gmm.alpha)
        whole_model = gmmlib.DoublyRobustModel(model,
                                               gmm,
                                               gmm_out,
                                               loglam=0.,
                                               dim=3072,
                                               classes=num_classes)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=in_dataset, name=name, epochs=epochs))

    if args.model_arch == 'densenet_ccu' or args.model_arch == 'wideresnet_ccu':
        whole_model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    if method == "mahalanobis":
        temp_x = torch.rand(2, 3, 32, 32)
        temp_x = Variable(temp_x).cuda()
        temp_list = model.feature_list(temp_x)[1]
        num_output = len(temp_list)
        method_args['num_output'] = num_output

    if adv or adv_corrupt:
        epsilon = adv_args['epsilon']
        iters = adv_args['iters']
        iter_size = adv_args['iter_size']

        if method == "msp" or method == "odin":
            attack_out = ConfidenceLinfPGDAttack(model,
                                                 eps=epsilon,
                                                 nb_iter=iters,
                                                 eps_iter=args.iter_size,
                                                 rand_init=True,
                                                 clip_min=0.,
                                                 clip_max=1.,
                                                 num_classes=num_classes)
        elif method == "mahalanobis":
            attack_out = MahalanobisLinfPGDAttack(model,
                                                  eps=args.epsilon,
                                                  nb_iter=args.iters,
                                                  eps_iter=iter_size,
                                                  rand_init=True,
                                                  clip_min=0.,
                                                  clip_max=1.,
                                                  num_classes=num_classes,
                                                  sample_mean=sample_mean,
                                                  precision=precision,
                                                  num_output=num_output,
                                                  regressor=regressor)
        elif method == "sofl":
            attack_out = SOFLLinfPGDAttack(
                model,
                eps=epsilon,
                nb_iter=iters,
                eps_iter=iter_size,
                rand_init=True,
                clip_min=0.,
                clip_max=1.,
                num_classes=num_classes,
                num_reject_classes=num_reject_classes)
        elif method == "rowl":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)
        elif method == "atom" or method == "ntom":
            attack_out = OODScoreLinfPGDAttack(model,
                                               eps=epsilon,
                                               nb_iter=iters,
                                               eps_iter=iter_size,
                                               rand_init=True,
                                               clip_min=0.,
                                               clip_max=1.,
                                               num_classes=num_classes)

    if not mode_args['out_dist_only']:
        t0 = time.time()

        f1 = open(os.path.join(in_save_dir, "in_scores.txt"), 'w')
        g1 = open(os.path.join(in_save_dir, "in_labels.txt"), 'w')

        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        N = len(testloaderIn.dataset)
        count = 0
        for j, data in enumerate(testloaderIn):
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f1.write("{}\n".format(score))

            if method == "rowl":
                outputs = F.softmax(model(inputs), dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)
            else:
                outputs = F.softmax(model(inputs)[:, :num_classes], dim=1)
                outputs = outputs.detach().cpu().numpy()
                preds = np.argmax(outputs, axis=1)
                confs = np.max(outputs, axis=1)

            for k in range(preds.shape[0]):
                g1.write("{} {} {}\n".format(labels[k], preds[k], confs[k]))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f1.close()
        g1.close()

    if mode_args['in_dist_only']:
        return

    for out_dataset in out_datasets:

        out_save_dir = os.path.join(in_save_dir, out_dataset)

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        f2 = open(os.path.join(out_save_dir, "out_scores.txt"), 'w')

        if not os.path.exists(out_save_dir):
            os.makedirs(out_save_dir)

        if out_dataset == 'SVHN':
            testsetout = svhn.SVHN('datasets/ood_datasets/svhn/',
                                   split='test',
                                   transform=transforms.ToTensor(),
                                   download=False)
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'dtd':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/dtd/images",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        elif out_dataset == 'places365':
            testsetout = torchvision.datasets.ImageFolder(
                root="datasets/ood_datasets/places365/test_subset",
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)
        else:
            testsetout = torchvision.datasets.ImageFolder(
                "./datasets/ood_datasets/{}".format(out_dataset),
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.CenterCrop(32),
                    transforms.ToTensor()
                ]))
            testloaderOut = torch.utils.data.DataLoader(testsetout,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        num_workers=2)

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")

        N = len(testloaderOut.dataset)
        count = 0
        for j, data in enumerate(testloaderOut):

            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            curr_batch_size = images.shape[0]

            if adv:
                inputs = attack_out.perturb(images)
            elif corrupt:
                inputs = corrupt_attack(images, model, method, method_args,
                                        False, adv_args['severity_level'])
            elif adv_corrupt:
                corrupted_images = corrupt_attack(images, model, method,
                                                  method_args, False,
                                                  adv_args['severity_level'])
                inputs = attack_out.perturb(corrupted_images)
            else:
                inputs = images

            scores = get_score(inputs, model, method, method_args)

            for score in scores:
                f2.write("{}\n".format(score))

            count += curr_batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(
                count, N,
                time.time() - t0))
            t0 = time.time()

        f2.close()

    return
def tune_odin_hyperparams():
    print('Tuning hyper-parameters...')
    stypes = ['ODIN']

    save_dir = os.path.join('output/odin_hyperparams/', args.in_dataset,
                            args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR10('./datasets/cifar10',
                                                train=True,
                                                download=True,
                                                transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10',
                                               train=False,
                                               download=True,
                                               transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 10

    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize(
            (125.3 / 255, 123.0 / 255, 113.9 / 255),
            (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0))
        trainset = torchvision.datasets.CIFAR100('./datasets/cifar100',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100',
                                                train=False,
                                                download=True,
                                                transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='train',
            transform=transforms.ToTensor(),
            download=False),
                                                    batch_size=args.batch_size,
                                                    shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(svhn.SVHN(
            'datasets/svhn/',
            split='test',
            transform=transforms.ToTensor(),
            download=False),
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

        args.epochs = 20
        num_classes = 10

    valloaderOut = torch.utils.data.DataLoader(
        TinyImages(transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])),
        batch_size=args.batch_size,
        shuffle=False)

    valloaderOut.dataset.offset = np.random.randint(len(valloaderOut.dataset))

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth,
                              num_classes,
                              widen_factor=args.width,
                              normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load(
        "./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(
            in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    m = 1000
    val_in = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        for x in data:
            val_in.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    cnt = 0
    for data, target in valloaderOut:
        for x in data:
            val_out.append(x.numpy())
            cnt += 1
            if cnt == m:
                break
        if cnt == m:
            break

    print('Len of val in: ', len(val_in))
    print('Len of val out: ', len(val_out))

    best_fpr = 1.1
    best_magnitude = 0.0
    for magnitude in np.arange(0, 0.0041, 0.004 / 20):

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_ODIN_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_ODIN_Out.txt"), 'w')
        ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_in[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f1.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m / args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(
                val_out[i * args.batch_size:min((i + 1) * args.batch_size, m)])
            images = images.cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            scores = get_odin_score(images,
                                    model,
                                    temper=1000,
                                    noiseMagnitude1=magnitude)

            for k in range(batch_size):
                f2.write("{}\n".format(scores[k]))

            count += batch_size
            # print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['ODIN']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude

    return best_magnitude
    def __init__(self):
        super(Solver, self).__init__()
        global numberofclass

        #define the network
        if args.net_type == 'resnet':
            self.model = RN.ResNet(dataset=args.dataset,
                                   depth=args.depth,
                                   num_classes=numberofclass,
                                   bottleneck=args.bottleneck)

        elif args.net_type == 'pyramidnet':
            self.model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                         numberofclass, args.bottleneck)

        elif args.net_type == 'wideresnet':
            self.model = WR.WideResNet(depth=args.depth,
                                       num_classes=numberofclass,
                                       widen_factor=args.width)

        elif args.net_type == 'vggnet':
            self.model = VGG.vgg16(num_classes=numberofclass)

        elif args.net_type == 'mobilenet':
            self.model = MN.mobile_half(num_classes=numberofclass)

        elif args.net_type == 'shufflenet':
            self.model = SN.ShuffleV2(num_classes=numberofclass)

        elif args.net_type == 'densenet':
            self.model = DN.densenet_cifar(num_classes=numberofclass)

        elif args.net_type == 'resnext-2':
            self.model = ResNeXt29_2x64d(num_classes=numberofclass)
        elif args.net_type == 'resnext-4':
            self.model = ResNeXt29_4x64d(num_classes=numberofclass)
        elif args.net_type == 'resnext-32':
            self.model = ResNeXt29_32x4d(num_classes=numberofclass)

        elif args.net_type == 'imagenetresnet18':
            self.model = multi_resnet18_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet34':
            self.model = multi_resnet34_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet50':
            self.model = multi_resnet50_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet101':
            self.model = multi_resnet101_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet152':
            self.model = multi_resnet152_kd(num_classes=numberofclass)
        else:
            raise Exception('unknown network architecture: {}'.format(
                args.net_type))

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=True)
        self.loss_lams = torch.zeros(numberofclass,
                                     numberofclass,
                                     dtype=torch.float32).cuda()
        self.loss_lams.requires_grad = False
        #define the loss function
        if args.method == 'ce':
            self.criterion = nn.CrossEntropyLoss()
        elif args.method == 'sce':
            if args.dataset == 'cifar10':
                self.criterion = SCELoss(alpha=0.1,
                                         beta=1.0,
                                         num_classes=numberofclass)
            else:
                self.criterion = SCELoss(alpha=6.0,
                                         beta=0.1,
                                         num_classes=numberofclass)
        elif args.method == 'ls':
            self.criterion = label_smooth(num_classes=numberofclass)
        elif args.method == 'gce':
            self.criterion = generalized_cross_entropy(
                num_classes=numberofclass)
        elif args.method == 'jo':
            self.criterion = joint_optimization(num_classes=numberofclass)
        elif args.method == 'bootsoft':
            self.criterion = boot_soft(num_classes=numberofclass)
        elif args.method == 'boothard':
            self.criterion = boot_hard(num_classes=numberofclass)
        elif args.method == 'forward':
            self.criterion = Forward(num_classes=numberofclass)
        elif args.method == 'backward':
            self.criterion = Backward(num_classes=numberofclass)
        elif args.method == 'disturb':
            self.criterion = DisturbLabel(num_classes=numberofclass)
        elif args.method == 'ols':
            self.criterion = nn.CrossEntropyLoss()
        self.criterion = self.criterion.cuda()
Example #15
0
def main():
    if args.tensorboard: configure("runs/%s"%(args.name))

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}

    if args.in_dataset == "CIFAR-10":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./datasets/cifar10', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        # Data loading code
        normalizer = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./datasets/cifar100', train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        lr_schedule=[50, 75, 90]
        pool_size = args.pool_size
        num_classes = 100
    elif args.in_dataset == "SVHN":
        # Data loading code
        normalizer = None
        train_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=False, **kwargs)

        args.epochs = 20
        args.save_epoch = 2
        lr_schedule=[10, 15, 18]
        pool_size = int(len(train_loader.dataset) * 8 / args.ood_batch_size) + 1
        num_classes = 10

    ood_dataset_size = len(train_loader.dataset) * 2

    print('OOD Dataset Size: ', ood_dataset_size)

    if args.auxiliary_dataset == '80m_tiny_images':
        ood_loader = torch.utils.data.DataLoader(
            TinyImages(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)
    elif args.auxiliary_dataset == 'imagenet':
        ood_loader = torch.utils.data.DataLoader(
            ImageNet(transform=transforms.Compose(
                [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), transforms.ToTensor()])),
                batch_size=args.ood_batch_size, shuffle=False, **kwargs)


    # create model
    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes + 1, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes + 1, widen_factor=args.width, dropRate=args.droprate, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            assert False, "=> no checkpoint found at '{}'".format(args.resume)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)

        # train for one epoch
        selected_ood_loader = select_ood(ood_loader, model, args.batch_size * 2, num_classes, pool_size, ood_dataset_size, args.quantile)

        train_ntom(train_loader, selected_ood_loader, model, criterion, num_classes, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, num_classes)

        # remember best prec@1 and save checkpoint
        if (epoch + 1) % args.save_epoch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, epoch + 1)
Example #16
0
 def create_model():
     model = nn.DataParallel(models.WideResNet(num_classes=num_classes))
     if use_cuda:
         model.cuda()
     return model
Example #17
0
def tune_mahalanobis_hyperparams():

    print('Tuning hyper-parameters...')
    stypes = ['mahalanobis']

    save_dir = os.path.join('output/mahalanobis_hyperparams/', args.in_dataset, args.name, 'tmp')

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if args.in_dataset == "CIFAR-10":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR10('./datasets/cifar10', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./datasets/cifar10', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 10
    elif args.in_dataset == "CIFAR-100":
        normalizer = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        trainset= torchvision.datasets.CIFAR100('./datasets/cifar100', train=True, download=True, transform=transform)
        trainloaderIn = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR100(root='./datasets/cifar100', train=False, download=True, transform=transform)
        testloaderIn = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        num_classes = 100

    elif args.in_dataset == "SVHN":

        normalizer = None
        trainloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='train',
                                      transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)
        testloaderIn = torch.utils.data.DataLoader(
            svhn.SVHN('datasets/svhn/', split='test',
                                  transform=transforms.ToTensor(), download=False),
            batch_size=args.batch_size, shuffle=True)

        args.epochs = 20
        num_classes = 10

    if args.model_arch == 'densenet':
        model = dn.DenseNet3(args.layers, num_classes, normalizer=normalizer)
    elif args.model_arch == 'wideresnet':
        model = wn.WideResNet(args.depth, num_classes, widen_factor=args.width, normalizer=normalizer)
    else:
        assert False, 'Not supported model arch: {}'.format(args.model_arch)

    checkpoint = torch.load("./checkpoints/{in_dataset}/{name}/checkpoint_{epochs}.pth.tar".format(in_dataset=args.in_dataset, name=args.name, epochs=args.epochs))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model.cuda()

    # set information about feature extaction
    temp_x = torch.rand(2,3,32,32)
    temp_x = Variable(temp_x).cuda()
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = sample_estimator(model, num_classes, feature_list, trainloaderIn)

    print('train logistic regression model')
    m = 500

    train_in = []
    train_in_label = []
    train_out = []

    val_in = []
    val_in_label = []
    val_out = []

    cnt = 0
    for data, target in testloaderIn:
        data = data.numpy()
        target = target.numpy()
        for x, y in zip(data, target):
            cnt += 1
            if cnt <= m:
                train_in.append(x)
                train_in_label.append(y)
            elif cnt <= 2*m:
                val_in.append(x)
                val_in_label.append(y)

            if cnt == 2*m:
                break
        if cnt == 2*m:
            break

    print('In', len(train_in), len(val_in))

    criterion = nn.CrossEntropyLoss().cuda()
    adv_noise = 0.05

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(train_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(train_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        train_out.extend(adv_data.cpu().numpy())

    for i in range(int(m/args.batch_size) + 1):
        if i*args.batch_size >= m:
            break
        data = torch.tensor(val_in[i*args.batch_size:min((i+1)*args.batch_size, m)])
        target = torch.tensor(val_in_label[i*args.batch_size:min((i+1)*args.batch_size, m)])
        data = data.cuda()
        target = target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)

        model.zero_grad()
        inputs = Variable(data.data, requires_grad=True).cuda()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()

        gradient = torch.ge(inputs.grad.data, 0)
        gradient = (gradient.float()-0.5)*2

        adv_data = torch.add(input=inputs.data, other=gradient, alpha=adv_noise)
        adv_data = torch.clamp(adv_data, 0.0, 1.0)

        val_out.extend(adv_data.cpu().numpy())

    print('Out', len(train_out),len(val_out))

    train_lr_data = []
    train_lr_label = []
    train_lr_data.extend(train_in)
    train_lr_label.extend(np.zeros(m))
    train_lr_data.extend(train_out)
    train_lr_label.extend(np.ones(m))
    train_lr_data = torch.tensor(train_lr_data)
    train_lr_label = torch.tensor(train_lr_label)

    best_fpr = 1.1
    best_magnitude = 0.0

    for magnitude in [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]:
        train_lr_Mahalanobis = []
        total = 0
        for data_index in range(int(np.floor(train_lr_data.size(0) / args.batch_size))):
            data = train_lr_data[total : total + args.batch_size].cuda()
            total += args.batch_size
            Mahalanobis_scores = get_Mahalanobis_score(data, model, num_classes, sample_mean, precision, num_output, magnitude)
            train_lr_Mahalanobis.extend(Mahalanobis_scores)

        train_lr_Mahalanobis = np.asarray(train_lr_Mahalanobis, dtype=np.float32)
        regressor = LogisticRegressionCV(n_jobs=-1).fit(train_lr_Mahalanobis, train_lr_label)

        print('Logistic Regressor params:', regressor.coef_, regressor.intercept_)

        t0 = time.time()
        f1 = open(os.path.join(save_dir, "confidence_mahalanobis_In.txt"), 'w')
        f2 = open(os.path.join(save_dir, "confidence_mahalanobis_Out.txt"), 'w')

    ########################################In-distribution###########################################
        print("Processing in-distribution images")

        count = 0
        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_in[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]
            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)
            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f1.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

    ###################################Out-of-Distributions#####################################
        t0 = time.time()
        print("Processing out-of-distribution images")
        count = 0

        for i in range(int(m/args.batch_size) + 1):
            if i * args.batch_size >= m:
                break
            images = torch.tensor(val_out[i * args.batch_size : min((i+1) * args.batch_size, m)]).cuda()
            # if j<1000: continue
            batch_size = images.shape[0]

            Mahalanobis_scores = get_Mahalanobis_score(images, model, num_classes, sample_mean, precision, num_output, magnitude)

            confidence_scores= regressor.predict_proba(Mahalanobis_scores)[:, 1]

            for k in range(batch_size):
                f2.write("{}\n".format(-confidence_scores[k]))

            count += batch_size
            print("{:4}/{:4} images processed, {:.1f} seconds used.".format(count, m, time.time()-t0))
            t0 = time.time()

        f1.close()
        f2.close()

        results = metric(save_dir, stypes)
        print_results(results, stypes)
        fpr = results['mahalanobis']['FPR']
        if fpr < best_fpr:
            best_fpr = fpr
            best_magnitude = magnitude
            best_regressor = regressor

    print('Best Logistic Regressor params:', best_regressor.coef_, best_regressor.intercept_)
    print('Best magnitude', best_magnitude)

    return sample_mean, precision, best_regressor, best_magnitude