Example #1
0
def main():
    # arg parser
    args = arg_parser()

    # set seed
    set_seed(args.seed)

    # dataset
    id_traindata = datasets.CIFAR10('./data/', train=True, download=True)
    id_testdata = datasets.CIFAR10('./data/', train=False, download=True)

    id_traindata = RotDataset(id_traindata, train_mode=True)
    id_testdata = RotDataset(id_testdata, train_mode=False)

    # data loader
    if args.method == 'rot' or args.method == 'msp':
        train_loader = dataloader(id_traindata,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    else:
        raise ValueError(args.method)

    test_loader = dataloader(id_testdata,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers,
                             pin_memory=True)

    # model
    num_classes = 10
    model = WideResNet(args.layers,
                       num_classes,
                       args.widen_factor,
                       dropRate=args.droprate)
    model.rot_head = nn.Linear(128, 4)
    model = model.cuda()

    # optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=True)

    # training
    for epoch in range(1, args.epochs + 1):

        train_loss = train(args,
                           epoch,
                           model,
                           train_loader,
                           optimizer,
                           lr_scheduler=None)
        test_loss, test_acc = test(args, model, test_loader)

        print('epoch:{}, train_loss:{}, test_loss:{}, test_acc:{}'.format(
            epoch, round(train_loss.item(), 4), round(test_loss.item(), 4),
            round(test_acc, 4)))
        torch.save(model.state_dict(),
                   './trained_model_{}.pth'.format(args.method))
Example #2
0
def main():
    # arg parser 
    args = arg_parser()
    
    # set seed 
    set_seed(args.seed)  
    
    # dataset 
    id_testdata = datasets.CIFAR10('./data/', train=False, download=True)
    id_testdata = RotDataset(id_testdata, train_mode=False)

    if args.ood_dataset == 'cifar100':
        ood_testdata = datasets.CIFAR100('./data/', train=False, download=True)
    elif args.ood_dataset == 'svhn':
        ood_testdata = datasets.SVHN('./data/', split='test', download=True)
    else:
        raise ValueError(args.ood_dataset)
    ood_testdata = RotDataset(ood_testdata, train_mode=False)
    
    # data loader  
    id_test_loader = dataloader(id_testdata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    ood_test_loader = dataloader(ood_testdata, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
  
    # load model
    num_classes = 10
    model = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
    model.rot_head = nn.Linear(128, 4)
    model = model.cuda()
    model.load_state_dict(torch.load('./models/trained_model_{}.pth'.format(args.method)))

    TODO:
Example #3
0
def main():

    # Load the model
    num_class = 10
    net = None
    net = WideResNet(args.layers, num_class, args.widen_factor, dropRate=args.droprate)
  
    net.x_trans_head = nn.Linear(128, 3)
    net.y_trans_head = nn.Linear(128, 3)
    net.rot_head = nn.Linear(128, 4)
    
    if os.path.isfile(args.model):
        # net.load_state_dict(torch.load(args.model))
        net.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args.model).items()})
    else:
        raise Exception("Cannot find {0}".format(args.model))

    in_data = PerturbDataset(dset.CIFAR10('/home/jiuhai.chen/data', train=False, download=True), train_mode=False)

    train_loader_in = torch.utils.data.DataLoader(
    in_data,
    batch_size=args.test_bs,
    shuffle=False,
    num_workers=args.prefetch,
    pin_memory=False
    )
  

  
    Jacobian = []
    for x_tf_0, _, _, _, _, _, _, _ in tqdm(train_loader_in):
        
        batch_size = x_tf_0.shape[0]
        # x_tf_0 = x_tf_0.clone().detach().requires_grad_(True)
        x_tf_0 = x_tf_0.requires_grad_(True)
        logits, pen = net(x_tf_0)
        dims = pen.shape[1]
        # classification_smax = F.softmax(logits[:batch_size], dim=1)
        Jacobian_batch = torch.zeros(batch_size, 3, 32, 32, dims)

        for i in range(dims):
            grad_tensor = torch.zeros(pen.size())
            grad_tensor[:, i] = 1
            # logits.backward(grad_tensor, retain_graph=True)
            pen.backward(grad_tensor, retain_graph=True)
            with torch.no_grad():
                Jacobian_batch[:, :, :, :, i] = x_tf_0.grad.detach()

        Jacobian.append(Jacobian_batch)

    Jacobian = torch.cat(Jacobian, dim=0)
    torch.save(Jacobian, 'Jacobian_rot.pt')
    Jacobian = Jacobian.numpy()
    Jacobian_mean = np.mean(Jacobian, axis=0)
    print(np.linalg.matrix_rank(Jacobian_mean))
