def test_process(self, testset): args = self.args record = np.zeros((args.num_test_episodes, 2)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat( # args.num_tasks * args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() test_sampler = CategoriesSampler( testset.label, args.num_test_episodes, # args.num_eval_episodes, args.eval_way, args.eval_shot + args.eval_query) test_loader = DataLoader(dataset=testset, batch_sampler=test_sampler, num_workers=args.num_workers, pin_memory=True) for i, batch in tqdm(enumerate(test_loader, 1), total=len(test_loader)): data = batch[0] data = data.to(self.args.device) logits = self.model(data) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) print('{} way {} shot,Test acc={:.4f} + {:.4f}\n'.format( args.eval_way, args.eval_shot, va, vap)) return vl, va, vap
def evaluate_fsl(self, data_loader): # restore model args args = self.args # evaluation mode self.model.eval() record = np.zeros((args.num_eval_episodes, 2)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() print('{} best epoch {}, best val acc={:.4f} + {:.4f}'.format( args.test_mode, self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) with torch.no_grad(): for i, batch in tqdm(enumerate(data_loader, 1)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = args.eval_shot * args.eval_way data_shot, data_query = data[:p], data[p:] logits = self.model.forward_fsl(data_shot, data_query) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) # train mode self.model.train() return vl, va, vap
def evaluate_gfsl(self, fsl_loader, gfsl_loader, gfsl_dataset): # restore model args args = self.args # evaluation mode self.model.eval() record = np.zeros((args.num_eval_episodes, 5)) # loss and acc label_unseen_query = torch.arange( min(args.eval_way, self.valset.num_class)).repeat(args.eval_query).long() if torch.cuda.is_available(): label_unseen_query = label_unseen_query.cuda() print('{} best epoch {}, best val acc={:.4f} + {:.4f}'.format( args.test_mode, self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) with torch.no_grad(): for i, batch in tqdm(enumerate(zip(gfsl_loader, fsl_loader), 1)): if torch.cuda.is_available(): data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda( ), batch[1][1].cuda() else: data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0], batch[1][0], batch[0][1], batch[1][1] p2 = args.eval_shot * args.eval_way data_unseen_shot, data_unseen_query = data_unseen[: p2], data_unseen[ p2:] label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:] whole_query = torch.cat([data_seen, data_unseen_query], 0) whole_label = torch.cat( [seen_label, label_unseen_query + gfsl_dataset.num_class]) logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query) # compute un-biased accuracy new_logits = torch.cat([logits_s, logits_u], 1) record[i - 1, 0] = F.cross_entropy(new_logits, whole_label).item() record[i - 1, 1] = count_acc(new_logits, whole_label) # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint( torch.cat([logits_s, logits_u], 1), whole_label, seen_label.shape[0]) record[i - 1, 2:] = np.array([HM_nobias, SA_nobias, UA_nobias]) del logits_s, logits_u, new_logits torch.cuda.empty_cache() assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 2]) # train mode self.model.train() return vl, va, vap
def evaluate_fsl(self): args = self.args record = np.zeros((args.num_eval_episodes, 2)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() for i, batch in tqdm(enumerate(self.test_fsl_loader, 1)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p2 = args.eval_shot * args.eval_way p = args.eval_shot * args.eval_way data_shot, data_query = data[:p], data[p:] with torch.no_grad(): logits = self.model.forward_fsl(data_shot, data_query) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc del loss, logits torch.cuda.empty_cache() assert (i == record.shape[0]) print('-'.join([args.model_class, args.model_path])) self.trlog['acc_mean'], self.trlog[ 'acc_interval'] = compute_confidence_interval(record[:, 1]) print('FSL {}-way Acc {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['acc_mean'], self.trlog['acc_interval']))
def evaluate_test(self): # restore model args args = self.args # evaluation mode self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params']) self.model.eval() record = np.zeros((10000, 2)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) with torch.no_grad(): for i, batch in tqdm(enumerate(self.test_loader, 1)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] logits = self.model(data) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i-1, 0] = loss.item() record[i-1, 1] = acc assert(i == record.shape[0]) vl, _ = compute_confidence_interval(record[:,0]) va, vap = compute_confidence_interval(record[:,1]) self.trlog['test_acc'] = va self.trlog['test_acc_interval'] = vap self.trlog['test_loss'] = vl print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) print('Test acc={:.4f} + {:.4f}\n'.format( self.trlog['test_acc'], self.trlog['test_acc_interval'])) return vl, va, vap
def evaluate(self, data_loader): # restore model args args = self.args # evaluation mode self.model.eval() record = np.zeros((args.num_eval_episodes, 2)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat( # args.num_tasks * args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in tqdm(enumerate(data_loader, 1), total=len(data_loader), desc='eval procedure'): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] logits = self.model(data) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) # train mode self.model.train() if self.args.fix_BN: self.model.encoder.eval() return vl, va, vap
def evaluate_test(self, use_max_tst=False): # restore model args args = self.args # evaluation mode if use_max_tst: assert args.tst_criterion != '', 'Please specify a criterion' fname = osp.join(self.args.save_path, 'max_tst_criterion.pth') criterion = args.tst_criterion max_acc_epoch = 'max_tst_criterion_epoch' max_acc = 'max_tst_criterion' max_acc_interval = 'max_tst_criterion_interval' test_acc = 'test_acc_at_max_criterion' test_acc_interval = 'test_acc_interval_at_max_criterion' test_loss = 'test_loss_at_max_criterion' else: fname = osp.join(self.args.save_path, 'max_acc.pth') criterion = 'SupervisedAcc' max_acc_epoch = 'max_acc_epoch' max_acc = 'max_acc' max_acc_interval = 'max_acc_interval' test_acc = 'test_acc' test_acc_interval = 'test_acc_interval' test_loss = 'test_loss' print('\nCriterion selected: {}'.format(criterion)) print('Reloading model from {}'.format(fname)) self.model.load_state_dict(torch.load(fname)['params']) self.model.eval() record = np.zeros((10000, 2)) # loss and acc metrics = defaultdict(list) # all other metrics label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() all_labels = torch.arange(args.eval_way, device=label.device).repeat(args.eval_shot + args.eval_query) max_validation_str = 'Maximum value of valid_{} {:.4f} + {:.4f} reached at Epoch {}\n'.format( criterion, self.trlog[max_acc], self.trlog[max_acc_interval], self.trlog[max_acc_epoch]) print(max_validation_str) with torch.no_grad(): for i, batch in tqdm(enumerate(self.test_loader, 1)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] embeddings, logits = self.model(data, return_feature=True) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i-1, 0] = loss.item() record[i-1, 1] = acc if args.tst_free: embeddings_dict = self.model.get_embeddings_dict(embeddings, all_labels) # TST-free part clustering_losses = tst_free.clustering_loss(embeddings_dict, args.sinkhorn_reg, 'wasserstein', sqrt_temperature=np.sqrt(args.temperature), normalize_by_dim=False, clustering_iterations=20, sinkhorn_iterations=20, sinkhorn_iterations_warmstart=4, sanity_check=False) for key, val in clustering_losses.items(): metrics[key].append(val) if args.debug_fast: print('Debug fast, breaking TEST after 1 mini-batch') record = record[:1] break assert(i == record.shape[0]) vl, _ = compute_confidence_interval(record[:,0]) va, vap = compute_confidence_interval(record[:,1]) metric_summaries = {key: compute_confidence_interval(val) for key, val in metrics.items()} self.trlog[test_acc] = va self.trlog[test_acc_interval] = vap self.trlog[test_loss] = vl summary_lines = [] summary_lines.append(max_validation_str) summary_lines.append('test_SupervisedAcc {:.4f} + {:.4f} (ep{})'.format( self.trlog[test_acc], self.trlog[test_acc_interval], self.trlog[max_acc_epoch])) for key, (mean, std) in metric_summaries.items(): summary_lines.append('test_{} {:.4f} + {:.4f} (ep{})'.format(key, mean, std, self.trlog[max_acc_epoch])) #self.print_metric_summaries(metric_summaries, prefix='\ttest_') #self.log_metric_summaries(metric_summaries, 0, prefix='test_') self.trlog['TST'] = metric_summaries summary_lines_str = '\n'.join(summary_lines) print('\n{}'.format(summary_lines_str)) with open(osp.join(self.args.save_path, 'summary_max_{}.txt'.format(criterion)), 'w') as f: f.write(summary_lines_str)
def evaluate(self, data_loader): # restore model args args = self.args # evaluation mode self.model.eval() accuracies = [] losses = [] metrics = OrderedDict() label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() all_labels = torch.arange(args.eval_way, device=label.device).repeat(args.eval_shot + args.eval_query) #print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( # self.trlog['max_acc_epoch'], # self.trlog['max_acc'], # self.trlog['max_acc_interval'])) with torch.no_grad(): for i, batch in enumerate(data_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] embeddings, logits = self.model(data, return_feature=True) if args.tst_free: embeddings_logits_dict = self.model.get_embeddings_dict(embeddings, all_labels, logits) for sinkhorn_reg_str in args.sinkhorn_reg: # loop over all possible regularizations sinkhorn_reg_float = float(sinkhorn_reg_str) transductive_losses = tst_free.transductive_from_logits(embeddings_logits_dict, regularization=sinkhorn_reg_float) for key, val in transductive_losses.items(): key += '_reg{}'.format(sinkhorn_reg_str) metrics.setdefault(key, []) metrics[key].append(val) # data contains both support and query sets (typically 25+75 for 5-shot 5-way 15-query) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) losses.append(loss.item()) accuracies.append(acc) if args.tst_free: # Also do the transductive based on the logits embeddings_dict = self.model.get_embeddings_dict(embeddings, all_labels) # TST-free part for sinkhorn_reg_str in args.sinkhorn_reg: # loop over all possible regularizations sinkhorn_reg_float = float(sinkhorn_reg_str) clustering_losses = tst_free.clustering_loss(embeddings_dict, sinkhorn_reg_float, 'wasserstein', sqrt_temperature=np.sqrt(args.temperature), normalize_by_dim=False, clustering_iterations=20, sinkhorn_iterations=20, sinkhorn_iterations_warmstart=4, sanity_check=False) for key, val in clustering_losses.items(): key += '_reg{}'.format(sinkhorn_reg_str) metrics.setdefault(key, []) metrics[key].append(val) if args.debug_fast: print('Debug fast, breaking eval after 1 mini-batch') break assert(i == len(losses) and i == len(accuracies)) vl, _ = compute_confidence_interval(losses) va, vap = compute_confidence_interval(accuracies) metric_summaries = {key: compute_confidence_interval(val) for key, val in metrics.items()} # train mode self.model.train() if self.args.fix_BN: self.model.encoder.eval() if args.tst_free: return vl, va, vap, metric_summaries else: return vl, va, vap
def evaluate_test(self): # restore model args args = self.args # evaluation mode self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params']) self.model.eval() if args.test_mode == 'FSL': record = np.zeros((10000, 2)) # loss and acc label = torch.arange(args.eval_way).repeat(args.eval_query).type( torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in enumerate(self.test_fsl_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = args.eval_shot * args.eval_way data_shot, data_query = data[:p], data[p:] logits = self.model.forward_fsl(data_shot, data_query) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) self.trlog['test_acc'] = va self.trlog['test_acc_interval'] = vap self.trlog['test_loss'] = vl print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) print('Test acc={:.4f} + {:.4f}\n'.format( self.trlog['test_acc'], self.trlog['test_acc_interval'])) else: record = np.zeros((10000, 5)) # loss and acc label_unseen_query = torch.arange( min(args.eval_way, self.valset.num_class)).repeat(args.eval_query).long() if torch.cuda.is_available(): label_unseen_query = label_unseen_query.cuda() with torch.no_grad(): for i, batch in tqdm( enumerate( zip(self.test_gfsl_loader, self.test_fsl_loader), 1)): if torch.cuda.is_available(): data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda( ), batch[1][1].cuda() else: data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0], batch[1][0], batch[0][1], batch[1][1] p2 = args.eval_shot * args.eval_way data_unseen_shot, data_unseen_query = data_unseen[: p2], data_unseen[ p2:] label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:] whole_query = torch.cat([data_seen, data_unseen_query], 0) whole_label = torch.cat([ seen_label, label_unseen_query + self.traintestset.num_class ]) logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query) # compute un-biased accuracy new_logits = torch.cat([logits_s, logits_u], 1) record[i - 1, 0] = F.cross_entropy(new_logits, whole_label).item() record[i - 1, 1] = count_acc(new_logits, whole_label) # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint( torch.cat([logits_s, logits_u], 1), whole_label, seen_label.shape[0]) record[i - 1, 2:] = np.array([HM_nobias, SA_nobias, UA_nobias]) del logits_s, logits_u, new_logits torch.cuda.empty_cache() m_list = [] p_list = [] for i in range(5): m1, p1 = compute_confidence_interval(record[:, i]) m_list.append(m1) p_list.append(p1) self.trlog['test_loss'] = m_list[0] self.trlog['test_acc'] = m_list[1] self.trlog['test_acc_interval'] = p_list[1] self.trlog['test_HM_acc'] = m_list[2] self.trlog['test_HM_acc_interval'] = p_list[2] self.trlog['test_HMSeen_acc'] = m_list[3] self.trlog['test_HMSeen_acc_interval'] = p_list[3] self.trlog['test_HMUnseen_acc'] = m_list[4] self.trlog['test_HMUnseen_acc_interval'] = p_list[4] print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) print('Test HM acc={:.4f} + {:.4f}\n'.format( self.trlog['test_HM_acc'], self.trlog['test_HM_acc_interval'])) print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[1], p_list[1])) print('GFSL {}-way HM w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[2], p_list[2])) print('GFSL {}-way HMSeen w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[3], p_list[3])) print('GFSL {}-way HMUnseen w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[4], p_list[4]))
def evaluate_gfsl(self): args = self.args label_unseen_query = torch.arange(args.eval_way).repeat( args.eval_query).long() if torch.cuda.is_available(): label_unseen_query = label_unseen_query.cuda() generalized_few_shot_acc = np.zeros((args.num_eval_episodes, 2)) generalized_few_shot_delta = np.zeros((args.num_eval_episodes, 4)) generalized_few_shot_hmeanacc = np.zeros((args.num_eval_episodes, 6)) generalized_few_shot_hmeanmap = np.zeros((args.num_eval_episodes, 6)) generalized_few_shot_ausuc = np.zeros((args.num_eval_episodes, 1)) AUC_record = [] for i, batch in tqdm( enumerate(zip(self.test_gfsl_loader, self.test_fsl_loader), 1)): if torch.cuda.is_available(): data_seen, data_unseen, seen_label, unseen_label = batch[0][ 0].cuda(), batch[1][0].cuda(), batch[0][1].cuda( ), batch[1][1].cuda() else: data_seen, data_unseen, seen_label, unseen_label = batch[0][ 0], batch[1][0], batch[0][1], batch[1][1] p2 = args.eval_shot * args.eval_way data_unseen_shot, data_unseen_query = data_unseen[: p2], data_unseen[ p2:] label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:] whole_query = torch.cat([data_seen, data_unseen_query], 0) whole_label = torch.cat( [seen_label, label_unseen_query + self.trainset.num_class]) if args.model_class in ['CLS', 'Castle', 'ACastle']: with torch.no_grad(): logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query) elif args.model_class in ['ProtoNet']: with torch.no_grad(): logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query, self.model.seen_proto) # compute un-biased accuracy new_logits = torch.cat([logits_s, logits_u], 1) if 'acc' in self.criteria or 'hmeanacc' in self.criteria or 'delta' in self.criteria: new_logits_acc_biased = torch.cat( [logits_s - self.best_bias_acc, logits_u], 1) if 'hmeanmap' in self.criteria: new_logits_map_biased = torch.cat( [logits_s - self.best_bias_map, logits_u], 1) # Criterion: Acc if 'acc' in self.criteria: generalized_few_shot_acc[i - 1, 0] = count_acc( new_logits, whole_label) # compute biased accuracy generalized_few_shot_acc[i - 1, 1] = count_acc( new_logits_acc_biased, whole_label) if 'delta' in self.criteria: # compute delta value for un-biased logits unbiased_detla1, unbiased_detla2 = count_delta_value( new_logits, whole_label, seen_label.shape[0], self.trainset.num_class) # compute delta value biased_detla1, biased_detla2 = count_delta_value( new_logits_acc_biased, whole_label, seen_label.shape[0], self.trainset.num_class) generalized_few_shot_delta[i - 1, :] = np.array([ unbiased_detla1, unbiased_detla2, biased_detla1, biased_detla2 ]) if 'hmeanacc' in self.criteria: # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint( new_logits, whole_label, seen_label.shape[0]) HM, SA, UA = count_acc_harmonic_low_shot_joint( new_logits_acc_biased, whole_label, seen_label.shape[0]) generalized_few_shot_hmeanacc[i - 1, :] = np.array( [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA]) if 'hmeanmap' in self.criteria: # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_MAP( new_logits, whole_label, seen_label.shape[0], 'macro') HM, SA, UA = count_acc_harmonic_MAP(new_logits_map_biased, whole_label, seen_label.shape[0], 'macro') generalized_few_shot_hmeanmap[i - 1, :] = np.array( [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA]) if 'ausuc' in self.criteria: # compute AUSUC generalized_few_shot_ausuc[i - 1, 0], temp_auc_record = Compute_AUSUC( logits_s.detach().cpu().numpy(), logits_u.detach().cpu().numpy(), whole_label.cpu().numpy(), np.arange( self.trainset.num_class), self.trainset.num_class + np.arange(args.eval_way)) AUC_record.append(temp_auc_record) del logits_s, logits_u, new_logits torch.cuda.empty_cache() self.AUC_record = AUC_record print('-'.join([args.model_class, args.model_path])) if 'acc' in self.criteria: self.trlog['acc_mean'], self.trlog[ 'acc_interval'] = compute_confidence_interval( generalized_few_shot_acc[:, 0]) self.trlog['acc_biased_mean'], self.trlog[ 'acc_biased_interval'] = compute_confidence_interval( generalized_few_shot_acc[:, 1]) print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['acc_mean'], self.trlog['acc_interval'])) print('GFSL {}-way Acc w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['acc_biased_mean'], self.trlog['acc_biased_interval'])) if 'delta' in self.criteria: self.trlog['detla1_mean'], self.trlog[ 'detla1_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 0]) self.trlog['detla2_mean'], self.trlog[ 'detla2_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 1]) self.trlog['detla1_biased_mean'], self.trlog[ 'detla1_biased_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 2]) self.trlog['detla2_biased_mean'], self.trlog[ 'detla2_biased_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 3]) print('GFSL {}-way Detla1 w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla1_mean'], self.trlog['detla1_interval'])) print('GFSL {}-way Detla1 w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla1_biased_mean'], self.trlog['detla1_biased_interval'])) print('GFSL {}-way Detla2 w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla2_mean'], self.trlog['detla2_interval'])) print('GFSL {}-way Detla2 w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla2_biased_mean'], self.trlog['detla2_biased_interval'])) if 'hmeanacc' in self.criteria: self.trlog['HM_mean'], self.trlog[ 'HM_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 0]) self.trlog['S2All_mean'], self.trlog[ 'S2All_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 1]) self.trlog['U2All_mean'], self.trlog[ 'U2All_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 2]) self.trlog['HM_biased_mean'], self.trlog[ 'HM_biased_nterval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 3]) self.trlog['S2All_biased_mean'], self.trlog[ 'S2All_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 4]) self.trlog['U2All_biased_mean'], self.trlog[ 'U2All_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 5]) print('GFSL {}-way HM_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_mean'], self.trlog['HM_interval'])) print('GFSL {}-way HM_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_biased_mean'], self.trlog['HM_biased_nterval'])) print('GFSL {}-way S2All_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_mean'], self.trlog['S2All_interval'])) print('GFSL {}-way S2All_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_biased_mean'], self.trlog['S2All_biased_interval'])) print('GFSL {}-way U2All_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_mean'], self.trlog['U2All_interval'])) print('GFSL {}-way U2All_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_biased_mean'], self.trlog['U2All_biased_interval'])) if 'hmeanmap' in self.criteria: self.trlog['HM_map_mean'], self.trlog[ 'HM_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 0]) self.trlog['S2All_map_mean'], self.trlog[ 'S2All_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 1]) self.trlog['U2All_map_mean'], self.trlog[ 'U2All_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 2]) self.trlog['HM_map_biased_mean'], self.trlog[ 'HM_map_biased_nterval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 3]) self.trlog['S2All_map_biased_mean'], self.trlog[ 'S2All_map_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 4]) self.trlog['U2All_map_biased_mean'], self.trlog[ 'U2All_map_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 5]) print('GFSL {}-way HM_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_map_mean'], self.trlog['HM_map_interval'])) print('GFSL {}-way HM_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_map_biased_mean'], self.trlog['HM_map_biased_nterval'])) print('GFSL {}-way S2All_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_map_mean'], self.trlog['S2All_map_interval'])) print('GFSL {}-way S2All_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_map_biased_mean'], self.trlog['S2All_map_biased_interval'])) print('GFSL {}-way U2All_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_map_mean'], self.trlog['U2All_map_interval'])) print('GFSL {}-way U2All_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_map_biased_mean'], self.trlog['U2All_map_biased_interval'])) if 'ausuc' in self.criteria: self.trlog['AUSUC_mean'], self.trlog[ 'AUSUC_interval'] = compute_confidence_interval( generalized_few_shot_ausuc[:, 0]) print('GFSL {}-way AUSUC {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['AUSUC_mean'], self.trlog['AUSUC_interval']))
if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] data_emb = model(data, is_emb=True) if args.centralize: data_emb = data_emb - class_mean if args.normalize: data_emb = F.normalize(data_emb, dim=1, p=2) split_index = args.way * args.shot data_shot, data_query = data_emb[:split_index], data_emb[ split_index:] SVM = LinearSVC(C=best_c, multi_class='crammer_singer', dual=False, max_iter=5000).fit(data_shot.cpu().numpy(), shot_label) prediction = SVM.predict(data_query.cpu().numpy()) acc = np.mean(prediction == query_label) test_acc_record[i - 1] = acc # print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) m, pm = compute_confidence_interval(test_acc_record) ensemble_result.append('{:.4f} + {:.4f}'.format(m, pm)) print('{} way {} shot,Test acc={:.4f} + {:.4f}, best_gamma:{}'. format(args.way, args.shot, m, pm, best_c)) print('ensemble result: {}'.format(','.join(ensemble_result)))
def open_evaluate(self, data_loader): # restore model args args = self.args # evaluation mode self.model.eval() record = np.zeros((args.num_test_episodes, 4)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() print('Evaluating ... best epoch {}, SnaTCHer={:.4f} + {:.4f} acc={:.4f} + {:.4f}'.format( self.trlog['max_auc_epoch'], self.trlog['max_auc'], self.trlog['max_auc_interval'], self.trlog['acc'], self.trlog['acc_interval'])) with torch.no_grad(): for i, batch in enumerate(tqdm(data_loader)): data, _ = [_.cuda() for _ in batch] logits = self.para_model(data) logits = logits.reshape([-1, args.eval_way + args.open_eval_way, args.way]) klogits = logits[:, :args.eval_way, :].reshape(-1, args.way) ulogits = logits[:, args.eval_way:, :].reshape(-1, args.way) loss = F.cross_entropy(klogits, label) acc = count_acc(klogits, label) """ Distance """ kdist = -(klogits.max(1)[0]) udist = -(ulogits.max(1)[0]) kdist = kdist.cpu().detach().numpy() udist = udist.cpu().detach().numpy() dist_auroc = calc_auroc(kdist, udist) """ Snatcher """ with torch.no_grad(): instance_embs = self.para_model.instance_embs support_idx = self.para_model.support_idx query_idx = self.para_model.query_idx support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,))) query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,))) emb_dim = support.shape[-1] support = support[:, :, :args.way].contiguous() # get mean of the support bproto = support.mean(dim=1) # Ntask x NK x d proto = self.para_model.slf_attn(bproto, bproto, bproto) kquery = query[:, :, :args.way].contiguous() uquery = query[:, :, args.way:].contiguous() snatch_known = [] for j in range(75): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = klogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = kquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_known.append(pdiff) snatch_unknown = [] for j in range(ulogits.shape[0]): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = ulogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = uquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_unknown.append(pdiff) pkdiff = torch.stack(snatch_known) pudiff = torch.stack(snatch_unknown) pkdiff = pkdiff.cpu().detach().numpy() pudiff = pudiff.cpu().detach().numpy() snatch_auroc = calc_auroc(pkdiff, pudiff) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc record[i - 1, 2] = snatch_auroc record[i - 1, 3] = dist_auroc vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) auc_sna, auc_sna_p = compute_confidence_interval(record[:, 2]) auc_dist, auc_dist_p = compute_confidence_interval(record[:, 3]) print("acc: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}" \ .format(va, vap, auc_dist, auc_dist_p, auc_sna, auc_sna_p)) # train mode self.model.train() if self.args.fix_BN: self.model.encoder.eval() return vl, va, vap, auc_sna, auc_sna_p