def evaluate_test(self): # restore model args args = self.args # evaluation mode self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params']) self.model.eval() if args.test_mode == 'FSL': record = np.zeros((10000, 2)) # loss and acc label = torch.arange(args.eval_way).repeat(args.eval_query).type( torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in enumerate(self.test_fsl_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = args.eval_shot * args.eval_way data_shot, data_query = data[:p], data[p:] logits = self.model.forward_fsl(data_shot, data_query) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc assert (i == record.shape[0]) vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) self.trlog['test_acc'] = va self.trlog['test_acc_interval'] = vap self.trlog['test_loss'] = vl print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) print('Test acc={:.4f} + {:.4f}\n'.format( self.trlog['test_acc'], self.trlog['test_acc_interval'])) else: record = np.zeros((10000, 5)) # loss and acc label_unseen_query = torch.arange( min(args.eval_way, self.valset.num_class)).repeat(args.eval_query).long() if torch.cuda.is_available(): label_unseen_query = label_unseen_query.cuda() with torch.no_grad(): for i, batch in tqdm( enumerate( zip(self.test_gfsl_loader, self.test_fsl_loader), 1)): if torch.cuda.is_available(): data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda( ), batch[1][1].cuda() else: data_seen, data_unseen, seen_label, unseen_label = batch[ 0][0], batch[1][0], batch[0][1], batch[1][1] p2 = args.eval_shot * args.eval_way data_unseen_shot, data_unseen_query = data_unseen[: p2], data_unseen[ p2:] label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:] whole_query = torch.cat([data_seen, data_unseen_query], 0) whole_label = torch.cat([ seen_label, label_unseen_query + self.traintestset.num_class ]) logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query) # compute un-biased accuracy new_logits = torch.cat([logits_s, logits_u], 1) record[i - 1, 0] = F.cross_entropy(new_logits, whole_label).item() record[i - 1, 1] = count_acc(new_logits, whole_label) # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint( torch.cat([logits_s, logits_u], 1), whole_label, seen_label.shape[0]) record[i - 1, 2:] = np.array([HM_nobias, SA_nobias, UA_nobias]) del logits_s, logits_u, new_logits torch.cuda.empty_cache() m_list = [] p_list = [] for i in range(5): m1, p1 = compute_confidence_interval(record[:, i]) m_list.append(m1) p_list.append(p1) self.trlog['test_loss'] = m_list[0] self.trlog['test_acc'] = m_list[1] self.trlog['test_acc_interval'] = p_list[1] self.trlog['test_HM_acc'] = m_list[2] self.trlog['test_HM_acc_interval'] = p_list[2] self.trlog['test_HMSeen_acc'] = m_list[3] self.trlog['test_HMSeen_acc_interval'] = p_list[3] self.trlog['test_HMUnseen_acc'] = m_list[4] self.trlog['test_HMUnseen_acc_interval'] = p_list[4] print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) print('Test HM acc={:.4f} + {:.4f}\n'.format( self.trlog['test_HM_acc'], self.trlog['test_HM_acc_interval'])) print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[1], p_list[1])) print('GFSL {}-way HM w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[2], p_list[2])) print('GFSL {}-way HMSeen w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[3], p_list[3])) print('GFSL {}-way HMUnseen w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, m_list[4], p_list[4]))
def train(self): args = self.args # start GFSL training for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for _, batch in enumerate( zip(self.train_fsl_loader, self.train_gfsl_loader)): self.train_step += 1 if torch.cuda.is_available(): support_data, support_label = batch[0][0].cuda( ), batch[0][1].cuda() query_data, query_label = batch[1][0].cuda( ), batch[1][1].cuda() else: support_data, support_label = batch[0][0], batch[0][1] query_data, query_label = batch[1][0], batch[1][1] data_tm = time.time() self.dt.add(data_tm - start_tm) logits = self.model(support_data, query_data, support_label) loss = F.cross_entropy( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl2.add(loss.item()) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl1.add(loss.item()) ta.add(acc) self.optimizer.zero_grad() loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() self.lr_scheduler.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) self.try_logging(tl1, tl2, ta) # refresh start_tm start_tm = time.time() del logits, loss torch.cuda.empty_cache() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch))) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')
def train(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in self.train_loader: self.train_step += 1 if torch.cuda.is_available(): data, gt_label = [_.cuda() for _ in batch] else: data, gt_label = batch[0], batch[1] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = F.cross_entropy(logits, label) tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() self.lr_scheduler.step() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')
def validate(args, model, val_loader, epoch, trlog=None): model.eval() global writer vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager( ), Averager() if trlog is not None: print('[Dist] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_dist_epoch'], trlog['max_acc_dist'])) print('[Sim] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_sim_epoch'], trlog['max_acc_sim'])) # test performance with Few-Shot label = torch.arange(args.num_val_class).repeat(args.query).long() if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data, _ = batch data_shot, data_query = data[:args.num_val_class], data[ args.num_val_class:] # 16-way test logits_dist, logits_sim = model.forward_proto( data_shot, data_query, args.num_val_class) loss_dist = F.cross_entropy(logits_dist, label) acc_dist = count_acc(logits_dist, label) loss_sim = F.cross_entropy(logits_sim, label) acc_sim = count_acc(logits_sim, label) vl_dist.add(loss_dist.item()) va_dist.add(acc_dist) vl_sim.add(loss_sim.item()) va_sim.add(acc_sim) vl_dist = vl_dist.item() va_dist = va_dist.item() vl_sim = vl_sim.item() va_sim = va_sim.item() print( 'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}' .format(epoch, vl_dist, va_dist, vl_sim, va_sim)) if trlog is not None: writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch) writer.add_scalar('data/val_acc_dist', float(va_dist), epoch) writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch) writer.add_scalar('data/val_acc_sim', float(va_sim), epoch) if va_dist > trlog['max_acc_dist']: trlog['max_acc_dist'] = va_dist trlog['max_acc_dist_epoch'] = epoch save_model('max_acc_dist', model, args) save_checkpoint(True) if va_sim > trlog['max_acc_sim']: trlog['max_acc_sim'] = va_sim trlog['max_acc_sim_epoch'] = epoch save_model('max_acc_sim', model, args) save_checkpoint(True) trlog['val_loss_dist'].append(vl_dist) trlog['val_acc_dist'].append(va_dist) trlog['val_loss_sim'].append(vl_sim) trlog['val_acc_sim'].append(va_sim) return trlog
def evaluate_test(self): # restore model args emb_dim = self.emb_dim args = self.args weights = torch.load( osp.join(self.args.save_path, self.args.weight_name)) model_weights = weights['params'] self.missing_keys, self.unexpected_keys = self.model.load_state_dict( model_weights, strict=False) self.model.eval() test_steps = 600 self.record = np.zeros((test_steps, 2)) # loss and acc self.auroc_record = np.zeros((test_steps, 10)) label = torch.arange(args.closed_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() way = args.closed_way label = torch.arange(way).repeat(15).cuda() for i, batch in tqdm(enumerate(self.test_loader, 1)): if i > test_steps: break if torch.cuda.is_available(): data, dlabel = [_.cuda() for _ in batch] else: data = batch[0] self.probe_data = data self.probe_dlabel = dlabel with torch.no_grad(): _ = self.para_model(data) instance_embs = self.para_model.probe_instance_embs support_idx = self.para_model.probe_support_idx query_idx = self.para_model.probe_query_idx support = instance_embs[support_idx.flatten()].view( *(support_idx.shape + (-1, ))) query = instance_embs[query_idx.flatten()].view( *(query_idx.shape + (-1, ))) emb_dim = support.shape[-1] support = support[:, :, :way].contiguous() # get mean of the support bproto = support.mean(dim=1) # Ntask x NK x d proto = bproto kquery = query[:, :, :way].contiguous() uquery = query[:, :, way:].contiguous() # get mean of the support proto = self.para_model.slf_attn(proto, proto, proto) proto = proto[0] klogits = -(kquery.reshape(-1, 1, emb_dim) - proto).pow(2).sum(2) / 64.0 ulogits = -(uquery.reshape(-1, 1, emb_dim) - proto).pow(2).sum(2) / 64.0 loss = F.cross_entropy(klogits, label) acc = count_acc(klogits, label) """ Probability """ known_prob = F.softmax(klogits, 1).max(1)[0] unknown_prob = F.softmax(ulogits, 1).max(1)[0] known_scores = (known_prob).cpu().detach().numpy() unknown_scores = (unknown_prob).cpu().detach().numpy() known_scores = 1 - known_scores unknown_scores = 1 - unknown_scores auroc = calc_auroc(known_scores, unknown_scores) """ Distance """ kdist = -(klogits.max(1)[0]) udist = -(ulogits.max(1)[0]) kdist = kdist.cpu().detach().numpy() udist = udist.cpu().detach().numpy() dist_auroc = calc_auroc(kdist, udist) """ Snatcher """ with torch.no_grad(): snatch_known = [] for j in range(75): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = klogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = kquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_known.append(pdiff) snatch_unknown = [] for j in range(ulogits.shape[0]): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = ulogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = uquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_unknown.append(pdiff) pkdiff = torch.stack(snatch_known) pudiff = torch.stack(snatch_unknown) pkdiff = pkdiff.cpu().detach().numpy() pudiff = pudiff.cpu().detach().numpy() snatch_auroc = calc_auroc(pkdiff, pudiff) self.record[i - 1, 0] = loss.item() self.record[i - 1, 1] = acc self.auroc_record[i - 1, 0] = auroc self.auroc_record[i - 1, 1] = snatch_auroc self.auroc_record[i - 1, 2] = dist_auroc if i % 100 == 0: vdata = self.record[:, 1] vdata = 1.0 * np.array(vdata) vdata = vdata[:i] va = np.mean(vdata) std = np.std(vdata) vap = 1.96 * (std / np.sqrt(i)) audata = self.auroc_record[:, 0] audata = np.array(audata, np.float32) audata = audata[:i] aua = np.mean(audata) austd = np.std(audata) auap = 1.96 * (austd / np.sqrt(i)) sdata = self.auroc_record[:, 1] sdata = np.array(sdata, np.float32) sdata = sdata[:i] sa = np.mean(sdata) sstd = np.std(sdata) sap = 1.96 * (sstd / np.sqrt(i)) ddata = self.auroc_record[:, 2] ddata = np.array(ddata, np.float32)[:i] da = np.mean(ddata) dstd = np.std(ddata) dap = 1.96 * (dstd / np.sqrt(i)) print("acc: {:.4f} + {:.4f} Prob: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}"\ .format(va, vap, aua, auap, da, dap, sa, sap)) return
def evaluate_gfsl(self): args = self.args label_unseen_query = torch.arange(args.eval_way).repeat( args.eval_query).long() if torch.cuda.is_available(): label_unseen_query = label_unseen_query.cuda() generalized_few_shot_acc = np.zeros((args.num_eval_episodes, 2)) generalized_few_shot_delta = np.zeros((args.num_eval_episodes, 4)) generalized_few_shot_hmeanacc = np.zeros((args.num_eval_episodes, 6)) generalized_few_shot_hmeanmap = np.zeros((args.num_eval_episodes, 6)) generalized_few_shot_ausuc = np.zeros((args.num_eval_episodes, 1)) AUC_record = [] for i, batch in tqdm( enumerate(zip(self.test_gfsl_loader, self.test_fsl_loader), 1)): if torch.cuda.is_available(): data_seen, data_unseen, seen_label, unseen_label = batch[0][ 0].cuda(), batch[1][0].cuda(), batch[0][1].cuda( ), batch[1][1].cuda() else: data_seen, data_unseen, seen_label, unseen_label = batch[0][ 0], batch[1][0], batch[0][1], batch[1][1] p2 = args.eval_shot * args.eval_way data_unseen_shot, data_unseen_query = data_unseen[: p2], data_unseen[ p2:] label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:] whole_query = torch.cat([data_seen, data_unseen_query], 0) whole_label = torch.cat( [seen_label, label_unseen_query + self.trainset.num_class]) if args.model_class in ['CLS', 'Castle', 'ACastle']: with torch.no_grad(): logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query) elif args.model_class in ['ProtoNet']: with torch.no_grad(): logits_s, logits_u = self.model.forward_generalized( data_unseen_shot, whole_query, self.model.seen_proto) # compute un-biased accuracy new_logits = torch.cat([logits_s, logits_u], 1) if 'acc' in self.criteria or 'hmeanacc' in self.criteria or 'delta' in self.criteria: new_logits_acc_biased = torch.cat( [logits_s - self.best_bias_acc, logits_u], 1) if 'hmeanmap' in self.criteria: new_logits_map_biased = torch.cat( [logits_s - self.best_bias_map, logits_u], 1) # Criterion: Acc if 'acc' in self.criteria: generalized_few_shot_acc[i - 1, 0] = count_acc( new_logits, whole_label) # compute biased accuracy generalized_few_shot_acc[i - 1, 1] = count_acc( new_logits_acc_biased, whole_label) if 'delta' in self.criteria: # compute delta value for un-biased logits unbiased_detla1, unbiased_detla2 = count_delta_value( new_logits, whole_label, seen_label.shape[0], self.trainset.num_class) # compute delta value biased_detla1, biased_detla2 = count_delta_value( new_logits_acc_biased, whole_label, seen_label.shape[0], self.trainset.num_class) generalized_few_shot_delta[i - 1, :] = np.array([ unbiased_detla1, unbiased_detla2, biased_detla1, biased_detla2 ]) if 'hmeanacc' in self.criteria: # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint( new_logits, whole_label, seen_label.shape[0]) HM, SA, UA = count_acc_harmonic_low_shot_joint( new_logits_acc_biased, whole_label, seen_label.shape[0]) generalized_few_shot_hmeanacc[i - 1, :] = np.array( [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA]) if 'hmeanmap' in self.criteria: # compute harmonic mean HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_MAP( new_logits, whole_label, seen_label.shape[0], 'macro') HM, SA, UA = count_acc_harmonic_MAP(new_logits_map_biased, whole_label, seen_label.shape[0], 'macro') generalized_few_shot_hmeanmap[i - 1, :] = np.array( [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA]) if 'ausuc' in self.criteria: # compute AUSUC generalized_few_shot_ausuc[i - 1, 0], temp_auc_record = Compute_AUSUC( logits_s.detach().cpu().numpy(), logits_u.detach().cpu().numpy(), whole_label.cpu().numpy(), np.arange( self.trainset.num_class), self.trainset.num_class + np.arange(args.eval_way)) AUC_record.append(temp_auc_record) del logits_s, logits_u, new_logits torch.cuda.empty_cache() self.AUC_record = AUC_record print('-'.join([args.model_class, args.model_path])) if 'acc' in self.criteria: self.trlog['acc_mean'], self.trlog[ 'acc_interval'] = compute_confidence_interval( generalized_few_shot_acc[:, 0]) self.trlog['acc_biased_mean'], self.trlog[ 'acc_biased_interval'] = compute_confidence_interval( generalized_few_shot_acc[:, 1]) print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['acc_mean'], self.trlog['acc_interval'])) print('GFSL {}-way Acc w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['acc_biased_mean'], self.trlog['acc_biased_interval'])) if 'delta' in self.criteria: self.trlog['detla1_mean'], self.trlog[ 'detla1_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 0]) self.trlog['detla2_mean'], self.trlog[ 'detla2_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 1]) self.trlog['detla1_biased_mean'], self.trlog[ 'detla1_biased_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 2]) self.trlog['detla2_biased_mean'], self.trlog[ 'detla2_biased_interval'] = compute_confidence_interval( generalized_few_shot_delta[:, 3]) print('GFSL {}-way Detla1 w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla1_mean'], self.trlog['detla1_interval'])) print('GFSL {}-way Detla1 w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla1_biased_mean'], self.trlog['detla1_biased_interval'])) print('GFSL {}-way Detla2 w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla2_mean'], self.trlog['detla2_interval'])) print('GFSL {}-way Detla2 w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['detla2_biased_mean'], self.trlog['detla2_biased_interval'])) if 'hmeanacc' in self.criteria: self.trlog['HM_mean'], self.trlog[ 'HM_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 0]) self.trlog['S2All_mean'], self.trlog[ 'S2All_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 1]) self.trlog['U2All_mean'], self.trlog[ 'U2All_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 2]) self.trlog['HM_biased_mean'], self.trlog[ 'HM_biased_nterval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 3]) self.trlog['S2All_biased_mean'], self.trlog[ 'S2All_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 4]) self.trlog['U2All_biased_mean'], self.trlog[ 'U2All_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanacc[:, 5]) print('GFSL {}-way HM_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_mean'], self.trlog['HM_interval'])) print('GFSL {}-way HM_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_biased_mean'], self.trlog['HM_biased_nterval'])) print('GFSL {}-way S2All_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_mean'], self.trlog['S2All_interval'])) print('GFSL {}-way S2All_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_biased_mean'], self.trlog['S2All_biased_interval'])) print('GFSL {}-way U2All_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_mean'], self.trlog['U2All_interval'])) print('GFSL {}-way U2All_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_biased_mean'], self.trlog['U2All_biased_interval'])) if 'hmeanmap' in self.criteria: self.trlog['HM_map_mean'], self.trlog[ 'HM_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 0]) self.trlog['S2All_map_mean'], self.trlog[ 'S2All_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 1]) self.trlog['U2All_map_mean'], self.trlog[ 'U2All_map_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 2]) self.trlog['HM_map_biased_mean'], self.trlog[ 'HM_map_biased_nterval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 3]) self.trlog['S2All_map_biased_mean'], self.trlog[ 'S2All_map_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 4]) self.trlog['U2All_map_biased_mean'], self.trlog[ 'U2All_map_biased_interval'] = compute_confidence_interval( generalized_few_shot_hmeanmap[:, 5]) print('GFSL {}-way HM_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_map_mean'], self.trlog['HM_map_interval'])) print('GFSL {}-way HM_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['HM_map_biased_mean'], self.trlog['HM_map_biased_nterval'])) print('GFSL {}-way S2All_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_map_mean'], self.trlog['S2All_map_interval'])) print('GFSL {}-way S2All_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['S2All_map_biased_mean'], self.trlog['S2All_map_biased_interval'])) print('GFSL {}-way U2All_map_mean w/o Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_map_mean'], self.trlog['U2All_map_interval'])) print('GFSL {}-way U2All_map_mean w/ Bias {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['U2All_map_biased_mean'], self.trlog['U2All_map_biased_interval'])) if 'ausuc' in self.criteria: self.trlog['AUSUC_mean'], self.trlog[ 'AUSUC_interval'] = compute_confidence_interval( generalized_few_shot_ausuc[:, 0]) print('GFSL {}-way AUSUC {:.5f} + {:.5f}'.format( args.eval_way, self.trlog['AUSUC_mean'], self.trlog['AUSUC_interval']))
model.train() tl = Averager() ta = Averager() for i, batch in enumerate(train_loader, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, label = [_.cuda() for _ in batch] label = label.type(torch.cuda.LongTensor) else: data, label = batch label = label.type(torch.LongTensor) logits = model(data) loss = criterion(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0: print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format( epoch, i, len(train_loader), loss.item(), acc)) tl.add(loss.item()) ta.add(acc) optimizer.zero_grad() loss.backward() optimizer.step() tl = tl.item() ta = ta.item()
def train(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in tqdm(self.train_loader): data, gt_label = [_.cuda() for _ in batch] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) logits = logits.view(-1, args.way) oh_query = torch.nn.functional.one_hot(label, args.way) sims = logits temp = (sims * oh_query).sum(-1) e_sim_p = temp - self.model.margin e_sim_p_pos = F.relu(e_sim_p) e_sim_p_neg = F.relu(-e_sim_p) l_open_margin = e_sim_p_pos.mean(-1) l_open = e_sim_p_neg.mean(-1) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = total_loss + args.open_balance * l_open tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward(retain_graph=True) self.optimizer_margin.zero_grad() l_open_margin.backward() self.optimizer.step() self.optimizer_margin.step() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\ total_loss.item(), loss.item(), l_open.item(), self.model.margin.item())) self.lr_scheduler.step() self.lr_scheduler_margin.step() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')
def open_evaluate(self, data_loader): # restore model args args = self.args # evaluation mode self.model.eval() record = np.zeros((args.num_test_episodes, 4)) # loss and acc label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query) label = label.type(torch.LongTensor) if torch.cuda.is_available(): label = label.cuda() print('Evaluating ... best epoch {}, SnaTCHer={:.4f} + {:.4f} acc={:.4f} + {:.4f}'.format( self.trlog['max_auc_epoch'], self.trlog['max_auc'], self.trlog['max_auc_interval'], self.trlog['acc'], self.trlog['acc_interval'])) with torch.no_grad(): for i, batch in enumerate(tqdm(data_loader)): data, _ = [_.cuda() for _ in batch] logits = self.para_model(data) logits = logits.reshape([-1, args.eval_way + args.open_eval_way, args.way]) klogits = logits[:, :args.eval_way, :].reshape(-1, args.way) ulogits = logits[:, args.eval_way:, :].reshape(-1, args.way) loss = F.cross_entropy(klogits, label) acc = count_acc(klogits, label) """ Distance """ kdist = -(klogits.max(1)[0]) udist = -(ulogits.max(1)[0]) kdist = kdist.cpu().detach().numpy() udist = udist.cpu().detach().numpy() dist_auroc = calc_auroc(kdist, udist) """ Snatcher """ with torch.no_grad(): instance_embs = self.para_model.instance_embs support_idx = self.para_model.support_idx query_idx = self.para_model.query_idx support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,))) query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,))) emb_dim = support.shape[-1] support = support[:, :, :args.way].contiguous() # get mean of the support bproto = support.mean(dim=1) # Ntask x NK x d proto = self.para_model.slf_attn(bproto, bproto, bproto) kquery = query[:, :, :args.way].contiguous() uquery = query[:, :, args.way:].contiguous() snatch_known = [] for j in range(75): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = klogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = kquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_known.append(pdiff) snatch_unknown = [] for j in range(ulogits.shape[0]): pproto = bproto.clone().detach() """ Algorithm 1 Line 1 """ c = ulogits.argmax(1)[j] """ Algorithm 1 Line 2 """ pproto[0][c] = uquery.reshape(-1, emb_dim)[j] """ Algorithm 1 Line 3 """ pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0] pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0 """ pdiff: d_SnaTCHer in Algorithm 1 """ snatch_unknown.append(pdiff) pkdiff = torch.stack(snatch_known) pudiff = torch.stack(snatch_unknown) pkdiff = pkdiff.cpu().detach().numpy() pudiff = pudiff.cpu().detach().numpy() snatch_auroc = calc_auroc(pkdiff, pudiff) record[i - 1, 0] = loss.item() record[i - 1, 1] = acc record[i - 1, 2] = snatch_auroc record[i - 1, 3] = dist_auroc vl, _ = compute_confidence_interval(record[:, 0]) va, vap = compute_confidence_interval(record[:, 1]) auc_sna, auc_sna_p = compute_confidence_interval(record[:, 2]) auc_dist, auc_dist_p = compute_confidence_interval(record[:, 3]) print("acc: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}" \ .format(va, vap, auc_dist, auc_dist_p, auc_sna, auc_sna_p)) # train mode self.model.train() if self.args.fix_BN: self.model.encoder.eval() return vl, va, vap, auc_sna, auc_sna_p