Example #4
0
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)

net.rot_pred = nn.Linear(128, 4)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            args.load,
            args.dataset + args.model + '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
def main():

    # Loading pretrained model
    net_trained = WideResNet(args.layers,
                             10,
                             args.widen_factor,
                             dropRate=args.droprate)
    net_trained.x_trans_head = nn.Linear(128, 3)
    net_trained.y_trans_head = nn.Linear(128, 3)
    net_trained.rot_head = nn.Linear(128, 4)

    if os.path.isfile(args.modelload):
        net_trained.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load(args.modelload).items()
        })
    else:
        raise Exception("Cannot find {0}".format(args.modelload))

    fc_weight = net_trained.state_dict()['fc.weight']
    fc_weight = torch.transpose(fc_weight, 0, 1).cuda()
    fc_bias = net_trained.state_dict()['fc.bias'].cuda()

    del net_trained

    num_classes = 10
    # Load the model
    net = None
    if args.architecture == 'wrn':
        net = WideResNet_without_fc(args.layers,
                                    fc_weight,
                                    fc_bias,
                                    num_classes,
                                    args.widen_factor,
                                    dropRate=args.droprate)
    else:
        raise NotImplementedError()

    if not args.vanilla:
        net.x_trans_head = nn.Linear(128, 3)
        net.y_trans_head = nn.Linear(128, 3)
        net.rot_head = nn.Linear(128, 4)

    if os.path.isfile(args.model):
        # net.load_state_dict(torch.load(args.model))
        net.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load(args.model).items()
        })
    else:
        raise Exception("Cannot find {0}".format(args.model))

    in_data = PerturbDataset(dset.CIFAR10('/home/jiuhai.chen/data',
                                          train=False,
                                          download=True),
                             train_mode=False)
    out_data = PerturbDataset(dset.CIFAR100('/home/jiuhai.chen/data/cifar100',
                                            train=False,
                                            download=True),
                              train_mode=False)
    # out_data = PerturbDataset(dset.SVHN('/home/jiuhai.chen/data/svhn', split='test', transform=None, target_transform=None, download=False), train_mode=False)

    in_loader = torch.utils.data.DataLoader(in_data,
                                            batch_size=args.test_bs,
                                            shuffle=False,
                                            num_workers=args.prefetch,
                                            pin_memory=False)

    out_loader = torch.utils.data.DataLoader(out_data,
                                             batch_size=args.test_bs,
                                             shuffle=False,
                                             num_workers=args.prefetch,
                                             pin_memory=False)

    if args.vanilla:
        anomaly_func = get_anomaly_scores_vanilla_msp
    else:
        if args.test_time_train:
            raise RuntimeError("--test-time-train doesn't work very well.")
            anomaly_func = get_anomaly_scores_TTT
        else:
            anomaly_func = get_anomaly_scores

    print("Getting anomaly scores for the in_dist set")
    in_probs = anomaly_func(net, in_loader)

    print("Getting anomaly scores for the out_dist set")
    out_probs = anomaly_func(net, out_loader)

    print("Getting anomaly scores for Gaussian data")
    gauss_num_examples = 10000
    dummy_targets = torch.ones(gauss_num_examples)
    gauss_data = torch.from_numpy(
        np.float32(
            np.clip(
                np.random.normal(size=(gauss_num_examples, 32, 32, 3),
                                 scale=0.5,
                                 loc=0.5), 0, 1)))
    gauss_data = PerturbDatasetCustom(torch.utils.data.TensorDataset(
        gauss_data, dummy_targets),
                                      train_mode=False)
    gauss_loader = torch.utils.data.DataLoader(gauss_data,
                                               batch_size=args.test_bs,
                                               shuffle=True,
                                               num_workers=args.prefetch,
                                               pin_memory=False)
    # pdb.set_trace()
    gauss_probs = anomaly_func(net, gauss_loader)

    print(np.mean(out_probs), np.mean(in_probs), np.mean(gauss_probs))
    '''
    import matplotlib.pyplot as plt
    plt.figure()
    plt.subplot(211)
    plt.plot(out_probs)
    plt.ylim([0,30]) 

    plt.subplot(212)
    plt.plot(in_probs)
    plt.ylim([0,30])
    plt.show()
    exit()
    '''

    ground_truths = [0 for _ in range(10000)]
    # ground_truths += [1 for _ in range(10000)]
    ground_truths += [1 for _ in range(gauss_num_examples)]
    # ground_truths += [1 for _ in range(26032)]

    scores = np.concatenate([
        in_probs,
        #    out_probs,
        gauss_probs
    ])

    AUROC = roc_auc_score(ground_truths, scores)
    print("AUROC = {0}".format(AUROC))
