def evaluate(data): corrects = [] for _ in tqdm(range(args['data.test_episodes'])): sample = load_episode(data, test_tr, args['data.test_way'], args['data.test_shot'], args['data.test_query'], device) corrects.append(classification_accuracy(sample, model)[0]) acc = torch.mean(torch.cat(corrects)) return acc.item()
def eval_ood_aurocs( ood_tensor, episodic_in_data, tr, n_way, n_shot, n_query, n_episodes, device, conf, db=False, out_name='', no_grad=True ): if ood_tensor is not None: N = n_way*n_query*n_episodes # repeat if necessary if len(ood_tensor) < N: ood_tensor = np.vstack([ood_tensor for _ in range(N//len(ood_tensor)+1)])[:N] metrics = defaultdict(list) for n in tqdm(range(n_episodes),desc='eval_ood_aurocs',dynamic_ncols=True): sample = load_episode(episodic_in_data, tr, n_way, n_shot, n_query, device) if ood_tensor is not None: bs = n_way*n_query sample['ooc_xq'] = torch.stack([tr(x) for x in ood_tensor[n*bs:(n+1)*bs]]).to(device) with torch.set_grad_enabled( not no_grad ): in_score, out_score = score_batch(conf, sample) in_score = in_score.numpy() out_score = out_score.numpy() # if db: # in_score1= conf.score(in_x).detach().cpu().numpy() # out_score1=conf.score(out_x).detach().cpu().numpy() # print("db --- d-score", np.sum((in_score - in_score1)**2 ), np.sum((out_score - out_score1)**2 )) _, auroc, _, _, fpr = show_ood_detection_results_softmax(in_score,out_score) metrics['aurocs'].append(auroc) metrics['fprs'].append(fpr) if db: print("Avg `in_score`: ", np.mean(in_score)) print("Avg `out_score`: ", np.mean(out_score)) # vutils.save_image(out_x[:100], f'episodic-{out_name}.jpeg' , normalize=True, nrow=10) return metrics
def evaluate(): nonlocal best_metric_value nonlocal patience_elapsed nonlocal stop nonlocal epoch corrects = [] for _ in tqdm(range(args['data.test_episodes']), desc="Epoch {:d} Val".format(epoch + 1)): sample = load_episode(val_data, test_tr, args['data.test_way'], args['data.test_shot'], args['data.test_query'], device) corrects.append(classification_accuracy(sample, model)[0]) val_acc = torch.mean(torch.cat(corrects)) iteration_logger.writerow({ 'global_iteration': epoch, 'val_acc': val_acc.item() }) plot_csv(iteration_logger.filename, iteration_logger.filename) print(f"Epoch {epoch}: Val Acc: {val_acc}") if val_acc > best_metric_value: best_metric_value = val_acc print("==> best model (metric = {:0.6f}), saving model...".format( best_metric_value)) model.cpu() torch.save(model, os.path.join(args['log.exp_dir'], 'best_model.pt')) model.to(device) patience_elapsed = 0 else: patience_elapsed += 1 if patience_elapsed > args['train.patience']: print("==> patience {:d} exceeded".format( args['train.patience'])) stop = True
def main(args): device = 'cuda:0' if args['data.cuda'] else 'cpu' args['log.exp_dir'] = args['log.exp_dir'] if not os.path.isdir(args['log.exp_dir']): os.makedirs(args['log.exp_dir']) # save opts with open(os.path.join(args['log.exp_dir'], 'args.json'), 'w') as f: json.dump(args, f) f.write('\n') # Loggin iteration_fieldnames = ['global_iteration', 'val_acc'] iteration_logger = CSVLogger(every=0, fieldnames=iteration_fieldnames, filename=os.path.join(args['log.exp_dir'], 'iteration_log.csv')) # Set the random seed manually for reproducibility. np.random.seed(args['seed']) torch.manual_seed(args['seed']) if args['data.cuda']: torch.cuda.manual_seed(args['seed']) if args['data.dataset'] == 'omniglot': raise train_tr = None test_tr = None elif args['data.dataset'] == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', args['dataroot']) val_data = get_dataset('miniimagenet-val', args['dataroot']) train_tr = get_transform( 'cifar_augment_normalize_84' if args['data_augmentation'] else 'cifar_normalize') test_tr = get_transform('cifar_normalize') elif args['data.dataset'] == 'cifar100': train_data = get_dataset('cifar-fs-train-train') val_data = get_dataset('cifar-fs-val') train_tr = get_transform( 'cifar_augment_normalize' if args['data_augmentation'] else 'cifar_normalize') test_tr = get_transform('cifar_normalize') else: raise model = protonet.create_model(**args) if args['model.model_path'] != '': loaded = torch.load(args['model.model_path']) if not 'Protonet' in str(loaded.__class__): pretrained = ResNetClassifier(64, train_data['im_size']).to(device) pretrained.load_state_dict(loaded) model.encoder = pretrained.encoder else: model = loaded model = model.to(device) max_epoch = args['train.epochs'] epoch = 0 stop = False patience_elapsed = 0 best_metric_value = 0.0 def evaluate(): nonlocal best_metric_value nonlocal patience_elapsed nonlocal stop nonlocal epoch corrects = [] for _ in tqdm(range(args['data.test_episodes']), desc="Epoch {:d} Val".format(epoch + 1)): sample = load_episode(val_data, test_tr, args['data.test_way'], args['data.test_shot'], args['data.test_query'], device) corrects.append(classification_accuracy(sample, model)[0]) val_acc = torch.mean(torch.cat(corrects)) iteration_logger.writerow({ 'global_iteration': epoch, 'val_acc': val_acc.item() }) plot_csv(iteration_logger.filename, iteration_logger.filename) print(f"Epoch {epoch}: Val Acc: {val_acc}") if val_acc > best_metric_value: best_metric_value = val_acc print("==> best model (metric = {:0.6f}), saving model...".format( best_metric_value)) model.cpu() torch.save(model, os.path.join(args['log.exp_dir'], 'best_model.pt')) model.to(device) patience_elapsed = 0 else: patience_elapsed += 1 if patience_elapsed > args['train.patience']: print("==> patience {:d} exceeded".format( args['train.patience'])) stop = True optim_method = getattr(optim, args['train.optim_method']) params = model.parameters() optimizer = optim_method(params, lr=args['train.learning_rate'], weight_decay=args['train.weight_decay']) scheduler = lr_scheduler.StepLR(optimizer, args['train.decay_every'], gamma=0.5) while epoch < max_epoch and not stop: evaluate() model.train() if epoch % args['ckpt_every'] == 0: model.cpu() torch.save(model, os.path.join(args['log.exp_dir'], f'model_{epoch}.pt')) model.to(device) scheduler.step() for _ in tqdm(range(args['data.train_episodes']), desc="Epoch {:d} train".format(epoch + 1)): sample = load_episode(train_data, train_tr, args['data.way'], args['data.shot'], args['data.query'], device) optimizer.zero_grad() loss, output = model.loss(sample) loss.backward() optimizer.step() epoch += 1
def main(opt): eval_exp_name = opt['exp_name'] device = 'cuda:0' # Load data if opt['dataset'] == 'cifar-fs': train_data = get_dataset('cifar-fs-train-train', opt['dataroot']) val_data = get_dataset('cifar-fs-val', opt['dataroot']) test_data = get_dataset('cifar-fs-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize') normalize = cifar_normalize elif opt['dataset'] == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', opt['dataroot']) val_data = get_dataset('miniimagenet-val', opt['dataroot']) test_data = get_dataset('miniimagenet-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize_84') normalize = cifar_normalize np.random.seed(1234) torch.manual_seed(1234) torch.cuda.manual_seed(1234) if opt['db']: ood_distributions = ['ooe', 'gaussian'] else: ood_distributions = ['ooe', 'gaussian', 'svhn'] # ood_distributions = ['ooe', 'gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun'] ood_tensors = [('ooe', None)] + [(out_name, load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': 10000, })) for out_name in ood_distributions[1:]] # Load trained model loaded = torch.load(opt['model.model_path']) if not isinstance(loaded, OrderedDict): protonet = loaded else: classifier = ResNetClassifier(64, train_data['im_size']).to(device) classifier.load_state_dict(loaded) protonet = Protonet(classifier.encoder) encoder = protonet.encoder encoder.eval() encoder.to(device) protonet.eval() protonet.to(device) # Init Confidence model if opt['ood_method'] == 'deep-ed-iso': deep_mahala_obj = DeepMahala(None, None, None, encoder, device, num_feats=encoder.depth, num_classes=train_data['n_classes'], pretrained_path="", fit=False, normalize=None) conf = DMConfidence(deep_mahala_obj, { 'ls': range(encoder.depth), 'reduction': 'max', 'g_magnitude': 0 }, True, 'iso').to(device) elif opt['ood_method'] == 'native-spp': conf = FSCConfidence(protonet, 'spp') elif opt['ood_method'] == 'oec': oec_opt = json.load( open(os.path.join(os.path.dirname(opt['oec_path']), 'args.json'), 'r')) init_sample = load_episode(train_data, tr, oec_opt['data.test_way'], oec_opt['data.test_shot'], oec_opt['data.test_query'], device) if oec_opt['confidence_method'] == 'oec': oec_conf = OECConfidence(None, protonet, init_sample, oec_opt) else: oec_conf = DeepOECConfidence(None, protonet, init_sample, oec_opt) oec_conf.load_state_dict(torch.load(opt['oec_path'])) oec_conf.eval() oec_conf.to(device) conf = oec_conf # Turn confidence score into a threshold based classifier # Select threshold by "max-accuracy" # Select temperature by "best-calibration" in the binary problem # done using the meta-train set in_scores = [] out_scores = [] for n in tqdm(range(100)): sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) in_score, out_score = score_batch(conf, sample) in_scores.append(in_score) out_scores.append(out_score) in_scores = torch.cat(in_scores) out_scores = torch.cat(out_scores) def _compute_acc(in_scores, out_scores, t): N = len(in_scores) + len(out_scores) return (torch.sum(in_scores >= t) + torch.sum(out_scores < t)).item() / float(N) best_threshold = torch.min(in_scores) best_acc = _compute_acc(in_scores, out_scores, best_threshold) for t in in_scores: acc = _compute_acc(in_scores, out_scores, t) if acc > best_acc: best_acc = acc best_threshold = t def _compute_confs(in_scores, out_scores, t, temp): in_p = torch.sigmoid((in_scores - t) / temp) corrects = in_p >= .5 confs = torch.max(torch.stack([in_p, 1 - in_p]), 0)[0] out_p = torch.sigmoid((out_scores - t) / temp) corrects = torch.cat([corrects, out_p < .5]) confs = torch.cat( [confs, torch.max(torch.stack([out_p, 1 - out_p]), 0)[0]]) return confs, corrects def compute_eces(candidate_temps, in_scores, out_scores, best_threshold): eces = [] for temp in candidate_temps: confs, corrects = _compute_confs(in_scores, out_scores, best_threshold, temp) ece = compute_ece( *prep_accs(confs.numpy(), corrects.numpy(), bins=20)) eces.append(ece) return eces min_log_temp = -1 log_interval = 2 npts = 10 for _ in range(opt['max_temp_select_iter']): print("..selecting temperature") candidate_temps = np.logspace(min_log_temp, min_log_temp + log_interval, npts) eces = compute_eces(candidate_temps, in_scores, out_scores, best_threshold) min_idx = np.argmin(eces) if min_idx == 0: min_log_temp -= log_interval // 2 elif min_idx == npts - 1: min_log_temp += log_interval // 2 else: break best_ece = eces[min_idx] best_temp = candidate_temps[min_idx] print( f"Best ACC:{best_acc}, thresh:{best_threshold}, Best ECE:{best_ece}, temp:{best_temp}" ) def get_95_percent_ci(std): """Computes the 95% confidence interval from the standard deviation.""" return std * 1.96 / np.sqrt(data_opt['data.test_episodes']) active_supervised = defaultdict(list) active_augmented = defaultdict(list) ssl_soft = defaultdict(list) ssl_hard = defaultdict(list) # for ood_idx, curr_ood in tqdm(enumerate(all_distributions)): for curr_ood, ood_tensor in ood_tensors: in_scores = defaultdict(list) out_scores = defaultdict(list) # Compute and collect scores for all examples aurocs, auprs, fprs = defaultdict(list), defaultdict( list), defaultdict(list) for n in tqdm(range(opt['data.test_episodes'])): n_total_query = np.max([ opt['data.test_query'] + opt['n_unlabeled_per_class'], opt['n_distractor_per_class'] ]) sample = load_episode(test_data, tr, opt['data.test_way'], opt['data.test_shot'], n_total_query, device) if curr_ood != 'ooe': bs = opt['data.test_way'] * opt['data.test_query'] ridx = np.random.permutation(bs) sample['ooc_xq'] = torch.stack( [tr(x) for x in ood_tensor[ridx]]).to(device) way, _, c, h, w = sample['xq'].shape sample['ooc_xq'] = sample['ooc_xq'].reshape(way, -1, c, h, w) # if curr_ood in ['gaussian', 'rademacher']: # sample['ooc_xq'] *= 4 all_xq = sample['xq'].clone() sample['xq'] = all_xq[:, :opt[ 'n_unlabeled_per_class']] # Unlabelled pool sample[ 'xq2'] = all_xq[:, opt['n_unlabeled_per_class']: opt['n_unlabeled_per_class'] + opt['data.test_query']] # Final test queries sample['ooc_xq'] = sample[ 'ooc_xq'][:, :opt['n_distractor_per_class']] """ 1. OOD classification on the 'unlabelled' set """ # In vs Out in_score, out_score = score_batch(conf, sample) num_in = in_score.shape[0] confs, corrects = _compute_confs(in_score, out_score, best_threshold, best_temp) in_mask = corrects[:num_in].reshape( sample['xq'].size(0), sample['xq'].size(1)).float().to(device) out_mask = 1 - corrects[num_in:].reshape( sample['ooc_xq'].size(0), sample['ooc_xq'].size(1)).float().to(device) """ 2.0 """ budget_active = in_score.size(0) scores = torch.cat([in_score, out_score], -1) ipdb.set_trace() selected_inds = torch.sort(scores)[1][scores.size(0) - budget_active:] selected_inds_in = selected_inds[selected_inds < in_score.size(0)] budget_mask = torch.zeros(in_score.size(0)).to(device) budget_mask.scatter_(0, selected_inds_in.to(device).long(), 1) budget_mask = budget_mask.reshape( sample['xq'].size(0), sample['xq'].size(1)).float().to(device) """ 2. Add labels to the predicted unlabelled examples """ # Collect the incorrectly kept OOD examples included_distractors = sample['ooc_xq'][out_mask.byte()] # Pad them to N-way multiples, and assign random labels (done simply by reshaping) n_way = sample['xs'].shape[0] im_shape = list(sample['xs'].shape[2:]) n_res = n_way - (included_distractors.shape[0] % n_way) distractor_mask = torch.ones([included_distractors.shape[0] ]).to(device) zeros = torch.zeros([n_res] + im_shape).to(device) included_distractors = torch.cat([included_distractors, zeros]) distractor_mask = torch.cat( [distractor_mask, torch.zeros([n_res]).to(device)]) # the reason we permute is to spread the padded zero across ways included_distractors = included_distractors.reshape( [-1, n_way] + im_shape).permute(1, 0, 2, 3, 4) distractor_mask = distractor_mask.reshape([-1, n_way]).permute(1, 0) """ 2.5 SSL """ # predict k-way using classifier n_way, n_aug_shot, n_ch, n_dim, _ = sample['xq'].shape lpy_dic = protonet.log_p_y(sample['xs'], sample['xq'], mask=None) log_p_y, target_inds = lpy_dic['log_p_y'], lpy_dic['target_inds'] preds = log_p_y.max(-1)[1] def reorder(unlabelled, preds, py, make_soft=True): if py is not None: reshaped_py = py.reshape(-1) n_way, n_aug_shot, n_ch, n_dim, _ = unlabelled.shape reshaped_unlabelled = unlabelled.reshape( n_aug_shot * n_way, n_ch, n_dim, n_dim) reshaped_predicted_labels = preds.reshape(-1) unlabelled = torch.zeros( (n_way, n_aug_shot * n_way, n_ch, n_dim, n_dim)) mask = torch.zeros((n_way, n_aug_shot * n_way)) for idx, label in enumerate(reshaped_predicted_labels): unlabelled[label, idx] = reshaped_unlabelled[idx] # (n_shot, ...) if make_soft: mask[label, idx] = reshaped_py[idx] else: mask[label, idx] = 1 # (n_shot, ) return unlabelled.to(device), mask.to(device) gt_in_unlabelled, gt_in_weights = reorder(sample['xq'], preds, log_p_y.max(-1)[0].exp(), True) _, in_mask_reordered = reorder(sample['xq'], preds, in_mask, True) # for the gt OOD ones lpy_dic = protonet.log_p_y(sample['xs'], sample['ooc_xq'], mask=None) log_p_y = lpy_dic['log_p_y'] preds = log_p_y.max(-1)[1] gt_ood_unlabelled, gt_ood_weights = reorder( sample['ooc_xq'], preds, log_p_y.max(-1)[0].exp(), True) _, out_mask_reordered = reorder(sample['ooc_xq'], preds, out_mask, True) # Support + ALL unlabelled _ssl_soft = compute_acc( protonet, torch.cat([sample['xs'], gt_in_unlabelled, gt_ood_unlabelled], 1), sample['xq2'], torch.cat([ torch.ones(sample['xs'].shape[:2]).to(device), gt_in_weights, gt_ood_weights ], 1)) _acc_hard = compute_acc( protonet, torch.cat([sample['xs'], gt_in_unlabelled, gt_ood_unlabelled], 1), sample['xq2'], torch.cat([ torch.ones(sample['xs'].shape[:2]).to(device), in_mask_reordered * gt_in_weights, out_mask_reordered * gt_ood_weights ], 1)) """ 3. Evaluate k-way accuracy after adding examples """ _active_supervised = compute_acc(protonet, sample['xs'], sample['xq2'], None) # Support + Budgeted unlabelled _active_augmented = compute_acc( protonet, torch.cat([sample['xs'], sample['xq']], 1), sample['xq2'], torch.cat([ torch.ones(sample['xs'].shape[:2]).to(device), budget_mask ], 1)) ssl_soft[curr_ood].append(_ssl_soft) ssl_hard[curr_ood].append(_acc_hard) active_supervised[curr_ood].append(_active_supervised) active_augmented[curr_ood].append(_active_augmented) if not os.path.exists(opt['output_dir']): os.makedirs(opt['output_dir']) pickle.dump((ssl_soft, ssl_hard, active_supervised, active_augmented), open( os.path.join(opt['output_dir'], f'eval_active_{eval_exp_name}.pkl'), 'wb')) print("===> Aggregating results") aggr_args = namedtuple('Arg', ('exp_dir', 'f_acq'))(exp_dir=opt['output_dir'], f_acq='conv4') aggregate_eval_active.main(aggr_args) print('===> Done') sys.exit()
def main(opt): # Logging trace_file = os.path.join(opt['output_dir'], '{}_trace.txt'.format(opt['exp_name'])) # Load data if opt['dataset'] == 'cifar-fs': train_data = get_dataset('cifar-fs-train-train', opt['dataroot']) val_data = get_dataset('cifar-fs-val', opt['dataroot']) test_data = get_dataset('cifar-fs-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize') normalize = cifar_normalize elif opt['dataset'] == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', opt['dataroot']) val_data = get_dataset('miniimagenet-val', opt['dataroot']) test_data = get_dataset('miniimagenet-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize_84') normalize = cifar_normalize if opt['input_regularization'] == 'oe': reg_data = load_ood_data({ 'name': 'tinyimages', 'ood_scale': 1, 'n_anom': 50000, }) if not opt['ooe_only']: if opt['db']: ood_distributions = ['ooe', 'gaussian'] else: ood_distributions = [ 'ooe', 'gaussian', 'rademacher', 'texture3', 'svhn', 'tinyimagenet', 'lsun' ] if opt['input_regularization'] == 'oe': ood_distributions.append('tinyimages') ood_tensors = [('ooe', None)] + [(out_name, load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': 10000, })) for out_name in ood_distributions[1:]] # Load trained model loaded = torch.load(opt['model.model_path']) if not isinstance(loaded, OrderedDict): fs_model = loaded else: classifier = ResNetClassifier(64, train_data['im_size']).to(device) classifier.load_state_dict(loaded) fs_model = Protonet(classifier.encoder) fs_model.eval() fs_model = fs_model.to(device) # Init Confidence Methods if opt['confidence_method'] == 'oec': init_sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model = OECConfidence(None, fs_model, init_sample, opt) elif opt['confidence_method'] == 'deep-oec': init_sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model = DeepOECConfidence(None, fs_model, init_sample, opt) elif opt['confidence_method'] == 'dm-iso': encoder = fs_model.encoder deep_mahala_obj = DeepMahala(None, None, None, encoder, device, num_feats=encoder.depth, num_classes=train_data['n_classes'], pretrained_path="", fit=False, normalize=None) conf_model = DMConfidence(deep_mahala_obj, { 'ls': range(encoder.depth), 'reduction': 'max', 'g_magnitude': .1 }, True, 'iso') if opt['pretrained_oec_path']: conf_model.load_state_dict(torch.load(opt['pretrained_oec_path'])) conf_model.to(device) print(conf_model) optimizer = optim.Adam(conf_model.confidence_parameters(), lr=opt['lr'], weight_decay=opt['wd']) scheduler = StepLR(optimizer, step_size=opt['lrsche_step_size'], gamma=opt['lrsche_gamma']) num_param = sum(p.numel() for p in conf_model.confidence_parameters()) print(f"Learning Confidence, Number of Parameters -- {num_param}") if conf_model.pretrain_parameters() is not None: pretrain_optimizer = optim.Adam(conf_model.pretrain_parameters(), lr=10) pretrain_iter = 100 start_idx = 0 if opt['resume']: last_ckpt_path = os.path.join(opt['output_dir'], 'last_ckpt.pt') if os.path.exists(last_ckpt_path): try: last_ckpt = torch.load(last_ckpt_path) if 'conf_model' in last_ckpt: conf_model = last_ckpt['conf_model'] else: sd = last_ckpt['conf_model_sd'] conf_model.load_state_dict(sd) optimizer = last_ckpt['optimizer'] pretrain_optimizer = last_ckpt['pretrain_optimizer'] scheduler = last_ckpt['scheduler'] start_idx = last_ckpt['outer_idx'] conf_model.to(device) except EOFError: print( "\n\nResuming but got EOF error, starting from init..\n\n") wandb.run.name = opt['exp_name'] wandb.run.save() # try: wandb.watch(conf_model) # except: # resuming a run # pass # Eval and Logging confs = { opt['confidence_method']: conf_model, } if opt['confidence_method'] == 'oec': confs['ed'] = FSCConfidence(fs_model, 'ed') elif opt['confidence_method'] == 'deep-oec': encoder = fs_model.encoder deep_mahala_obj = DeepMahala(None, None, None, encoder, device, num_feats=encoder.depth, num_classes=train_data['n_classes'], pretrained_path="", fit=False, normalize=None) confs['dm'] = DMConfidence(deep_mahala_obj, { 'ls': range(encoder.depth), 'reduction': 'max', 'g_magnitude': 0 }, True, 'iso').to(device) # Temporal Ensemble for Evaluation if opt['n_ensemble'] > 1: nets = [deepcopy(conf_model) for _ in range(opt['n_ensemble'])] confs['mixture-' + opt['confidence_method']] = Ensemble( nets, 'mixture') confs['poe-' + opt['confidence_method']] = Ensemble(nets, 'poe') ensemble_update_interval = opt['eval_every_outer'] // opt['n_ensemble'] iteration_fieldnames = ['global_iteration'] for c in confs: iteration_fieldnames += [ f'{c}_train_ooe', f'{c}_val_ooe', f'{c}_test_ooe', f'{c}_ood' ] iteration_logger = CSVLogger(every=0, fieldnames=iteration_fieldnames, filename=os.path.join(opt['output_dir'], 'iteration_log.csv')) best_val_ooe = 0 PATIENCE = 5 # Number of evaluations to wait waited = 0 progress_bar = tqdm(range(start_idx, opt['train_iter'])) for outer_idx in progress_bar: sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model.train() if opt['full_supervision']: # sanity check conf_model.support(sample['xs']) in_score = conf_model.score(sample['xq'], detach=False).squeeze() out_score = conf_model.score(sample['ooc_xq'], detach=False).squeeze() out_scores = [out_score] for curr_ood, ood_tensor in ood_tensors: if curr_ood == 'ooe': continue start = outer_idx % (len(ood_tensor) // 2) stop = min( start + sample['xq'].shape[0] * sample['xq'].shape[0], len(ood_tensor) // 2) oxq = torch.stack([tr(x) for x in ood_tensor[start:stop]]).to(device) o = conf_model.score(oxq, detach=False).squeeze() out_scores.append(o) # out_score = torch.cat(out_scores) in_score = in_score.repeat(len(ood_tensors)) loss, acc = compute_loss_bce(in_score, out_score, mean_center=False) else: conf_model.support(sample['xs']) if opt['interpolate']: half_n_way = sample['xq'].shape[0] // 2 interp = .5 * (sample['xq'][:half_n_way] + sample['xq'][half_n_way:2 * half_n_way]) sample['ooc_xq'][:half_n_way] = interp if opt['input_regularization'] == 'oe': # Reshape ooc_xq nw, nq, c, h, w = sample['ooc_xq'].shape sample['ooc_xq'] = sample['ooc_xq'].view(1, nw * nq, c, h, w) oe_bs = int(nw * nq * opt['input_regularization_percent']) start = (outer_idx * oe_bs) % len(reg_data) end = np.min([start + oe_bs, len(reg_data)]) oe_batch = torch.stack([tr(x) for x in reg_data[start:end] ]).to(device) oe_batch = oe_batch.unsqueeze(0) sample['ooc_xq'][:, :oe_batch.shape[1]] = oe_batch if opt['in_out_1_batch']: inps = torch.cat([sample['xq'], sample['ooc_xq']], 1) scores = conf_model.score(inps, detach=False).squeeze() in_score, out_score = scores[:sample['xq'].shape[1]], scores[ sample['xq'].shape[1]:] else: in_score = conf_model.score(sample['xq'], detach=False).squeeze() out_score = conf_model.score(sample['ooc_xq'], detach=False).squeeze() loss, acc = compute_loss_bce(in_score, out_score, mean_center=False) if conf_model.pretrain_parameters( ) is not None and outer_idx < pretrain_iter: pretrain_optimizer.zero_grad() loss.backward() pretrain_optimizer.step() else: optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() progress_bar.set_postfix(loss='{:.3e}'.format(loss), acc='{:.3e}'.format(acc)) # Update Ensemble if opt['n_ensemble'] > 1 and outer_idx % ensemble_update_interval == 0: update_ind = (outer_idx // ensemble_update_interval) % opt['n_ensemble'] if opt['db']: print(f"===> Updating Ensemble: {update_ind}") confs['mixture-' + opt['confidence_method']].nets[update_ind] = deepcopy( conf_model) confs['poe-' + opt['confidence_method']].nets[update_ind] = deepcopy( conf_model) # AUROC eval if outer_idx % opt['eval_every_outer'] == 0: if not opt['eval_in_train']: conf_model.eval() # Eval.. stats_dict = {'global_iteration': outer_idx} for conf_name, conf in confs.items(): conf.eval() # OOE eval ooe_aurocs = {} for split, in_data in [('train', train_data), ('val', val_data), ('test', test_data)]: auroc = np.mean( eval_ood_aurocs( None, in_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], opt['data.test_episodes'], device, conf, no_grad=False if opt['confidence_method'].startswith('dm') else True)['aurocs']) ooe_aurocs[split] = auroc print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format( conf_name, outer_idx, split, ooe_aurocs[split]) _print_and_log(print_str, trace_file) stats_dict[f'{conf_name}_train_ooe'] = ooe_aurocs['train'] stats_dict[f'{conf_name}_val_ooe'] = ooe_aurocs['val'] stats_dict[f'{conf_name}_test_ooe'] = ooe_aurocs['test'] # OOD eval if not opt['ooe_only']: aurocs = [] for curr_ood, ood_tensor in ood_tensors: auroc = np.mean( eval_ood_aurocs( ood_tensor, test_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], opt['data.test_episodes'], device, conf, no_grad=False if opt['confidence_method'].startswith('dm') else True)['aurocs']) aurocs.append(auroc) print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format( conf_name, outer_idx, curr_ood, auroc) _print_and_log(print_str, trace_file) mean_ood_auroc = np.mean(aurocs) print_str = '{}, iter: {} (OOD_mean), auroc: {:.3e}'.format( conf_name, outer_idx, mean_ood_auroc) _print_and_log(print_str, trace_file) stats_dict[f'{conf_name}_ood'] = mean_ood_auroc iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename, iteration_logger.filename) wandb.log(stats_dict) if stats_dict[f'{opt["confidence_method"]}_val_ooe'] > best_val_ooe: conf_model.cpu() torch.save( conf_model.state_dict(), os.path.join(opt['output_dir'], opt['exp_name'] + '_conf_best.pt')) conf_model.to(device) # Ckpt ensemble if opt['n_ensemble'] > 1: ensemble = confs['mixture-' + opt['confidence_method']] ensemble.cpu() torch.save( ensemble.state_dict(), os.path.join(opt['output_dir'], opt['exp_name'] + '_ensemble_best.pt')) ensemble.to(device) waited = 0 else: waited += 1 if waited >= PATIENCE: print("PATIENCE exceeded...exiting") sys.exit() # For `resume` conf_model.cpu() torch.save( { 'conf_model_sd': conf_model.state_dict(), 'optimizer': optimizer, 'pretrain_optimizer': pretrain_optimizer if conf_model.pretrain_parameters() is not None else None, 'scheduler': scheduler, 'outer_idx': outer_idx, }, os.path.join(opt['output_dir'], 'last_ckpt.pt')) conf_model.to(device) conf_model.train() sys.exit()
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--dm_path', type=str, default='') parser.add_argument('--oec_path', type=str, default='') parser.add_argument('--episodic_ood_eval', type=int, default=0) parser.add_argument('--episodic_in_distr', type = str, default='meta-test', choices=['meta-test','meta-train']) # DM parser.add_argument('--dm_g_magnitude', type=float, default=0) parser.add_argument('--dm_ls', type=str, default='-') parser.add_argument('--db', type = int, default=0) parser.add_argument('--tag', type = str, default='') parser.add_argument('--n_episodes', type = int, default=100) parser.add_argument('--n_ways', type = int, default=5) parser.add_argument('--n_shots', type = int, default=5) # Required parser.add_argument('--dataroot', required=True) parser.add_argument('--output_dir', required=True) parser.add_argument('--dataset', required=True, choices=['mnist','cifar10', 'cifar100', 'cifar-fs', 'cifar-64', 'miniimagenet']) parser.add_argument('--ood_methods', type=str, required=True, help='comma separated list of method names e.g., `mpp,DM-all') ## Pretrained model paths parser.add_argument('--fsmodel_path', required=True) parser.add_argument('--fsmodel_name', required=True, type=str, choices=['protonet', 'maml','baseline','baseline-pn']) parser.add_argument('--classifier_path', required=True) parser.add_argument('--glow_dir', required=True) parser.add_argument('--ooe_only', type=int, default=0) args = parser.parse_args() use_cuda = True mkdir(args.output_dir) torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") if args.dataset == 'mnist': test_data = get_dataset('mnist-test', args.dataroot) out_list = ['gaussian', 'rademacher', 'texture3', 'svhn', 'notMNIST'] tr = get_transform('mnist_resize_normalize') if args.dataset.startswith('cifar'): out_list = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun'] # out_list = ['svhn'] normalize = cifar_normalize if args.dataset == 'cifar10': train_data = get_dataset('cifar10-train', args.dataroot) test_data = get_dataset('cifar10-test', args.dataroot) if args.dataset == 'cifar100': train_data = get_dataset('cifar100-train', args.dataroot) test_data = get_dataset('cifar100-test', args.dataroot) if args.dataset == 'cifar-fs': train_data = get_dataset('cifar-fs-train-train', args.dataroot) test_data = get_dataset('cifar-fs-test', args.dataroot) if args.dataset == 'cifar-64': assert args.db train_data = get_dataset('cifar-fs-train-train', args.dataroot) test_data = get_dataset('cifar-fs-train-test', args.dataroot) tr = get_transform('cifar_resize_glow_preproc') if args.ood_methods.split(',')[0].startswith('glow') else get_transform('cifar_resize_normalize') if args.dataset == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', args.dataroot) test_data = get_dataset('miniimagenet-test', args.dataroot) out_list = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun'] tr = get_transform('cifar_resize_glow_preproc') if args.ood_methods.split(',')[0].startswith('glow') else get_transform('cifar_resize_normalize_84') normalize = cifar_normalize # Models classifier = None glow = None fs_model = None ## FS Model if args.fsmodel_name in ['protonet', 'maml']: assert args.fsmodel_path != '-' fs_model = torch.load(args.fsmodel_path) encoder = fs_model.encoder ## Classifier elif args.fsmodel_name in ['baseline','baseline-pn'] : assert args.classifier_path != '-' classifier = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device) classifier.load_state_dict(torch.load(args.classifier_path)) encoder = classifier.encoder if args.fsmodel_name == 'baseline': fs_model = BaselineFinetune(encoder, args.n_ways,args.n_shots,loss_type='dist') else: fs_model = Protonet(encoder) fs_model.to(device) fs_model.eval() args.num_feats = encoder.depth encoder.to(device) encoder.eval() if args.classifier_path != '-' and classifier is None: # for non-FS methods classifier = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device) classifier.load_state_dict(torch.load(args.classifier_path)) if args.glow_dir != '-': # Load Glow glow_name = list(filter( lambda s: 'glow_model' in s, os.listdir(args.glow_dir)))[0] with open(os.path.join(args.glow_dir ,'hparams.json')) as json_file: hparams = json.load(json_file) # Notice Glow is 32,32,3 even for miniImageNet glow = Glow((32,32,3), hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], train_data['n_classes'], hparams['learn_top'], hparams['y_condition']) glow.load_state_dict(torch.load(os.path.join(args.glow_dir, glow_name))) glow.set_actnorm_init() glow = glow.to(device) glow = glow.eval() # Verify Acc (just making sure models are loaded properly) if classifier is not None and not args.ood_methods.split(',')[0].startswith('glow'): preds = classifier(torch.stack([tr(x) for x in train_data['x'][:args.test_batch_size]]).to(device)).max(-1)[1] print("Train Acc: ", (preds.detach().cpu().numpy()==np.array(train_data['y'])[:args.test_batch_size]).mean()) preds = classifier(torch.stack([tr(x) for x in test_data['x'][:args.test_batch_size]]).to(device)).max(-1)[1] print("Test Acc: ", (preds.detach().cpu().numpy()==np.array(test_data['y'])[:args.test_batch_size]).mean()) # Confidence functions for OOD confidence_funcs = OrderedDict() # (name, (func, use_support, kwargs)) for ood_method in args.ood_methods.split(','): no_grad = True if ood_method.startswith('DM'): deep_mahala_obj = DeepMahala(train_data['x'], train_data['y'], tr, encoder, device,num_feats=args.num_feats, num_classes=train_data['n_classes'], pretrained_path=args.dm_path, fit=True, normalize=normalize) if ood_method.startswith('deep-ed'): no_grad=False deep_mahala_obj = DeepMahala(train_data['x'], train_data['y'], tr, encoder, device,num_feats=args.num_feats, num_classes=train_data['n_classes'], pretrained_path=args.dm_path, fit=False, normalize=normalize) if ood_method == 'MPP': confidence_funcs['MPP'] = BaseConfidence(lambda x:mpp(classifier, x)) elif ood_method == 'Ensemble-MPP': nets = [] class PModel(nn.Module): def __init__(self, logp_model): super(PModel, self).__init__() self.logp_model = logp_model def forward(self, x): return self.logp_model(x).exp() for i in range(5): _dir = os.path.dirname(args.classifier_path) _fname = os.path.basename(args.classifier_path) path = os.path.join(_dir[:-1]+f"{i}", _fname) model = ResNetClassifier(train_data['n_classes'], train_data['im_size']) model.load_state_dict(torch.load(path)) model = PModel(model) model.eval() # nets.append(model.to(device)) ensemble = Ensemble(nets) confidence_funcs['Ensemble-MPP'] = BaseConfidence(lambda x:ensemble(x).max(-1)[0]) elif ood_method == 'DM-last': confidence_funcs['DM-last'] = DMConfidence(deep_mahala_obj, {'ls':[args.num_feats - 1],'reduction':'max'}, False).to(device) elif ood_method == 'DM-all': confidence_funcs['DM-all'] = DMConfidence(deep_mahala_obj, {'ls':[i for i in range(args.num_feats)],'reduction':'max'}, False).to(device) elif ood_method == 'glow-ll': confidence_funcs['glow-ll'] = BaseConfidence(lambda x:-glow(x)[1]) elif ood_method == 'glow-lr': from test_glow_ood import ll_to_png_code_ratio confidence_funcs['glow-lr'] = BaseConfidence(lambda x:ll_to_png_code_ratio(x, glow)) elif ood_method == 'native-spp' and args.episodic_ood_eval: if args.fsmodel_name in ['maml','baseline']: no_grad=False confidence_funcs['native-spp'] = FSCConfidence(fs_model, 'spp') elif ood_method == 'native-ed' and args.episodic_ood_eval: confidence_funcs['native-ed'] = FSCConfidence(fs_model, 'ed') elif ood_method.startswith('deep-ed') and args.episodic_ood_eval: if args.dm_ls == '-': ls = range(args.num_feats) else: ls = [int(l) for l in args.dm_ls.split(',')] kwargs = { 'ls':ls, 'reduction':'max', 'g_magnitude': args.dm_g_magnitude } dm_conf = DMConfidence(deep_mahala_obj, kwargs, True, ood_method.split('-')[-1]) dm_conf.to(device) confidence_funcs[ood_method] = dm_conf elif ood_method == 'dkde' and args.episodic_ood_eval: confidence_funcs['dkde'] = DKDEConfidence(encoder) elif ood_method == 'oec' and args.episodic_ood_eval: oec_opt = json.load( open(os.path.join(os.path.dirname(args.oec_path), 'args.json'), 'r') ) init_sample = load_episode(train_data, tr, oec_opt['data.test_way'], oec_opt['data.test_shot'], oec_opt['data.test_query'], device) if oec_opt['confidence_method'] == 'oec': oec_conf = OECConfidence(None, fs_model, init_sample, oec_opt) else: oec_conf = DeepOECConfidence(None, fs_model, init_sample, oec_opt) oec_conf.load_state_dict( torch.load(args.oec_path) ) oec_conf.eval() oec_conf.to(device) confidence_funcs['oec'] = oec_conf elif ood_method == 'oec-ensemble' and args.episodic_ood_eval: # not much more effective than 'oec' oec_opt = json.load( open(os.path.join(os.path.dirname(args.oec_path), 'args.json'), 'r') ) oec_confs = [] for e in range(5): init_sample = load_episode(train_data, tr, oec_opt['data.test_way'], oec_opt['data.test_shot'], oec_opt['data.test_query'], device) if oec_opt['confidence_method'] == 'oec': oec_conf = OECConfidence(None, fs_model, init_sample, oec_opt) else: oec_conf = DeepOECConfidence(None, fs_model, init_sample, oec_opt) # Find ckpt cdir = os.path.dirname(args.oec_path)[:-1]+f"{e}" fname = list(filter(lambda s:s.endswith('conf_best.pt'), os.listdir(cdir)))[0] oec_conf.load_state_dict( torch.load(os.path.join( cdir, fname)) ) oec_conf.eval() oec_conf.to(device) oec_confs.append(oec_conf) confidence_funcs['oec'] = Ensemble(oec_confs) else: raise # ood_method not implemented, or typo in name auroc_data = defaultdict(list) auroc_95ci_data = defaultdict(list) fpr_data = defaultdict(list) fpr_95ci_data = defaultdict(list) # Classic OOD evaluation if not args.episodic_ood_eval: for out_name in out_list: ooc_config = { 'name': out_name, 'ood_scale': 1, 'n_anom': 5000, 'cuda': False } ood_tensor = load_ood_data(ooc_config) assert len(ood_tensor) <= len(test_data['x']) in_scores = defaultdict(list) out_scores = defaultdict(list) with torch.no_grad(): for i in tqdm(range(0, len(ood_tensor), args.test_batch_size)): stop = min(args.test_batch_size, len(ood_tensor[i:])) in_x = torch.stack([tr(x) for x in test_data['x'][i:i+stop]]).to(device) out_x = torch.stack([tr(x) for x in ood_tensor[i:i+stop]]).to(device) for c, f in confidence_funcs.items(): in_scores[c].append(f.score(in_x)) out_scores[c].append(f.score(out_x)) # save ood images for debugging vutils.save_image(out_x[:100], f'non-episodic-{out_name}.png' , normalize=True, nrow=10) for c in confidence_funcs: auroc = show_ood_detection_results_softmax(torch.cat(in_scores[c]).cpu().numpy(),torch.cat(out_scores[c]).cpu().numpy())[1] print(out_name, c, ': ', auroc) # auroc_data[c].append(auroc) auroc_data['dset'].append(out_name) pandas.DataFrame(auroc_data).to_csv(os.path.join(args.output_dir,f'md_auroc_{args.ood_methods}.csv')) else: cifar_meta_train_data = get_dataset('cifar-fs-train-test', args.dataroot) cifar_meta_test_data = get_dataset('cifar-fs-test', args.dataroot) # OOD Eval if args.episodic_in_distr == 'meta-train': episodic_in_data = train_data else: episodic_in_data = test_data episodic_ood = ['ooe','cifar-fs-test', 'cifar-fs-train-test'] ood_tensors = [None] + [load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': 10000, }) for out_name in episodic_ood[1:] + out_list] if args.ooe_only: all_oods = [('ooe', None)] else: all_oods = zip(episodic_ood + out_list, ood_tensors) for out_name, ood_tensor in all_oods: n_query = 15 metrics_dic = defaultdict(list) for c, f in confidence_funcs.items(): metrics_dic[c] = eval_ood_aurocs( ood_tensor, episodic_in_data, tr, args.n_ways, args.n_shots, n_query, args.n_episodes, device, f, db=args.db, out_name=out_name, no_grad=no_grad ) for c in confidence_funcs: auroc = np.mean(metrics_dic[c]['aurocs']) auroc_95ci = np.std(metrics_dic[c]['aurocs']) * 1.96 / args.n_episodes auroc_data[c].append(auroc) auroc_95ci_data[c].append(auroc_95ci) print(out_name, c, 'auroc: ', auroc, ',', auroc_95ci) fpr = np.mean(metrics_dic[c]['fprs']) fpr_95ci = np.std(metrics_dic[c]['fprs']) * 1.96 / args.n_episodes fpr_data[c].append(fpr) fpr_95ci_data[c].append(fpr_95ci) print(out_name, c, 'fpr: ', fpr, ',', fpr_95ci) auroc_data['dset'].append(out_name) fpr_data['dset'].append(out_name) auroc_95ci_data['dset'].append(out_name) fpr_95ci_data['dset'].append(out_name) pandas.DataFrame(auroc_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_auroc.csv')) pandas.DataFrame(fpr_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_fpr.csv')) pandas.DataFrame(auroc_95ci_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_auroc_95ci.csv')) pandas.DataFrame(fpr_95ci_data).to_csv(os.path.join(args.output_dir,f'{args.tag}_episodic_{args.episodic_in_distr}_{args.dm_path.split(".")[0]}_{args.ood_methods}_fpr_95ci.csv'))