示例#1
0
文件: test.py 项目: wangkua1/fs-ood
    def evaluate(data):

        corrects = []
        for _ in tqdm(range(args['data.test_episodes'])):
            sample = load_episode(data, test_tr, args['data.test_way'],
                                  args['data.test_shot'],
                                  args['data.test_query'], device)
            corrects.append(classification_accuracy(sample, model)[0])
        acc = torch.mean(torch.cat(corrects))
        return acc.item()
示例#2
0
def eval_ood_aurocs(
  ood_tensor, 
  episodic_in_data,
  tr, 
  n_way,
  n_shot,
  n_query,
  n_episodes,
  device,
  conf,
  db=False,
  out_name='',
  no_grad=True
  ):
  if ood_tensor is not None:
    N = n_way*n_query*n_episodes
    # repeat if necessary
    if len(ood_tensor) < N:
      ood_tensor = np.vstack([ood_tensor for _ in range(N//len(ood_tensor)+1)])[:N]
  metrics = defaultdict(list)
  for n in tqdm(range(n_episodes),desc='eval_ood_aurocs',dynamic_ncols=True):
    sample = load_episode(episodic_in_data, tr, n_way, n_shot, n_query, device)
    if ood_tensor is not None:
      bs = n_way*n_query
      sample['ooc_xq'] = torch.stack([tr(x) for x  in ood_tensor[n*bs:(n+1)*bs]]).to(device)
    with torch.set_grad_enabled( not no_grad ):
      in_score, out_score = score_batch(conf, sample)
      in_score = in_score.numpy()
      out_score = out_score.numpy()
      # if db:
      #   in_score1= conf.score(in_x).detach().cpu().numpy()
      #   out_score1=conf.score(out_x).detach().cpu().numpy()
      #   print("db --- d-score", np.sum((in_score - in_score1)**2 ), np.sum((out_score - out_score1)**2 ))
      _, auroc, _, _, fpr = show_ood_detection_results_softmax(in_score,out_score)
      metrics['aurocs'].append(auroc)
      metrics['fprs'].append(fpr)
  if db: 
    print("Avg `in_score`: ", np.mean(in_score))
    print("Avg `out_score`: ", np.mean(out_score))
    # vutils.save_image(out_x[:100], f'episodic-{out_name}.jpeg' , normalize=True, nrow=10) 
  return metrics
示例#3
0
文件: train.py 项目: wangkua1/fs-ood
    def evaluate():

        nonlocal best_metric_value
        nonlocal patience_elapsed
        nonlocal stop
        nonlocal epoch

        corrects = []
        for _ in tqdm(range(args['data.test_episodes']),
                      desc="Epoch {:d} Val".format(epoch + 1)):
            sample = load_episode(val_data, test_tr, args['data.test_way'],
                                  args['data.test_shot'],
                                  args['data.test_query'], device)
            corrects.append(classification_accuracy(sample, model)[0])
        val_acc = torch.mean(torch.cat(corrects))
        iteration_logger.writerow({
            'global_iteration': epoch,
            'val_acc': val_acc.item()
        })
        plot_csv(iteration_logger.filename, iteration_logger.filename)

        print(f"Epoch {epoch}: Val Acc: {val_acc}")

        if val_acc > best_metric_value:
            best_metric_value = val_acc
            print("==> best model (metric = {:0.6f}), saving model...".format(
                best_metric_value))
            model.cpu()
            torch.save(model, os.path.join(args['log.exp_dir'],
                                           'best_model.pt'))
            model.to(device)
            patience_elapsed = 0

        else:
            patience_elapsed += 1
            if patience_elapsed > args['train.patience']:
                print("==> patience {:d} exceeded".format(
                    args['train.patience']))
                stop = True
示例#4
0
文件: train.py 项目: wangkua1/fs-ood
def main(args):

    device = 'cuda:0' if args['data.cuda'] else 'cpu'

    args['log.exp_dir'] = args['log.exp_dir']

    if not os.path.isdir(args['log.exp_dir']):
        os.makedirs(args['log.exp_dir'])

    # save opts
    with open(os.path.join(args['log.exp_dir'], 'args.json'), 'w') as f:
        json.dump(args, f)
        f.write('\n')

    # Loggin
    iteration_fieldnames = ['global_iteration', 'val_acc']
    iteration_logger = CSVLogger(every=0,
                                 fieldnames=iteration_fieldnames,
                                 filename=os.path.join(args['log.exp_dir'],
                                                       'iteration_log.csv'))

    # Set the random seed manually for reproducibility.
    np.random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    if args['data.cuda']:
        torch.cuda.manual_seed(args['seed'])

    if args['data.dataset'] == 'omniglot':
        raise
        train_tr = None
        test_tr = None
    elif args['data.dataset'] == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', args['dataroot'])
        val_data = get_dataset('miniimagenet-val', args['dataroot'])
        train_tr = get_transform(
            'cifar_augment_normalize_84'
            if args['data_augmentation'] else 'cifar_normalize')
        test_tr = get_transform('cifar_normalize')

    elif args['data.dataset'] == 'cifar100':
        train_data = get_dataset('cifar-fs-train-train')
        val_data = get_dataset('cifar-fs-val')
        train_tr = get_transform(
            'cifar_augment_normalize'
            if args['data_augmentation'] else 'cifar_normalize')
        test_tr = get_transform('cifar_normalize')
    else:
        raise

    model = protonet.create_model(**args)

    if args['model.model_path'] != '':
        loaded = torch.load(args['model.model_path'])
        if not 'Protonet' in str(loaded.__class__):
            pretrained = ResNetClassifier(64, train_data['im_size']).to(device)
            pretrained.load_state_dict(loaded)
            model.encoder = pretrained.encoder
        else:
            model = loaded

    model = model.to(device)

    max_epoch = args['train.epochs']
    epoch = 0
    stop = False
    patience_elapsed = 0
    best_metric_value = 0.0

    def evaluate():

        nonlocal best_metric_value
        nonlocal patience_elapsed
        nonlocal stop
        nonlocal epoch

        corrects = []
        for _ in tqdm(range(args['data.test_episodes']),
                      desc="Epoch {:d} Val".format(epoch + 1)):
            sample = load_episode(val_data, test_tr, args['data.test_way'],
                                  args['data.test_shot'],
                                  args['data.test_query'], device)
            corrects.append(classification_accuracy(sample, model)[0])
        val_acc = torch.mean(torch.cat(corrects))
        iteration_logger.writerow({
            'global_iteration': epoch,
            'val_acc': val_acc.item()
        })
        plot_csv(iteration_logger.filename, iteration_logger.filename)

        print(f"Epoch {epoch}: Val Acc: {val_acc}")

        if val_acc > best_metric_value:
            best_metric_value = val_acc
            print("==> best model (metric = {:0.6f}), saving model...".format(
                best_metric_value))
            model.cpu()
            torch.save(model, os.path.join(args['log.exp_dir'],
                                           'best_model.pt'))
            model.to(device)
            patience_elapsed = 0

        else:
            patience_elapsed += 1
            if patience_elapsed > args['train.patience']:
                print("==> patience {:d} exceeded".format(
                    args['train.patience']))
                stop = True

    optim_method = getattr(optim, args['train.optim_method'])
    params = model.parameters()

    optimizer = optim_method(params,
                             lr=args['train.learning_rate'],
                             weight_decay=args['train.weight_decay'])

    scheduler = lr_scheduler.StepLR(optimizer,
                                    args['train.decay_every'],
                                    gamma=0.5)

    while epoch < max_epoch and not stop:
        evaluate()

        model.train()
        if epoch % args['ckpt_every'] == 0:
            model.cpu()
            torch.save(model,
                       os.path.join(args['log.exp_dir'], f'model_{epoch}.pt'))
            model.to(device)

        scheduler.step()

        for _ in tqdm(range(args['data.train_episodes']),
                      desc="Epoch {:d} train".format(epoch + 1)):
            sample = load_episode(train_data, train_tr, args['data.way'],
                                  args['data.shot'], args['data.query'],
                                  device)
            optimizer.zero_grad()
            loss, output = model.loss(sample)
            loss.backward()
            optimizer.step()

        epoch += 1
示例#5
0
def main(opt):

    eval_exp_name = opt['exp_name']
    device = 'cuda:0'

    # Load data
    if opt['dataset'] == 'cifar-fs':
        train_data = get_dataset('cifar-fs-train-train', opt['dataroot'])
        val_data = get_dataset('cifar-fs-val', opt['dataroot'])
        test_data = get_dataset('cifar-fs-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize')
        normalize = cifar_normalize
    elif opt['dataset'] == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', opt['dataroot'])
        val_data = get_dataset('miniimagenet-val', opt['dataroot'])
        test_data = get_dataset('miniimagenet-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize_84')
        normalize = cifar_normalize

    np.random.seed(1234)
    torch.manual_seed(1234)
    torch.cuda.manual_seed(1234)

    if opt['db']:
        ood_distributions = ['ooe', 'gaussian']
    else:
        ood_distributions = ['ooe', 'gaussian', 'svhn']
        # ood_distributions = ['ooe', 'gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun']

    ood_tensors = [('ooe', None)] + [(out_name,
                                      load_ood_data({
                                          'name': out_name,
                                          'ood_scale': 1,
                                          'n_anom': 10000,
                                      }))
                                     for out_name in ood_distributions[1:]]

    # Load trained model
    loaded = torch.load(opt['model.model_path'])
    if not isinstance(loaded, OrderedDict):
        protonet = loaded
    else:
        classifier = ResNetClassifier(64, train_data['im_size']).to(device)
        classifier.load_state_dict(loaded)
        protonet = Protonet(classifier.encoder)
    encoder = protonet.encoder
    encoder.eval()
    encoder.to(device)
    protonet.eval()
    protonet.to(device)

    # Init Confidence model
    if opt['ood_method'] == 'deep-ed-iso':
        deep_mahala_obj = DeepMahala(None,
                                     None,
                                     None,
                                     encoder,
                                     device,
                                     num_feats=encoder.depth,
                                     num_classes=train_data['n_classes'],
                                     pretrained_path="",
                                     fit=False,
                                     normalize=None)
        conf = DMConfidence(deep_mahala_obj, {
            'ls': range(encoder.depth),
            'reduction': 'max',
            'g_magnitude': 0
        }, True, 'iso').to(device)
    elif opt['ood_method'] == 'native-spp':
        conf = FSCConfidence(protonet, 'spp')
    elif opt['ood_method'] == 'oec':
        oec_opt = json.load(
            open(os.path.join(os.path.dirname(opt['oec_path']), 'args.json'),
                 'r'))
        init_sample = load_episode(train_data, tr, oec_opt['data.test_way'],
                                   oec_opt['data.test_shot'],
                                   oec_opt['data.test_query'], device)
        if oec_opt['confidence_method'] == 'oec':
            oec_conf = OECConfidence(None, protonet, init_sample, oec_opt)
        else:
            oec_conf = DeepOECConfidence(None, protonet, init_sample, oec_opt)
        oec_conf.load_state_dict(torch.load(opt['oec_path']))
        oec_conf.eval()
        oec_conf.to(device)
        conf = oec_conf

    # Turn confidence score into a threshold based classifier
    # Select threshold by "max-accuracy"
    # Select temperature by "best-calibration" in the binary problem
    # done using the meta-train set
    in_scores = []
    out_scores = []
    for n in tqdm(range(100)):
        sample = load_episode(train_data, tr, opt['data.test_way'],
                              opt['data.test_shot'], opt['data.test_query'],
                              device)
        in_score, out_score = score_batch(conf, sample)
        in_scores.append(in_score)
        out_scores.append(out_score)
    in_scores = torch.cat(in_scores)
    out_scores = torch.cat(out_scores)

    def _compute_acc(in_scores, out_scores, t):
        N = len(in_scores) + len(out_scores)
        return (torch.sum(in_scores >= t) +
                torch.sum(out_scores < t)).item() / float(N)

    best_threshold = torch.min(in_scores)
    best_acc = _compute_acc(in_scores, out_scores, best_threshold)

    for t in in_scores:
        acc = _compute_acc(in_scores, out_scores, t)
        if acc > best_acc:
            best_acc = acc
            best_threshold = t

    def _compute_confs(in_scores, out_scores, t, temp):
        in_p = torch.sigmoid((in_scores - t) / temp)
        corrects = in_p >= .5
        confs = torch.max(torch.stack([in_p, 1 - in_p]), 0)[0]
        out_p = torch.sigmoid((out_scores - t) / temp)
        corrects = torch.cat([corrects, out_p < .5])
        confs = torch.cat(
            [confs, torch.max(torch.stack([out_p, 1 - out_p]), 0)[0]])
        return confs, corrects

    def compute_eces(candidate_temps, in_scores, out_scores, best_threshold):
        eces = []
        for temp in candidate_temps:
            confs, corrects = _compute_confs(in_scores, out_scores,
                                             best_threshold, temp)
            ece = compute_ece(
                *prep_accs(confs.numpy(), corrects.numpy(), bins=20))
            eces.append(ece)
        return eces

    min_log_temp = -1
    log_interval = 2
    npts = 10

    for _ in range(opt['max_temp_select_iter']):
        print("..selecting temperature")
        candidate_temps = np.logspace(min_log_temp,
                                      min_log_temp + log_interval, npts)
        eces = compute_eces(candidate_temps, in_scores, out_scores,
                            best_threshold)
        min_idx = np.argmin(eces)
        if min_idx == 0:
            min_log_temp -= log_interval // 2
        elif min_idx == npts - 1:
            min_log_temp += log_interval // 2
        else:
            break

    best_ece = eces[min_idx]
    best_temp = candidate_temps[min_idx]

    print(
        f"Best ACC:{best_acc}, thresh:{best_threshold}, Best ECE:{best_ece}, temp:{best_temp}"
    )

    def get_95_percent_ci(std):
        """Computes the 95% confidence interval from the standard deviation."""
        return std * 1.96 / np.sqrt(data_opt['data.test_episodes'])

    active_supervised = defaultdict(list)
    active_augmented = defaultdict(list)
    ssl_soft = defaultdict(list)
    ssl_hard = defaultdict(list)
    # for ood_idx, curr_ood in tqdm(enumerate(all_distributions)):
    for curr_ood, ood_tensor in ood_tensors:

        in_scores = defaultdict(list)
        out_scores = defaultdict(list)
        # Compute and collect scores for all examples
        aurocs, auprs, fprs = defaultdict(list), defaultdict(
            list), defaultdict(list)

        for n in tqdm(range(opt['data.test_episodes'])):
            n_total_query = np.max([
                opt['data.test_query'] + opt['n_unlabeled_per_class'],
                opt['n_distractor_per_class']
            ])
            sample = load_episode(test_data, tr, opt['data.test_way'],
                                  opt['data.test_shot'], n_total_query, device)
            if curr_ood != 'ooe':
                bs = opt['data.test_way'] * opt['data.test_query']
                ridx = np.random.permutation(bs)
                sample['ooc_xq'] = torch.stack(
                    [tr(x) for x in ood_tensor[ridx]]).to(device)
                way, _, c, h, w = sample['xq'].shape
                sample['ooc_xq'] = sample['ooc_xq'].reshape(way, -1, c, h, w)
                # if curr_ood in ['gaussian', 'rademacher']:
                #   sample['ooc_xq'] *= 4

            all_xq = sample['xq'].clone()
            sample['xq'] = all_xq[:, :opt[
                'n_unlabeled_per_class']]  # Unlabelled pool
            sample[
                'xq2'] = all_xq[:, opt['n_unlabeled_per_class']:
                                opt['n_unlabeled_per_class'] +
                                opt['data.test_query']]  # Final test queries
            sample['ooc_xq'] = sample[
                'ooc_xq'][:, :opt['n_distractor_per_class']]
            """
            1.  OOD classification on the 'unlabelled' set
            """
            # In vs Out
            in_score, out_score = score_batch(conf, sample)

            num_in = in_score.shape[0]
            confs, corrects = _compute_confs(in_score, out_score,
                                             best_threshold, best_temp)
            in_mask = corrects[:num_in].reshape(
                sample['xq'].size(0), sample['xq'].size(1)).float().to(device)
            out_mask = 1 - corrects[num_in:].reshape(
                sample['ooc_xq'].size(0),
                sample['ooc_xq'].size(1)).float().to(device)
            """
            2.0
            """
            budget_active = in_score.size(0)
            scores = torch.cat([in_score, out_score], -1)
            ipdb.set_trace()
            selected_inds = torch.sort(scores)[1][scores.size(0) -
                                                  budget_active:]
            selected_inds_in = selected_inds[selected_inds < in_score.size(0)]
            budget_mask = torch.zeros(in_score.size(0)).to(device)
            budget_mask.scatter_(0, selected_inds_in.to(device).long(), 1)
            budget_mask = budget_mask.reshape(
                sample['xq'].size(0), sample['xq'].size(1)).float().to(device)
            """
            2.  Add labels to the predicted unlabelled examples
            """
            # Collect the incorrectly kept OOD examples
            included_distractors = sample['ooc_xq'][out_mask.byte()]
            # Pad them to N-way multiples, and assign random labels (done simply by reshaping)
            n_way = sample['xs'].shape[0]
            im_shape = list(sample['xs'].shape[2:])
            n_res = n_way - (included_distractors.shape[0] % n_way)
            distractor_mask = torch.ones([included_distractors.shape[0]
                                          ]).to(device)

            zeros = torch.zeros([n_res] + im_shape).to(device)
            included_distractors = torch.cat([included_distractors, zeros])
            distractor_mask = torch.cat(
                [distractor_mask,
                 torch.zeros([n_res]).to(device)])
            # the reason we permute is to spread the padded zero across ways
            included_distractors = included_distractors.reshape(
                [-1, n_way] + im_shape).permute(1, 0, 2, 3, 4)
            distractor_mask = distractor_mask.reshape([-1,
                                                       n_way]).permute(1, 0)
            """
            2.5 SSL
            """
            # predict k-way using classifier
            n_way, n_aug_shot, n_ch, n_dim, _ = sample['xq'].shape
            lpy_dic = protonet.log_p_y(sample['xs'], sample['xq'], mask=None)
            log_p_y, target_inds = lpy_dic['log_p_y'], lpy_dic['target_inds']

            preds = log_p_y.max(-1)[1]

            def reorder(unlabelled, preds, py, make_soft=True):
                if py is not None:
                    reshaped_py = py.reshape(-1)
                n_way, n_aug_shot, n_ch, n_dim, _ = unlabelled.shape
                reshaped_unlabelled = unlabelled.reshape(
                    n_aug_shot * n_way, n_ch, n_dim, n_dim)
                reshaped_predicted_labels = preds.reshape(-1)
                unlabelled = torch.zeros(
                    (n_way, n_aug_shot * n_way, n_ch, n_dim, n_dim))
                mask = torch.zeros((n_way, n_aug_shot * n_way))
                for idx, label in enumerate(reshaped_predicted_labels):
                    unlabelled[label,
                               idx] = reshaped_unlabelled[idx]  # (n_shot, ...)
                    if make_soft:
                        mask[label, idx] = reshaped_py[idx]
                    else:
                        mask[label, idx] = 1  # (n_shot, )
                return unlabelled.to(device), mask.to(device)

            gt_in_unlabelled, gt_in_weights = reorder(sample['xq'], preds,
                                                      log_p_y.max(-1)[0].exp(),
                                                      True)
            _, in_mask_reordered = reorder(sample['xq'], preds, in_mask, True)

            # for the gt OOD ones
            lpy_dic = protonet.log_p_y(sample['xs'],
                                       sample['ooc_xq'],
                                       mask=None)
            log_p_y = lpy_dic['log_p_y']

            preds = log_p_y.max(-1)[1]
            gt_ood_unlabelled, gt_ood_weights = reorder(
                sample['ooc_xq'], preds,
                log_p_y.max(-1)[0].exp(), True)
            _, out_mask_reordered = reorder(sample['ooc_xq'], preds, out_mask,
                                            True)

            # Support + ALL unlabelled
            _ssl_soft = compute_acc(
                protonet,
                torch.cat([sample['xs'], gt_in_unlabelled, gt_ood_unlabelled],
                          1), sample['xq2'],
                torch.cat([
                    torch.ones(sample['xs'].shape[:2]).to(device),
                    gt_in_weights, gt_ood_weights
                ], 1))
            _acc_hard = compute_acc(
                protonet,
                torch.cat([sample['xs'], gt_in_unlabelled, gt_ood_unlabelled],
                          1), sample['xq2'],
                torch.cat([
                    torch.ones(sample['xs'].shape[:2]).to(device),
                    in_mask_reordered * gt_in_weights,
                    out_mask_reordered * gt_ood_weights
                ], 1))
            """
            3.  Evaluate k-way accuracy after adding examples
            """
            _active_supervised = compute_acc(protonet, sample['xs'],
                                             sample['xq2'], None)

            # Support + Budgeted unlabelled
            _active_augmented = compute_acc(
                protonet, torch.cat([sample['xs'], sample['xq']], 1),
                sample['xq2'],
                torch.cat([
                    torch.ones(sample['xs'].shape[:2]).to(device), budget_mask
                ], 1))

            ssl_soft[curr_ood].append(_ssl_soft)
            ssl_hard[curr_ood].append(_acc_hard)
            active_supervised[curr_ood].append(_active_supervised)
            active_augmented[curr_ood].append(_active_augmented)

    if not os.path.exists(opt['output_dir']):
        os.makedirs(opt['output_dir'])

    pickle.dump((ssl_soft, ssl_hard, active_supervised, active_augmented),
                open(
                    os.path.join(opt['output_dir'],
                                 f'eval_active_{eval_exp_name}.pkl'), 'wb'))

    print("===> Aggregating results")
    aggr_args = namedtuple('Arg',
                           ('exp_dir', 'f_acq'))(exp_dir=opt['output_dir'],
                                                 f_acq='conv4')
    aggregate_eval_active.main(aggr_args)
    print('===> Done')
    sys.exit()
示例#6
0
def main(opt):

    # Logging
    trace_file = os.path.join(opt['output_dir'],
                              '{}_trace.txt'.format(opt['exp_name']))

    # Load data
    if opt['dataset'] == 'cifar-fs':
        train_data = get_dataset('cifar-fs-train-train', opt['dataroot'])
        val_data = get_dataset('cifar-fs-val', opt['dataroot'])
        test_data = get_dataset('cifar-fs-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize')
        normalize = cifar_normalize
    elif opt['dataset'] == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', opt['dataroot'])
        val_data = get_dataset('miniimagenet-val', opt['dataroot'])
        test_data = get_dataset('miniimagenet-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize_84')
        normalize = cifar_normalize

    if opt['input_regularization'] == 'oe':
        reg_data = load_ood_data({
            'name': 'tinyimages',
            'ood_scale': 1,
            'n_anom': 50000,
        })

    if not opt['ooe_only']:
        if opt['db']:
            ood_distributions = ['ooe', 'gaussian']
        else:
            ood_distributions = [
                'ooe', 'gaussian', 'rademacher', 'texture3', 'svhn',
                'tinyimagenet', 'lsun'
            ]
            if opt['input_regularization'] == 'oe':
                ood_distributions.append('tinyimages')

        ood_tensors = [('ooe', None)] + [(out_name,
                                          load_ood_data({
                                              'name': out_name,
                                              'ood_scale': 1,
                                              'n_anom': 10000,
                                          }))
                                         for out_name in ood_distributions[1:]]

    # Load trained model
    loaded = torch.load(opt['model.model_path'])
    if not isinstance(loaded, OrderedDict):
        fs_model = loaded
    else:
        classifier = ResNetClassifier(64, train_data['im_size']).to(device)
        classifier.load_state_dict(loaded)
        fs_model = Protonet(classifier.encoder)
    fs_model.eval()
    fs_model = fs_model.to(device)

    # Init Confidence Methods
    if opt['confidence_method'] == 'oec':
        init_sample = load_episode(train_data, tr, opt['data.test_way'],
                                   opt['data.test_shot'],
                                   opt['data.test_query'], device)
        conf_model = OECConfidence(None, fs_model, init_sample, opt)
    elif opt['confidence_method'] == 'deep-oec':
        init_sample = load_episode(train_data, tr, opt['data.test_way'],
                                   opt['data.test_shot'],
                                   opt['data.test_query'], device)
        conf_model = DeepOECConfidence(None, fs_model, init_sample, opt)
    elif opt['confidence_method'] == 'dm-iso':
        encoder = fs_model.encoder
        deep_mahala_obj = DeepMahala(None,
                                     None,
                                     None,
                                     encoder,
                                     device,
                                     num_feats=encoder.depth,
                                     num_classes=train_data['n_classes'],
                                     pretrained_path="",
                                     fit=False,
                                     normalize=None)

        conf_model = DMConfidence(deep_mahala_obj, {
            'ls': range(encoder.depth),
            'reduction': 'max',
            'g_magnitude': .1
        }, True, 'iso')

    if opt['pretrained_oec_path']:
        conf_model.load_state_dict(torch.load(opt['pretrained_oec_path']))

    conf_model.to(device)
    print(conf_model)

    optimizer = optim.Adam(conf_model.confidence_parameters(),
                           lr=opt['lr'],
                           weight_decay=opt['wd'])
    scheduler = StepLR(optimizer,
                       step_size=opt['lrsche_step_size'],
                       gamma=opt['lrsche_gamma'])

    num_param = sum(p.numel() for p in conf_model.confidence_parameters())
    print(f"Learning Confidence, Number of Parameters -- {num_param}")

    if conf_model.pretrain_parameters() is not None:
        pretrain_optimizer = optim.Adam(conf_model.pretrain_parameters(),
                                        lr=10)
        pretrain_iter = 100

    start_idx = 0
    if opt['resume']:
        last_ckpt_path = os.path.join(opt['output_dir'], 'last_ckpt.pt')
        if os.path.exists(last_ckpt_path):
            try:
                last_ckpt = torch.load(last_ckpt_path)
                if 'conf_model' in last_ckpt:
                    conf_model = last_ckpt['conf_model']
                else:
                    sd = last_ckpt['conf_model_sd']
                    conf_model.load_state_dict(sd)
                optimizer = last_ckpt['optimizer']
                pretrain_optimizer = last_ckpt['pretrain_optimizer']
                scheduler = last_ckpt['scheduler']
                start_idx = last_ckpt['outer_idx']
                conf_model.to(device)
            except EOFError:
                print(
                    "\n\nResuming but got EOF error, starting from init..\n\n")

    wandb.run.name = opt['exp_name']
    wandb.run.save()
    # try:
    wandb.watch(conf_model)
    # except: # resuming a run
    #     pass

    # Eval and Logging
    confs = {
        opt['confidence_method']: conf_model,
    }
    if opt['confidence_method'] == 'oec':
        confs['ed'] = FSCConfidence(fs_model, 'ed')
    elif opt['confidence_method'] == 'deep-oec':
        encoder = fs_model.encoder
        deep_mahala_obj = DeepMahala(None,
                                     None,
                                     None,
                                     encoder,
                                     device,
                                     num_feats=encoder.depth,
                                     num_classes=train_data['n_classes'],
                                     pretrained_path="",
                                     fit=False,
                                     normalize=None)
        confs['dm'] = DMConfidence(deep_mahala_obj, {
            'ls': range(encoder.depth),
            'reduction': 'max',
            'g_magnitude': 0
        }, True, 'iso').to(device)
    # Temporal Ensemble for Evaluation
    if opt['n_ensemble'] > 1:
        nets = [deepcopy(conf_model) for _ in range(opt['n_ensemble'])]
        confs['mixture-' + opt['confidence_method']] = Ensemble(
            nets, 'mixture')
        confs['poe-' + opt['confidence_method']] = Ensemble(nets, 'poe')
        ensemble_update_interval = opt['eval_every_outer'] // opt['n_ensemble']

    iteration_fieldnames = ['global_iteration']
    for c in confs:
        iteration_fieldnames += [
            f'{c}_train_ooe', f'{c}_val_ooe', f'{c}_test_ooe', f'{c}_ood'
        ]
    iteration_logger = CSVLogger(every=0,
                                 fieldnames=iteration_fieldnames,
                                 filename=os.path.join(opt['output_dir'],
                                                       'iteration_log.csv'))

    best_val_ooe = 0
    PATIENCE = 5  # Number of evaluations to wait
    waited = 0

    progress_bar = tqdm(range(start_idx, opt['train_iter']))
    for outer_idx in progress_bar:
        sample = load_episode(train_data, tr, opt['data.test_way'],
                              opt['data.test_shot'], opt['data.test_query'],
                              device)

        conf_model.train()
        if opt['full_supervision']:  # sanity check
            conf_model.support(sample['xs'])
            in_score = conf_model.score(sample['xq'], detach=False).squeeze()
            out_score = conf_model.score(sample['ooc_xq'],
                                         detach=False).squeeze()
            out_scores = [out_score]
            for curr_ood, ood_tensor in ood_tensors:
                if curr_ood == 'ooe':
                    continue
                start = outer_idx % (len(ood_tensor) // 2)
                stop = min(
                    start + sample['xq'].shape[0] * sample['xq'].shape[0],
                    len(ood_tensor) // 2)
                oxq = torch.stack([tr(x)
                                   for x in ood_tensor[start:stop]]).to(device)
                o = conf_model.score(oxq, detach=False).squeeze()
                out_scores.append(o)
            #
            out_score = torch.cat(out_scores)
            in_score = in_score.repeat(len(ood_tensors))
            loss, acc = compute_loss_bce(in_score,
                                         out_score,
                                         mean_center=False)
        else:
            conf_model.support(sample['xs'])
            if opt['interpolate']:
                half_n_way = sample['xq'].shape[0] // 2
                interp = .5 * (sample['xq'][:half_n_way] +
                               sample['xq'][half_n_way:2 * half_n_way])
                sample['ooc_xq'][:half_n_way] = interp

            if opt['input_regularization'] == 'oe':
                # Reshape ooc_xq
                nw, nq, c, h, w = sample['ooc_xq'].shape
                sample['ooc_xq'] = sample['ooc_xq'].view(1, nw * nq, c, h, w)
                oe_bs = int(nw * nq * opt['input_regularization_percent'])

                start = (outer_idx * oe_bs) % len(reg_data)
                end = np.min([start + oe_bs, len(reg_data)])
                oe_batch = torch.stack([tr(x) for x in reg_data[start:end]
                                        ]).to(device)
                oe_batch = oe_batch.unsqueeze(0)
                sample['ooc_xq'][:, :oe_batch.shape[1]] = oe_batch

            if opt['in_out_1_batch']:
                inps = torch.cat([sample['xq'], sample['ooc_xq']], 1)
                scores = conf_model.score(inps, detach=False).squeeze()
                in_score, out_score = scores[:sample['xq'].shape[1]], scores[
                    sample['xq'].shape[1]:]
            else:
                in_score = conf_model.score(sample['xq'],
                                            detach=False).squeeze()
                out_score = conf_model.score(sample['ooc_xq'],
                                             detach=False).squeeze()

            loss, acc = compute_loss_bce(in_score,
                                         out_score,
                                         mean_center=False)

        if conf_model.pretrain_parameters(
        ) is not None and outer_idx < pretrain_iter:
            pretrain_optimizer.zero_grad()
            loss.backward()
            pretrain_optimizer.step()
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()

        progress_bar.set_postfix(loss='{:.3e}'.format(loss),
                                 acc='{:.3e}'.format(acc))

        # Update Ensemble
        if opt['n_ensemble'] > 1 and outer_idx % ensemble_update_interval == 0:
            update_ind = (outer_idx //
                          ensemble_update_interval) % opt['n_ensemble']
            if opt['db']:
                print(f"===> Updating Ensemble: {update_ind}")
            confs['mixture-' +
                  opt['confidence_method']].nets[update_ind] = deepcopy(
                      conf_model)
            confs['poe-' +
                  opt['confidence_method']].nets[update_ind] = deepcopy(
                      conf_model)

        # AUROC eval
        if outer_idx % opt['eval_every_outer'] == 0:
            if not opt['eval_in_train']:
                conf_model.eval()

            # Eval..
            stats_dict = {'global_iteration': outer_idx}
            for conf_name, conf in confs.items():
                conf.eval()
                # OOE eval
                ooe_aurocs = {}
                for split, in_data in [('train', train_data),
                                       ('val', val_data), ('test', test_data)]:
                    auroc = np.mean(
                        eval_ood_aurocs(
                            None,
                            in_data,
                            tr,
                            opt['data.test_way'],
                            opt['data.test_shot'],
                            opt['data.test_query'],
                            opt['data.test_episodes'],
                            device,
                            conf,
                            no_grad=False
                            if opt['confidence_method'].startswith('dm') else
                            True)['aurocs'])
                    ooe_aurocs[split] = auroc
                    print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format(
                        conf_name, outer_idx, split, ooe_aurocs[split])
                    _print_and_log(print_str, trace_file)
                stats_dict[f'{conf_name}_train_ooe'] = ooe_aurocs['train']
                stats_dict[f'{conf_name}_val_ooe'] = ooe_aurocs['val']
                stats_dict[f'{conf_name}_test_ooe'] = ooe_aurocs['test']

                # OOD eval
                if not opt['ooe_only']:
                    aurocs = []
                    for curr_ood, ood_tensor in ood_tensors:
                        auroc = np.mean(
                            eval_ood_aurocs(
                                ood_tensor,
                                test_data,
                                tr,
                                opt['data.test_way'],
                                opt['data.test_shot'],
                                opt['data.test_query'],
                                opt['data.test_episodes'],
                                device,
                                conf,
                                no_grad=False
                                if opt['confidence_method'].startswith('dm')
                                else True)['aurocs'])
                        aurocs.append(auroc)

                        print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format(
                            conf_name, outer_idx, curr_ood, auroc)
                        _print_and_log(print_str, trace_file)

                    mean_ood_auroc = np.mean(aurocs)
                    print_str = '{}, iter: {} (OOD_mean), auroc: {:.3e}'.format(
                        conf_name, outer_idx, mean_ood_auroc)
                    _print_and_log(print_str, trace_file)

                    stats_dict[f'{conf_name}_ood'] = mean_ood_auroc

            iteration_logger.writerow(stats_dict)
            plot_csv(iteration_logger.filename, iteration_logger.filename)
            wandb.log(stats_dict)

            if stats_dict[f'{opt["confidence_method"]}_val_ooe'] > best_val_ooe:
                conf_model.cpu()
                torch.save(
                    conf_model.state_dict(),
                    os.path.join(opt['output_dir'],
                                 opt['exp_name'] + '_conf_best.pt'))
                conf_model.to(device)
                # Ckpt ensemble
                if opt['n_ensemble'] > 1:
                    ensemble = confs['mixture-' + opt['confidence_method']]
                    ensemble.cpu()
                    torch.save(
                        ensemble.state_dict(),
                        os.path.join(opt['output_dir'],
                                     opt['exp_name'] + '_ensemble_best.pt'))
                    ensemble.to(device)
                waited = 0
            else:
                waited += 1
                if waited >= PATIENCE:
                    print("PATIENCE exceeded...exiting")
                    sys.exit()
            # For `resume`
            conf_model.cpu()
            torch.save(
                {
                    'conf_model_sd':
                    conf_model.state_dict(),
                    'optimizer':
                    optimizer,
                    'pretrain_optimizer':
                    pretrain_optimizer
                    if conf_model.pretrain_parameters() is not None else None,
                    'scheduler':
                    scheduler,
                    'outer_idx':
                    outer_idx,
                }, os.path.join(opt['output_dir'], 'last_ckpt.pt'))
            conf_model.to(device)
            conf_model.train()
    sys.exit()
示例#7
0
def main():
    
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--dm_path', type=str, default='')
    parser.add_argument('--oec_path', type=str, default='')
    parser.add_argument('--episodic_ood_eval', type=int, default=0)
    parser.add_argument('--episodic_in_distr', type = str, default='meta-test', choices=['meta-test','meta-train'])
    # DM 
    parser.add_argument('--dm_g_magnitude', type=float, default=0)
    parser.add_argument('--dm_ls', type=str, default='-')
    parser.add_argument('--db', type = int, default=0)
    parser.add_argument('--tag', type = str, default='')
    parser.add_argument('--n_episodes', type = int, default=100)
    parser.add_argument('--n_ways', type = int, default=5)
    parser.add_argument('--n_shots', type = int, default=5)
    # Required
    parser.add_argument('--dataroot', required=True)
    parser.add_argument('--output_dir', required=True)
    parser.add_argument('--dataset', required=True, choices=['mnist','cifar10', 'cifar100', 'cifar-fs', 'cifar-64', 'miniimagenet'])
    parser.add_argument('--ood_methods', type=str, required=True, help='comma separated list of method names e.g.,  `mpp,DM-all')
    ## Pretrained model paths 
    parser.add_argument('--fsmodel_path', required=True)
    parser.add_argument('--fsmodel_name', required=True, type=str, choices=['protonet', 'maml','baseline','baseline-pn'])
    parser.add_argument('--classifier_path', required=True)
    parser.add_argument('--glow_dir', required=True)
    parser.add_argument('--ooe_only', type=int, default=0)
    
    args = parser.parse_args()
    use_cuda = True

    mkdir(args.output_dir)

    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")


    
    if args.dataset  == 'mnist':
      test_data = get_dataset('mnist-test', args.dataroot)
      out_list = ['gaussian', 'rademacher', 'texture3', 'svhn', 'notMNIST']
      tr = get_transform('mnist_resize_normalize')
      
    if args.dataset.startswith('cifar'):
      out_list = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun']
      # out_list = ['svhn']
      normalize = cifar_normalize
      if args.dataset == 'cifar10':
        train_data = get_dataset('cifar10-train', args.dataroot)
        test_data = get_dataset('cifar10-test', args.dataroot)
        
      if args.dataset == 'cifar100':
        train_data = get_dataset('cifar100-train', args.dataroot)
        test_data = get_dataset('cifar100-test', args.dataroot)
        
      if args.dataset == 'cifar-fs':
        train_data = get_dataset('cifar-fs-train-train', args.dataroot)
        test_data = get_dataset('cifar-fs-test', args.dataroot)
        
      if args.dataset == 'cifar-64':
        assert args.db 
        train_data = get_dataset('cifar-fs-train-train', args.dataroot)
        test_data = get_dataset('cifar-fs-train-test', args.dataroot)

      tr = get_transform('cifar_resize_glow_preproc') if args.ood_methods.split(',')[0].startswith('glow') else get_transform('cifar_resize_normalize') 
        
      
    if args.dataset  == 'miniimagenet':
      train_data = get_dataset('miniimagenet-train-train', args.dataroot)
      test_data = get_dataset('miniimagenet-test', args.dataroot)
      out_list = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun']
      tr =  get_transform('cifar_resize_glow_preproc') if args.ood_methods.split(',')[0].startswith('glow') else get_transform('cifar_resize_normalize_84') 
      normalize = cifar_normalize

    # Models
    classifier = None
    glow = None
    fs_model = None
    ## FS Model
    if args.fsmodel_name in ['protonet', 'maml']:
      assert args.fsmodel_path != '-'
      fs_model = torch.load(args.fsmodel_path)
      encoder = fs_model.encoder
    ## Classifier
    elif args.fsmodel_name in ['baseline','baseline-pn'] :
      assert args.classifier_path != '-' 
      classifier = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device)
      classifier.load_state_dict(torch.load(args.classifier_path))
      encoder = classifier.encoder
      if args.fsmodel_name == 'baseline':
        fs_model = BaselineFinetune(encoder, args.n_ways,args.n_shots,loss_type='dist')
      else:
        fs_model = Protonet(encoder)
    
    fs_model.to(device)
    fs_model.eval()
    args.num_feats = encoder.depth
    encoder.to(device)
    encoder.eval()

    if args.classifier_path != '-' and classifier is None: # for non-FS methods
      classifier = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device)
      classifier.load_state_dict(torch.load(args.classifier_path))


    if args.glow_dir != '-':
      # Load Glow
      glow_name = list(filter( lambda s: 'glow_model' in s, os.listdir(args.glow_dir)))[0]
      with open(os.path.join(args.glow_dir ,'hparams.json')) as json_file:  
          hparams = json.load(json_file)
      # Notice Glow is 32,32,3 even for miniImageNet
      glow = Glow((32,32,3), hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
           hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], train_data['n_classes'], hparams['learn_top'], hparams['y_condition'])
      glow.load_state_dict(torch.load(os.path.join(args.glow_dir, glow_name)))
      glow.set_actnorm_init()
      glow = glow.to(device)
      glow = glow.eval()

      
    # Verify Acc (just making sure models are loaded properly)
    if classifier is not None and not args.ood_methods.split(',')[0].startswith('glow'):
        preds = classifier(torch.stack([tr(x) for x  in train_data['x'][:args.test_batch_size]]).to(device)).max(-1)[1]
        print("Train Acc: ", (preds.detach().cpu().numpy()==np.array(train_data['y'])[:args.test_batch_size]).mean())
        preds = classifier(torch.stack([tr(x) for x  in test_data['x'][:args.test_batch_size]]).to(device)).max(-1)[1]
        print("Test Acc: ", (preds.detach().cpu().numpy()==np.array(test_data['y'])[:args.test_batch_size]).mean())


    # Confidence functions for OOD
    confidence_funcs = OrderedDict() # (name, (func, use_support, kwargs))
    for ood_method in args.ood_methods.split(','):
      no_grad = True

      if ood_method.startswith('DM'):
        deep_mahala_obj = DeepMahala(train_data['x'], train_data['y'], tr, encoder, device,num_feats=args.num_feats, num_classes=train_data['n_classes'], pretrained_path=args.dm_path, fit=True, normalize=normalize)

      if ood_method.startswith('deep-ed'): 
        no_grad=False
        deep_mahala_obj = DeepMahala(train_data['x'], train_data['y'], tr, encoder, device,num_feats=args.num_feats, num_classes=train_data['n_classes'], pretrained_path=args.dm_path, fit=False, normalize=normalize)

      if ood_method == 'MPP':
        confidence_funcs['MPP'] = BaseConfidence(lambda x:mpp(classifier, x))
      elif ood_method == 'Ensemble-MPP':
        nets = []
        class PModel(nn.Module):
          def __init__(self, logp_model):
            super(PModel, self).__init__()
            self.logp_model = logp_model
          def forward(self, x):
            return self.logp_model(x).exp()
            
        for i in range(5):
          _dir = os.path.dirname(args.classifier_path)
          _fname = os.path.basename(args.classifier_path)
          path = os.path.join(_dir[:-1]+f"{i}", _fname)
          model = ResNetClassifier(train_data['n_classes'], train_data['im_size'])
          model.load_state_dict(torch.load(path))
          model = PModel(model)
          model.eval() # 
          nets.append(model.to(device))
        ensemble = Ensemble(nets)
        confidence_funcs['Ensemble-MPP'] = BaseConfidence(lambda x:ensemble(x).max(-1)[0])
      elif ood_method == 'DM-last':
        confidence_funcs['DM-last'] = DMConfidence(deep_mahala_obj, {'ls':[args.num_feats - 1],'reduction':'max'}, False).to(device)
      elif ood_method == 'DM-all':
        confidence_funcs['DM-all'] = DMConfidence(deep_mahala_obj, {'ls':[i for i in range(args.num_feats)],'reduction':'max'}, False).to(device)
      elif ood_method == 'glow-ll':
        confidence_funcs['glow-ll'] = BaseConfidence(lambda x:-glow(x)[1])
      elif ood_method == 'glow-lr':
        from test_glow_ood import ll_to_png_code_ratio
        confidence_funcs['glow-lr'] = BaseConfidence(lambda x:ll_to_png_code_ratio(x, glow))
      elif ood_method == 'native-spp' and args.episodic_ood_eval:
        if args.fsmodel_name in ['maml','baseline']:
          no_grad=False
        confidence_funcs['native-spp'] = FSCConfidence(fs_model, 'spp')
      elif ood_method == 'native-ed' and args.episodic_ood_eval:
        confidence_funcs['native-ed'] = FSCConfidence(fs_model, 'ed')
      elif ood_method.startswith('deep-ed') and args.episodic_ood_eval:
        if args.dm_ls == '-':
          ls = range(args.num_feats)
        else:
          ls = [int(l) for l in args.dm_ls.split(',')]
        kwargs = {
          'ls':ls,
          'reduction':'max',
          'g_magnitude': args.dm_g_magnitude
        }
        dm_conf = DMConfidence(deep_mahala_obj, kwargs, True, ood_method.split('-')[-1])
        dm_conf.to(device)
        confidence_funcs[ood_method] = dm_conf
      elif ood_method == 'dkde' and args.episodic_ood_eval:
        confidence_funcs['dkde'] = DKDEConfidence(encoder)
      elif ood_method == 'oec' and args.episodic_ood_eval:
        oec_opt = json.load(
                open(os.path.join(os.path.dirname(args.oec_path), 'args.json'), 'r')
            )

        init_sample = load_episode(train_data, tr, oec_opt['data.test_way'], oec_opt['data.test_shot'], oec_opt['data.test_query'], device)
        if oec_opt['confidence_method'] == 'oec':
          oec_conf = OECConfidence(None, fs_model, init_sample, oec_opt)
        else:
          oec_conf = DeepOECConfidence(None, fs_model, init_sample, oec_opt)
        oec_conf.load_state_dict(
              torch.load(args.oec_path)
            )
        oec_conf.eval()
        oec_conf.to(device)
        confidence_funcs['oec'] =  oec_conf
      elif ood_method == 'oec-ensemble' and args.episodic_ood_eval: # not much more effective than 'oec'
        oec_opt = json.load(
                open(os.path.join(os.path.dirname(args.oec_path), 'args.json'), 'r')
            )
        oec_confs = []
        for e in range(5):
          init_sample = load_episode(train_data, tr, oec_opt['data.test_way'], oec_opt['data.test_shot'], oec_opt['data.test_query'], device)
          if oec_opt['confidence_method'] == 'oec':
            oec_conf = OECConfidence(None, fs_model, init_sample, oec_opt)
          else:
            oec_conf = DeepOECConfidence(None, fs_model, init_sample, oec_opt)
          # Find ckpt 
          cdir = os.path.dirname(args.oec_path)[:-1]+f"{e}"
          fname = list(filter(lambda s:s.endswith('conf_best.pt'), os.listdir(cdir)))[0]
          oec_conf.load_state_dict(
                torch.load(os.path.join(
                  cdir, fname))
              )
          oec_conf.eval()
          oec_conf.to(device)    
          oec_confs.append(oec_conf)
        confidence_funcs['oec'] =  Ensemble(oec_confs)
      else:
        raise # ood_method not implemented, or typo in name

 
    
    auroc_data = defaultdict(list)
    auroc_95ci_data = defaultdict(list)
    fpr_data = defaultdict(list)
    fpr_95ci_data = defaultdict(list)

    # Classic OOD evaluation
    if not args.episodic_ood_eval:
      for out_name in out_list:
        ooc_config = {
            'name': out_name,
            'ood_scale': 1,
            'n_anom': 5000,
            'cuda': False
        }
        ood_tensor = load_ood_data(ooc_config)
        assert len(ood_tensor) <= len(test_data['x'])
        in_scores = defaultdict(list)
        out_scores = defaultdict(list)

        with torch.no_grad():
          for i in tqdm(range(0, len(ood_tensor), args.test_batch_size)):
            stop = min(args.test_batch_size, len(ood_tensor[i:]))
            in_x = torch.stack([tr(x) for x  in test_data['x'][i:i+stop]]).to(device)
            out_x = torch.stack([tr(x) for x  in ood_tensor[i:i+stop]]).to(device)
            for c, f in confidence_funcs.items():
              in_scores[c].append(f.score(in_x))
              out_scores[c].append(f.score(out_x))
        # save ood images for debugging
        vutils.save_image(out_x[:100], f'non-episodic-{out_name}.png' , normalize=True, nrow=10) 
                
        for c in confidence_funcs:
          auroc = show_ood_detection_results_softmax(torch.cat(in_scores[c]).cpu().numpy(),torch.cat(out_scores[c]).cpu().numpy())[1]
          print(out_name, c, ': ', auroc)
          # 
          auroc_data[c].append(auroc)
        auroc_data['dset'].append(out_name)
      pandas.DataFrame(auroc_data).to_csv(os.path.join(args.output_dir,f'md_auroc_{args.ood_methods}.csv'))
    else:
      cifar_meta_train_data = get_dataset('cifar-fs-train-test', args.dataroot)
      cifar_meta_test_data = get_dataset('cifar-fs-test', args.dataroot)
      
      # OOD Eval
      if args.episodic_in_distr == 'meta-train':
        episodic_in_data = train_data
      else:
        episodic_in_data = test_data

      episodic_ood = ['ooe','cifar-fs-test', 'cifar-fs-train-test']  

      ood_tensors = [None] + [load_ood_data({
                      'name': out_name,
                      'ood_scale': 1,
                      'n_anom': 10000,
                    }) for out_name in episodic_ood[1:] + out_list]
      if args.ooe_only:
        all_oods = [('ooe', None)]
      else:
        all_oods = zip(episodic_ood + out_list, ood_tensors)
      for out_name, ood_tensor in all_oods:
        n_query = 15
        metrics_dic = defaultdict(list)
        for c, f in confidence_funcs.items():
          metrics_dic[c] = eval_ood_aurocs(
                      ood_tensor,
                      episodic_in_data,
                      tr, 
                      args.n_ways,
                      args.n_shots,
                      n_query,
                      args.n_episodes,
                      device,
                      f,
                      db=args.db,
                      out_name=out_name,
                      no_grad=no_grad
                      )
        
        for c in confidence_funcs:
          auroc = np.mean(metrics_dic[c]['aurocs'])
          auroc_95ci = np.std(metrics_dic[c]['aurocs']) * 1.96 / args.n_episodes
          auroc_data[c].append(auroc)
          auroc_95ci_data[c].append(auroc_95ci)
          print(out_name, c, 'auroc: ', auroc, ',', auroc_95ci)
          fpr = np.mean(metrics_dic[c]['fprs'])
          fpr_95ci = np.std(metrics_dic[c]['fprs']) * 1.96 / args.n_episodes
          fpr_data[c].append(fpr)
          fpr_95ci_data[c].append(fpr_95ci)
          print(out_name, c, 'fpr: ', fpr, ',', fpr_95ci)
          
        auroc_data['dset'].append(out_name)
        fpr_data['dset'].append(out_name)
        auroc_95ci_data['dset'].append(out_name)
        fpr_95ci_data['dset'].append(out_name)
      pandas.DataFrame(auroc_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_auroc.csv'))
      pandas.DataFrame(fpr_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_fpr.csv'))
      pandas.DataFrame(auroc_95ci_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_auroc_95ci.csv'))
      pandas.DataFrame(fpr_95ci_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_fpr_95ci.csv'))