def main():

    print("Using CIFAR 10")
    train_data_in = dset.CIFAR10('/home/jiuhai.chen/data',
                                 train=True,
                                 download=True)
    test_data = dset.CIFAR10('/home/jiuhai.chen/data',
                             train=False,
                             download=True)
    num_classes = 10

    # 0 airplane, 1 automobile, 2 bird, 3 cat, 4 deer, 5 dog, 6 frog, 7 horse, 8 ship, 9 truck
    # Must do != None to make sure 0 case works
    if args.in_class != None:
        print("Removing all but class {0} from train dataset and test dataset".
              format(args.in_class))
        train_data_in.data = train_data_in.data[
            train_data_in.targets == args.in_class *
            np.ones_like(train_data_in.targets)]
        test_data.data = test_data.data[test_data.targets == args.in_class *
                                        np.ones_like(test_data.targets)]
    else:
        print("Keeping all classes in both train/test datasets")

    train_data_in = PerturbDataset(train_data_in, train_mode=True)
    test_data = PerturbDataset(test_data, train_mode=False)

    train_loader_in = torch.utils.data.DataLoader(train_data_in,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.prefetch,
                                                  pin_memory=False)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=False)

    # Loading pretrained model
    net_trained = WideResNet(args.layers,
                             10,
                             args.widen_factor,
                             dropRate=args.droprate)
    net_trained.x_trans_head = nn.Linear(128, 3)
    net_trained.y_trans_head = nn.Linear(128, 3)
    net_trained.rot_head = nn.Linear(128, 4)

    if os.path.isfile(args.model):
        net_trained.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load(args.model).items()
        })
    else:
        raise Exception("Cannot find {0}".format(args.model))

    fc_weight = net_trained.state_dict()['fc.weight']
    fc_weight = torch.transpose(fc_weight, 0, 1).cuda()
    fc_bias = net_trained.state_dict()['fc.bias'].cuda()

    ## delete pretrained model
    del net_trained

    # Create model
    net = WideResNet_without_fc(args.layers,
                                fc_weight,
                                fc_bias,
                                num_classes,
                                args.widen_factor,
                                dropRate=args.droprate)
    net.x_trans_head = nn.Linear(128, 3)
    net.y_trans_head = nn.Linear(128, 3)
    net.rot_head = nn.Linear(128, 4)

    # Get GPUs ready
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    if args.ngpu > 0:
        net.cuda()
        torch.cuda.manual_seed(1)

    cudnn.benchmark = True  # fire on all cylinders

    # Set up optimization stuffs
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader_in),
            1,  # since lr_lambda computes multiplicative factor
            1e-6 / args.learning_rate))

    # Main loop
    for epoch in range(0, args.epochs):
        state['epoch'] = epoch

        begin_epoch = time.time()

        train(net, fc_weight, fc_bias, state, train_loader_in, optimizer,
              lr_scheduler)
        test(net, fc_weight, fc_bias, state, test_loader)

        # Save model
        torch.save(
            net.state_dict(),
            os.path.join(
                args.save,
                'layers_{0}_widenfactor_{1}_inclass_{2}_transform_trflossweight_{3}_epoch_{4}.pt'
                .format(
                    str(args.layers), str(args.widen_factor),
                    str(args.in_class),
                    str(args.rot_loss_weight) + "_" +
                    str(args.transl_loss_weight), str(epoch))))

        # Let us not waste space and delete the previous model
        prev_path = os.path.join(
            args.save,
            'layers_{0}_widenfactor_{1}_inclass_{2}_transform_trflossweight_{3}_epoch_{4}.pt'
            .format(
                str(args.layers), str(args.widen_factor), str(args.in_class),
                str(args.rot_loss_weight) + "_" + str(args.transl_loss_weight),
                str(epoch - 1)))
        if os.path.exists(prev_path):
            os.remove(prev_path)

        # Show results

        print(
            'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Accuracy {4:.3f}%'
            .format((epoch + 1), int(time.time() - begin_epoch),
                    state['train_loss'], state['test_loss'],
                    state['test_accuracy'] * 100))
Example #7
0
def main(args):
    # set seed
    set_seed(args.seed)

    # dataset
    id_testdata = datasets.CIFAR10('./data/', train=False, download=True)
    id_testdata = RotDataset(id_testdata, train_mode=False)

    if args.ood_dataset == 'cifar100':
        ood_testdata = datasets.CIFAR100('./data/', train=False, download=True)
    elif args.ood_dataset == 'svhn':
        ood_testdata = datasets.SVHN('./data/', split='test', download=True)
    else:
        raise ValueError(args.ood_dataset)
    ood_testdata = RotDataset(ood_testdata, train_mode=False)

    # data loader
    id_test_loader = dataloader(id_testdata,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers,
                                pin_memory=True)
    ood_test_loader = dataloader(ood_testdata,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    # load model
    num_classes = 10
    model = WideResNet(args.layers,
                       num_classes,
                       args.widen_factor,
                       dropRate=args.droprate)
    model.rot_head = nn.Linear(128, 4)
    model = model.to(device)
    model.load_state_dict(
        torch.load('./models/trained_model_{}.pth'.format(args.method),
                   map_location=device))

    # TODO:
    ## 1. calculate ood score by two methods(MSP, Rot)
    model.eval()
    id_testdata_score, ood_testdata_score = [], []

    for idx, loader in enumerate([id_test_loader, ood_test_loader]):
        for x_tf_0, x_tf_90, x_tf_180, x_tf_270, batch_y in tqdm(loader):
            batch_size = x_tf_0.shape[0]
            batch_x = torch.cat([x_tf_0, x_tf_90, x_tf_180, x_tf_270],
                                0).to(device)
            batch_y = batch_y.to(device)
            batch_rot_y = torch.cat(
                (torch.zeros(batch_size), torch.ones(batch_size),
                 2 * torch.ones(batch_size), 3 * torch.ones(batch_size)),
                0).long().to(device)

            logits, pen = model(batch_x)

            classification_probabilities = F.softmax(logits[:batch_size],
                                                     dim=-1)
            rot_logits = model.rot_head(pen)

            classification_loss = torch.max(classification_probabilities,
                                            dim=-1)[0].data.cpu()
            rotation_loss = F.cross_entropy(rot_logits,
                                            batch_rot_y,
                                            reduction='none').data

            uniform_distribution = torch.zeros_like(
                classification_probabilities).fill_(1 / num_classes)
            kl_divergence_loss = nn.KLDivLoss(reduction='none')(
                classification_probabilities.log(), uniform_distribution).data

            for i in range(batch_size):
                if args.method == 'msp':
                    score = -classification_loss[i]
                elif args.method == 'rot':
                    rotation_loss_tensor = torch.tensor([
                        rotation_loss[i], rotation_loss[i + batch_size],
                        rotation_loss[i + 2 * batch_size],
                        rotation_loss[i + 3 * batch_size]
                    ])
                    score = -torch.sum(kl_divergence_loss[i]) + torch.mean(
                        rotation_loss_tensor)

                if idx == 0:
                    id_testdata_score.append(score)
                elif idx == 1:
                    ood_testdata_score.append(score)

    y_true = torch.cat((torch.zeros(
        len(id_testdata_score)), torch.ones(len(ood_testdata_score))), 0)

    y_score = torch.cat(
        (torch.tensor(id_testdata_score), torch.tensor(ood_testdata_score)),
        0).float()

    ## 2. calculate AUROC by using ood scores
    print(f"dataset : {args.ood_dataset}, method : {args.method}")
    print(roc_auc_score(y_true, y_score))
Example #8
0
def main():
    # Load the model
    net = None
    if args.architecture == 'wrn':
        net = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate)
    else:
        raise NotImplementedError()

    if not args.vanilla:
        net.x_trans_head = nn.Linear(128, 3)
        net.y_trans_head = nn.Linear(128, 3)
        net.rot_head = nn.Linear(128, 4)

    if os.path.isfile(args.model):
        net.load_state_dict(torch.load(args.model))
    else:
        raise Exception("Cannot find {0}".format(args.model))
    
    in_data = PerturbDataset(dset.CIFAR10('~/datasets/cifarpy', train=False, download=True), train_mode=False)
    out_data = PerturbDataset(dset.CIFAR100('~/datasets/cifarpy', train=False, download=True), train_mode=False)

    in_loader = torch.utils.data.DataLoader(
        in_data,
        batch_size=args.test_bs,
        shuffle=False,
        num_workers=args.prefetch,
        pin_memory=False
    )

    out_loader = torch.utils.data.DataLoader(
        out_data,
        batch_size=args.test_bs,
        shuffle=False,
        num_workers=args.prefetch,
        pin_memory=False
    )

    if args.vanilla:
        anomaly_func = get_anomaly_scores_vanilla_msp
    else:
        if args.test_time_train:
            raise RuntimeError("--test-time-train doesn't work very well.")
            anomaly_func = get_anomaly_scores_TTT
        else:
            anomaly_func = get_anomaly_scores

    print("Getting anomaly scores for the in_dist set")
    in_probs = anomaly_func(net, in_loader)

    print("Getting anomaly scores for the out_dist set")
    out_probs = anomaly_func(net, out_loader)

    print("Getting anomaly scores for Gaussian data")
    gauss_num_examples = 10000
    dummy_targets = torch.ones(gauss_num_examples)
    gauss_data = torch.from_numpy(np.float32(np.clip(np.random.normal(size=(gauss_num_examples, 32, 32, 3), scale=0.5, loc=0.5), 0, 1)))
    gauss_data = PerturbDatasetCustom(torch.utils.data.TensorDataset(gauss_data, dummy_targets), train_mode=False)
    gauss_loader = torch.utils.data.DataLoader(gauss_data, batch_size=args.test_bs, shuffle=True, num_workers=args.prefetch, pin_memory=False)
    # pdb.set_trace()
    gauss_probs = anomaly_func(net, gauss_loader)

    ground_truths = [0 for _ in range(10000)] 
    ground_truths += [1 for _ in range(10000)] 
    # ground_truths += [1 for _ in range(gauss_num_examples)]

    scores = np.concatenate([
        in_probs, 
        out_probs, 
    #    gauss_probs
    ])

    AUROC = roc_auc_score(ground_truths, scores)
    print("AUROC = {0}".format(AUROC))
Example #9
0
def main():
    # Load the model
    net = None
    if args.architecture == 'wrn':
        net = WideResNet(args.layers,
                         10,
                         args.widen_factor,
                         dropRate=args.droprate)
    else:
        raise NotImplementedError()

    if not args.vanilla:
        net.x_trans_head = nn.Linear(128, 3)
        net.y_trans_head = nn.Linear(128, 3)
        net.rot_head = nn.Linear(128, 4)

    if os.path.isfile(args.model):
        # net.load_state_dict(torch.load(args.model))
        net.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load(args.model).items()
        })
    else:
        raise Exception("Cannot find {0}".format(args.model))

    in_data = PerturbDataset(dset.CIFAR10('/home/jiuhai.chen/data',
                                          train=False,
                                          download=True),
                             train_mode=False)
    out_data = PerturbDataset(dset.CIFAR100('/home/jiuhai.chen/data/cifar100',
                                            train=False,
                                            download=True),
                              train_mode=False)
    # out_data = PerturbDataset(dset.SVHN('/home/jiuhai.chen/data/svhn', split='test', transform=None, target_transform=None, download=False), train_mode=False)

    in_loader = torch.utils.data.DataLoader(in_data,
                                            batch_size=args.test_bs,
                                            shuffle=False,
                                            num_workers=args.prefetch,
                                            pin_memory=False)

    out_loader = torch.utils.data.DataLoader(out_data,
                                             batch_size=args.test_bs,
                                             shuffle=False,
                                             num_workers=args.prefetch,
                                             pin_memory=False)

    if args.vanilla:
        anomaly_func = get_anomaly_scores_vanilla_msp
    else:
        if args.test_time_train:
            raise RuntimeError("--test-time-train doesn't work very well.")
            anomaly_func = get_anomaly_scores_TTT
        else:
            anomaly_func = get_anomaly_scores

    print("Getting anomaly scores for the in_dist set")
    in_probs, X_embedded, y = anomaly_func(net, in_loader)
    # sns.scatterplot(X_embedded[:,0], X_embedded[:,1], hue=y, legend='full', palette=palette)
    # plt.show()
    # exit()

    print("Getting anomaly scores for the out_dist set")
    # out_probs = anomaly_func(net, out_loader)

    print("Getting anomaly scores for Gaussian data")
    gauss_num_examples = 10000
    dummy_targets = torch.ones(gauss_num_examples)
    gauss_data = torch.from_numpy(
        np.float32(
            np.clip(
                np.random.normal(size=(gauss_num_examples, 32, 32, 3),
                                 scale=0.5,
                                 loc=0.5), 0, 1)))
    gauss_data = PerturbDatasetCustom(torch.utils.data.TensorDataset(
        gauss_data, dummy_targets),
                                      train_mode=False)
    gauss_loader = torch.utils.data.DataLoader(gauss_data,
                                               batch_size=args.test_bs,
                                               shuffle=True,
                                               num_workers=args.prefetch,
                                               pin_memory=False)
    # pdb.set_trace()
    # gauss_probs = anomaly_func(net, gauss_loader)
    in_probs, X_embedded, y = anomaly_func(net, out_loader)
    print(np.mean(in_probs, axis=0))
    # sns.scatterplot(X_embedded[:,0], X_embedded[:,1], hue=dummy_targets, legend='full', palette=palette_gaussian)
    plt.scatter(X_embedded[:, 0],
                X_embedded[:, 1],
                c=in_probs,
                s=10,
                lw=0,
                cmap='RdYlGn')
    plt.colorbar()
    plt.show()
    exit()

    ground_truths = [0 for _ in range(10000)]
    # ground_truths += [1 for _ in range(10000)]
    ground_truths += [1 for _ in range(gauss_num_examples)]
    # ground_truths += [1 for _ in range(26032)]

    scores = np.concatenate([
        in_probs,
        #   out_probs,
        gauss_probs
    ])

    AUROC = roc_auc_score(ground_truths, scores)
    print("AUROC = {0}".format(AUROC))