def update_acc(coef_lst, proto_lst, acc_lst, embs, meta_labels): n_query = len(embs) for coef, proto, acc in zip(coef_lst, proto_lst, acc_lst): if coef != 0: logits = -euclidean_dist(embs, proto, transform=True).view( n_query, len(proto)) top_fs = obtain_accuracy(logits, meta_labels, (1, )) acc.update(top_fs[0].item(), n_query)
def update_acc(coef_lst, proto_lst, acc_lst, lvl_imgs_emb, query_idx, lvl_meta_labels): for coef, proto, acc in zip(coef_lst, proto_lst, acc_lst): if coef > 0: n_query = len(query_idx) logits = -euclidean_dist(lvl_imgs_emb, proto, transform=True).view( n_query, len(proto)) top_fs = obtain_accuracy(logits, lvl_meta_labels.data, (1, )) acc.update(top_fs[0].item(), n_query)
def leveltrain(lr_scheduler, emb_model, att_par, att_chi, criterion, optimizer, logger, dataloader, hierarchy_info, wordid_level_label, epoch, args, mode, n_support, pp_buffer=None): args = deepcopy(args) ancestors, parents, descendants, children = hierarchy_info losses, acc1 = AverageMeter(), AverageMeter() acc_base, acc_par, acc_chi = AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() num_device = len(emb_model.device_ids) if mode == "train": emb_model.train() att_par.train() att_chi.train() elif mode == "test": emb_model.eval() att_par.eval() att_chi.eval() metaval_accuracies = [] else: raise TypeError("invalid mode {:}".format(mode)) for batch_idx, (img_idx, imgs, labels, levels, wordids) in enumerate(dataloader): if mode == "train" and lr_scheduler: lr_scheduler.step() cpu_levels = levels.tolist() cpu_wordids = wordids.tolist() all_levels = list(set(cpu_levels)) all_levels.sort() # get idx, label, meta_label for every level lvls_wordids, lvls_meta_labels, cls_support_idxs, cls_query_idxs = [], [], [], [] for lvl in all_levels: lvl_wordids = sorted( set([ wordid for level, wordid in zip(cpu_levels, cpu_wordids) if level == lvl ])) lvl_idx = [ idx for idx, wd in enumerate(cpu_wordids) if wd in lvl_wordids ] lvl_labels_dicts = {x: i for i, x in enumerate(lvl_wordids)} grouped_s_idxs, grouped_q_idxs = [], [] idxs_dict = defaultdict(list) for i in lvl_idx: wordid = cpu_wordids[i] idxs_dict[wordid].append(i) idxs_dict = dict(sorted(idxs_dict.items())) # for non-few-shot classes, support and query set are the same if lvl < max(all_levels): for wordid, idxs in idxs_dict.items(): grouped_s_idxs.append(torch.IntTensor(idxs)) grouped_q_idxs = grouped_s_idxs else: for wordid, idxs in idxs_dict.items(): grouped_s_idxs.append(torch.IntTensor(idxs[:n_support])) grouped_q_idxs.append(torch.IntTensor(idxs[n_support:])) support_idxs = torch.cat(grouped_s_idxs, dim=0).tolist() query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist() lvl_meta_labels = torch.LongTensor( [lvl_labels_dicts[cpu_wordids[x]] for x in query_idxs]) lvl_meta_labels = lvl_meta_labels.cuda(non_blocking=True) # lvls_wordids cls_support_idxs cls_query_idxs all_level_classes are in the same order, low--high level, in each level, small--big lvls_wordids.append(lvl_wordids) lvls_meta_labels.append(lvl_meta_labels) cls_support_idxs.extend(grouped_s_idxs) cls_query_idxs.extend(grouped_q_idxs) all_level_classes = [c for sublist in lvls_wordids for c in sublist] lvls_embs = emb_model(imgs) if pp_buffer: if mode == "train": proto_base = pp_buffer.pp_running[all_level_classes] elif mode == "test": proto_base_lst = [] for idx, cls in enumerate(all_level_classes): if torch.sum(pp_buffer.pp_running[cls]).item( ) > 0: # common classes s_i = cls_support_idxs[idx] pp_new = 0.5 * pp_buffer.pp_running[ cls] + 0.5 * torch.mean(lvls_embs[s_i.long()], dim=0) proto_base_lst.append(pp_new) else: # non-common classes s_i = cls_support_idxs[idx] proto_base_lst.append( torch.mean(lvls_embs[s_i.long()], dim=0)) proto_base = torch.stack(proto_base_lst, dim=0) else: raise ValueError("invalid mode {}".format(mode)) else: proto_base_lst = [ torch.mean(lvls_embs[s_i.long()], dim=0) for s_i in cls_support_idxs ] proto_base = torch.stack(proto_base_lst, dim=0) if not args.coef_anc: proto_par = 0 else: proto_par = get_att_proto(all_level_classes, parents, proto_base, att_par, args.n_hop) if not args.coef_chi: proto_chi = 0 else: proto_chi = get_att_proto(all_level_classes, children, proto_base, att_chi, args.n_hop) coef_lst = [args.coef_base, args.coef_anc, args.coef_chi] proto_lst = [proto_base, proto_par, proto_chi] acc_lst = [acc_base, acc_par, acc_chi] # p_final if "avg" in args.training_strategy: denominator = args.coef_base + args.coef_anc + args.coef_chi final_proto = (args.coef_base / denominator) * proto_base + ( args.coef_anc / denominator) * proto_par + ( args.coef_chi / denominator) * proto_chi elif "weighted" in args.training_strategy: # TODO: hardcode final_proto = 0.3 * proto_base + 0.7 * proto_par #final_proto = 0.45 * proto_base + 0.45 * proto_par + 0.1 * proto_chi elif "relation" in args.training_strategy: cat_proto = torch.cat([ proto for coef, proto in zip(coef_lst, proto_lst) if coef > 0 ], dim=-1) final_proto = relation_nn(cat_proto) else: raise ValueError( "undefined training_strategy {}, no proto_weight info inside". format(args.training_strategy)) # classification over every level loss_lvls = [] for i, lvl in enumerate(all_levels): lvl_wordid = lvls_wordids[i] idx_start = all_level_classes.index(lvl_wordid[0]) idx_end = all_level_classes.index(lvl_wordid[-1]) protos = final_proto[idx_start:idx_end + 1] query_idx = torch.cat(cls_query_idxs[idx_start:idx_end + 1], dim=0).long() lvl_imgs_emb = lvls_embs[query_idx] lvl_meta_labels = lvls_meta_labels[i] logits = -euclidean_dist(lvl_imgs_emb, protos, transform=True).view( len(query_idx), len(protos)) loss = criterion(logits, lvl_meta_labels) loss_lvls.append(loss) if lvl == max(all_levels): top_fs = obtain_accuracy(logits, lvl_meta_labels.data, (1, )) acc1.update(top_fs[0].item(), len(query_idx)) fine_proto_lst = [] for c, p in zip(coef_lst, proto_lst): if c > 0: fine_proto_lst.append(p[idx_start:idx_end + 1]) else: fine_proto_lst.append(0) update_acc(coef_lst, fine_proto_lst, acc_lst, lvl_imgs_emb, query_idx, lvl_meta_labels) if "multi" in args.training_strategy: for coef, proto in zip(coef_lst, proto_lst): if coef > 0: logits = -euclidean_dist( lvl_imgs_emb, proto, transform=True).view( len(query_idx), len(proto)) loss = criterion(logits, lvl_meta_labels) loss_lvls.append(loss) if "mean" in args.training_strategy: loss = sum(loss_lvls) / len(loss_lvls) elif "single" in args.training_strategy or "multi" in args.training_strategy: loss = sum(loss_lvls) elif "selective" in args.training_strategy: loss = sum(loss_lvls[:-1]) / 10 + loss_lvls[-1] else: raise ValueError( "undefined loss type info in training_strategy : {}".format( args.training_strategy)) losses.update(loss.item(), len(query_idx)) if mode == 'train': optimizer.zero_grad() loss.backward() optimizer.step() elif mode == "test": metaval_accuracies.append(top_fs[0].item()) if batch_idx + 1 == len(dataloader): metaval_accuracies = np.array(metaval_accuracies) stds = np.std(metaval_accuracies, 0) ci95 = 1.96 * stds / np.sqrt(batch_idx + 1) logger.print("ci95 is : {:}".format(ci95)) else: raise ValueError('Invalid mode = {:}'.format(mode)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \ or (mode=="test" and (batch_idx + 1 == len(dataloader))): Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format( data_time=data_time, batch_time=batch_time) Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format( time_string(), mode, epoch, args.epochs, batch_idx, len(dataloader)) Astring = 'loss=({:.3f}, {:.3f}), loss_lvls:{}, loss_min:{:.2f}, loss_max:{:.2f}, loss_mean:{:.2f}, loss_var:{:.2f}, acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@par=({:.1f}, {:.1f}), acc@chi=({:.1f}, {:.1f})'.format( losses.val, losses.avg, [l.item() for l in loss_lvls], min(loss_lvls).item(), max(loss_lvls).item(), torch.mean(torch.stack(loss_lvls)).item(), torch.var(torch.stack(loss_lvls)).item(), acc1.val, acc1.avg, acc_base.val, acc_base.avg, acc_par.val, acc_par.avg, acc_chi.val, acc_chi.avg) Cstring = 'p_base_weigth : {:.4f}; p_par_weight : {:.4f}; p_chi_weight : {:.4f}'.format( args.coef_base, args.coef_anc, args.coef_chi) logger.print('{:} {:} {:} {:} \n'.format(Sstring, Tstring, Astring, Cstring)) return losses, acc1, acc_base, acc_par, acc_chi
def train_model(lr_scheduler, model, criterion, optimizer, logger, dataloader, epoch, args, mode, n_support, n_query,proto_bag): args = deepcopy(args) losses, acc1 = AverageMeter(), AverageMeter() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() num_device = len(model.device_ids) if mode == "train": model.train() elif mode == "test": model.eval() metaval_accuracies = [] else: raise TypeError("invalid mode {:}".format(mode)) for batch_idx, (feas, labels) in enumerate(dataloader): if mode=="train": lr_scheduler.step() cpu_labels = labels.cpu().tolist() idxs_dict = defaultdict(list) for i, l in enumerate(cpu_labels): idxs_dict[l].append(i) idxs_dict = dict(sorted(idxs_dict.items())) grouped_s_idxs, grouped_q_idxs = [], [] for lab, idxs in idxs_dict.items(): grouped_s_idxs.append(torch.LongTensor(idxs[:n_support])) grouped_q_idxs.append(torch.LongTensor(idxs[n_support:])) query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist() embs = model(feas) ## use first n support's mean to get proto (this proto from test dataset?)(why do we only use two?) #proto_lst = [torch.mean( embs[s_idxs], dim=0) for s_idxs in grouped_s_idxs] #proto = torch.stack(proto_lst, dim=0) proto = get_proto(args,model,proto_bag) # classification ## get test featuer after liner transform (data loader random sample?) query_emb = embs[query_idxs] logits = - euclidean_dist(query_emb, proto, transform=True).view(len(query_idxs), len(proto)) query_labels = labels[query_idxs].cuda(non_blocking=True) loss = criterion(logits, query_labels) losses.update(loss.item(), len(query_idxs)) top_fs = obtain_accuracy(logits, query_labels, (1,)) acc1.update(top_fs[0].item(), len(query_idxs)) if mode == 'train': optimizer.zero_grad() loss.backward() optimizer.step() elif mode=="test": metaval_accuracies.append(top_fs[0].item()) if batch_idx + 1 == len(dataloader): metaval_accuracies = np.array(metaval_accuracies) stds = np.std(metaval_accuracies, 0) ci95 = 1.96*stds/np.sqrt(batch_idx + 1) logger.print("ci95 is : {:}".format(ci95)) else: raise ValueError('Invalid mode = {:}'.format( mode )) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \ or (mode=="test" and (batch_idx + 1 == len(dataloader))): Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(data_time=data_time, batch_time=batch_time) Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(time_string(), mode, epoch, args.epochs, batch_idx, len(dataloader)) Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f})'.format(losses.val, losses.avg, acc1.val, acc1.avg) logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring)) return losses, acc1
def graph_nohier_test(mymodel, train_hierarchy_info, criterion, logger, dataloader, epoch, args, n_support, pp_running, n_hop): args = deepcopy(args) losses, acc1 = AverageMeter(), AverageMeter() acc_base, acc_prop = AverageMeter(), AverageMeter() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() num_device = len(mymodel.emb_model.device_ids) mymodel.emb_model.eval(); mymodel.propagation_net.eval() if mymodel.mst_net: mymodel.mst_net.eval() metaval_accuracies = [] for batch_idx, (_, imgs, labels) in enumerate(dataloader): assert len(set(labels.tolist())) == args.classes_per_it_tr embs = mymodel.emb_model(imgs) n_train_classes = len(train_hierarchy_info[0]) target_classes = list( set(labels.tolist())) test_pp_list = [] grouped_s_idxs, grouped_q_idxs = [], [] for cls in target_classes: all_idxs = (labels == cls).nonzero().view(-1).tolist() s_idx = all_idxs[:n_support]; q_idx = all_idxs[n_support:] grouped_s_idxs.append(torch.IntTensor(s_idx)); grouped_q_idxs.append(torch.IntTensor(q_idx)) test_pp_list.append(torch.mean( embs[s_idx], dim=0)) test_pp = torch.stack(test_pp_list, dim=0) # get nodes label2metalabel, forward_adj, backward_adj = construct_propagation_graph(target_classes, train_hierarchy_info, test_pp, pp_running[:n_train_classes], args.n_chi, args.n_hop, 5, 0.5) base_proto = pp_running[list( label2metalabel.keys() )] # input a list indices will create a clone for cls, pp in zip(target_classes, test_pp_list): base_proto[ label2metalabel[cls] ] = pp if "mst" in args.training_strategy: features = base_proto forward_adj = forward_adj.cuda(); backward_adj = backward_adj.cuda() if mymodel.mst_net: distance_forward = mymodel.mst_net(features, forward_adj) distance_backward = mymodel.mst_net(features, backward_adj) else: node_num, feat_dim = features.size() q = features.view(1, node_num, feat_dim) k = features.view(node_num, 1, feat_dim) distance_forward = distance_backward = torch.norm(q - k, p='fro', dim=2) #print("[before mst]-foward_adj is {}".format(forward_adj)) #print("[before mst]-backward_adj is {}".format(backward_adj)) # get edges forward_adj = get_max_spanning_tree_kruskal(forward_adj, distance_forward) backward_adj = get_max_spanning_tree_kruskal(backward_adj, distance_backward) #print("[after mst]-foward_adj is {}".format(forward_adj)) #print("[after mst]-backward_adj is {}".format(backward_adj)) # propagation if "single-round" in args.training_strategy: adj_lst = [forward_adj for i in range(args.n_hop)] + [backward_adj for i in range(args.n_hop)] elif "multi-round" in args.training_strategy: adj_lst = [] for i in range(args.n_hop): adj_lst.append(forward_adj) adj_lst.append(backward_adj) elif "allaround" in args.training_strategy: all_adj = forward_adj + backward_adj adj_lst = [all_adj for i in range(args.n_hop)] elif "only-forward" in args.training_strategy: adj_lst = [forward_adj for i in range(args.n_hop)] elif "only-backward" in args.training_strategy: adj_lst = [backward_adj for i in range(args.n_hop)] else: raise ValueError("invalid training_strategy for adj : {}".format(args.training_strategy)) prop_proto = mymodel.propagation_net(base_proto, adj_lst*num_device) target_prop_proto = prop_proto[:len(target_classes)] target_base_proto = base_proto[:len(target_classes)] if args.coef_base == -1 and args.coef_prop == -1: if epoch == -1: final_proto = target_prop_proto else: coef_norm = args.coef_base + args.coef_prop final_proto = target_base_proto * args.coef_base / coef_norm + target_prop_proto * args.coef_prop / coef_norm query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist() query_meta_labels = torch.LongTensor( [label2metalabel[labels[i].item()] for i in query_idxs] ).cuda(non_blocking=True) logits = - euclidean_dist(embs[query_idxs], final_proto, transform=True).view(len(query_idxs), len(target_classes)) loss = criterion(logits, query_meta_labels) losses.update(loss.item(), len(query_idxs)) top_fs = obtain_accuracy(logits, query_meta_labels, (1,)) acc1.update(top_fs[0].item(), len(query_idxs)) update_acc([args.coef_base, args.coef_prop], [target_base_proto, target_prop_proto], [acc_base, acc_prop], embs[query_idxs], query_meta_labels) metaval_accuracies.append(top_fs[0].item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx + 1 == len(dataloader): metaval_accuracies = np.array(metaval_accuracies) stds = np.std(metaval_accuracies, 0) ci95 = 1.96*stds/np.sqrt(batch_idx + 1) logger.print("ci95 is : {:}".format(ci95)) Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(data_time=data_time, batch_time=batch_time) Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(time_string(), "test", epoch, args.epochs, batch_idx, len(dataloader)) Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@prop=({:.1f}, {:.1f})'.format(losses.val, losses.avg, acc1.val, acc1.avg, acc_base.val, acc_base.avg, acc_prop.val, acc_prop.avg) Cstring = 'p_base_weigth : {:.4f}; p_prop_weight : {:.4f} '.format(args.coef_base, args.coef_prop) logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring)) return losses, acc1, acc_base, acc_prop
def graphtrain(lr_scheduler, mymodel, criterion, optimizer, logger, dataloader, hierarchy_info, epoch, args, mode, n_support, pp_running, n_hop): args = deepcopy(args) losses, acc1 = AverageMeter(), AverageMeter() acc_base, acc_prop = AverageMeter(), AverageMeter() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() num_device = len(mymodel.emb_model.device_ids) if mode == "train": mymodel.emb_model.train() mymodel.propagation_net.train() if mymodel.mst_net: mymodel.mst_net.train() parents, children = hierarchy_info elif mode == "test": mymodel.emb_model.eval() mymodel.propagation_net.eval() if mymodel.mst_net: mymodel.mst_net.eval() parents, children, all_train_classes = hierarchy_info metaval_accuracies = [] else: raise TypeError("invalid mode {:}".format(mode)) for batch_idx, (_, imgs, labels) in enumerate(dataloader): if mode == "train" and len(imgs) == args.batch_size and len( set(labels.tolist())) != args.classes_per_it_tr: # Curriculum many-shot learning, do not update few-shot loss and acc mymodel.emb_model.module.use_classifier = True prediction = mymodel.emb_model(imgs) mymodel.emb_model.module.use_classifier = False labels = labels.cuda() loss = criterion(prediction, labels) optimizer.zero_grad() loss.backward() optimizer.step() else: if mode == "train" and lr_scheduler: if not type(lr_scheduler) == list: lr_scheduler = [lr_scheduler] for l_s in lr_scheduler: l_s.step() embs = mymodel.emb_model(imgs) with torch.no_grad(): target_classes = list(set(labels.tolist())) # get nodes label2metalabel, forward_adj, backward_adj = get_real_propagation_graph( target_classes, hierarchy_info, args.n_hop, 5, 0.5 ) # if the num of children or parents > 5, only select half of them base_proto = pp_running[list(label2metalabel.keys( ))] # input a list indices will create a clone grouped_s_idxs, grouped_q_idxs, target_idx = [], [], [] for cls in target_classes: all_idxs = (labels == cls).nonzero().view(-1).tolist() s_idx = all_idxs[:n_support] q_idx = all_idxs[n_support:] grouped_s_idxs.append(torch.IntTensor(s_idx)) grouped_q_idxs.append(torch.IntTensor(q_idx)) base_proto[label2metalabel[cls]] = torch.mean(embs[s_idx], dim=0) if "mst" in args.training_strategy: # get edges features = base_proto forward_adj = forward_adj.cuda() backward_adj = backward_adj.cuda() if mymodel.mst_net: distance_forward = -mymodel.mst_net( features, forward_adj) distance_backward = -mymodel.mst_net( features, backward_adj) else: # euclidean distance node_num, feat_dim = features.size() q = features.view(1, node_num, feat_dim) k = features.view(node_num, 1, feat_dim) distance_forward = distance_backward = torch.norm( q - k, p='fro', dim=2) #print("[before mst]-foward_adj is {}".format(forward_adj)) #print("[before mst]-backward_adj is {}".format(backward_adj)) forward_adj = get_max_spanning_tree_kruskal( forward_adj, distance_forward) backward_adj = get_max_spanning_tree_kruskal( backward_adj, distance_backward) #print("[after mst]-foward_adj is {}".format(forward_adj)) #print("[after mst]-backward_adj is {}".format(backward_adj)) # propagation if "single-round" in args.training_strategy: adj_lst = [forward_adj for i in range(args.n_hop) ] + [backward_adj for i in range(args.n_hop)] elif "multi-round" in args.training_strategy: adj_lst = [] for i in range(args.n_hop): adj_lst.append(forward_adj) adj_lst.append(backward_adj) elif "allaround" in args.training_strategy: all_adj = forward_adj + backward_adj adj_lst = [all_adj for i in range(args.n_hop)] elif "only-forward" in args.training_strategy: adj_lst = [forward_adj for i in range(args.n_hop)] elif "only-backward" in args.training_strategy: adj_lst = [backward_adj for i in range(args.n_hop)] else: raise ValueError( "invalid training_strategy for adj : {}".format( args.training_strategy)) prop_proto = mymodel.propagation_net(base_proto, adj_lst * num_device) target_prop_proto = prop_proto[:len(target_classes)] target_base_proto = base_proto[:len(target_classes)] if args.coef_base == -1 and args.coef_prop == -1: if epoch == -1: final_proto = target_prop_proto else: final_proto = target_base_proto * ( 1 - epoch / args.epochs) + target_prop_proto * epoch / args.epochs else: coef_norm = args.coef_base + args.coef_prop final_proto = target_base_proto * args.coef_base / coef_norm + target_prop_proto * args.coef_prop / coef_norm query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist() query_meta_labels = torch.LongTensor([ label2metalabel[labels[i].item()] for i in query_idxs ]).cuda(non_blocking=True) logits = -euclidean_dist( embs[query_idxs], final_proto, transform=True).view( len(query_idxs), len(target_classes)) loss = criterion(logits, query_meta_labels) losses.update(loss.item(), len(query_idxs)) if mode == "train": # train on selected classes in the graph optimizer.zero_grad() loss.backward() optimizer.step() top_fs = obtain_accuracy(logits, query_meta_labels, (1, )) acc1.update(top_fs[0].item(), len(query_idxs)) update_acc([args.coef_base, args.coef_prop], [target_base_proto, target_prop_proto], [acc_base, acc_prop], embs[query_idxs], query_meta_labels) elif mode == "test": with torch.no_grad(): # only test few-shot level, but the loss is still graph loss top_fs = obtain_accuracy(logits, query_meta_labels, (1, )) acc1.update(top_fs[0].item(), len(query_idxs)) update_acc([args.coef_base, args.coef_prop], [target_base_proto, target_prop_proto], [acc_base, acc_prop], embs[query_idxs], query_meta_labels) metaval_accuracies.append(top_fs[0].item()) if batch_idx + 1 == len(dataloader): metaval_accuracies = np.array(metaval_accuracies) stds = np.std(metaval_accuracies, 0) ci95 = 1.96 * stds / np.sqrt(batch_idx + 1) logger.print("ci95 is : {:}".format(ci95)) else: raise ValueError("undefined mode : {}".format(mode)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \ or (mode=="test" and (batch_idx + 1 == len(dataloader))): Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format( data_time=data_time, batch_time=batch_time) Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format( time_string(), mode, epoch, args.epochs, batch_idx, len(dataloader)) Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@prop=({:.1f}, {:.1f})'.format( losses.val, losses.avg, acc1.val, acc1.avg, acc_base.val, acc_base.avg, acc_prop.val, acc_prop.avg) Cstring = 'p_base_weigth : {:.4f}; p_prop_weight : {:.4f} '.format( args.coef_base, args.coef_prop) logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring)) return losses, acc1, acc_base, acc_prop
def test_no_hierarchy(lr_scheduler, par_ratio, emb_model, att_par, att_chi, criterion, optimizer, logger, dataloader, hierarchy_info, wordid_level_label, epoch, args, mode, n_support, pp_buffer, prop_proto): args = deepcopy(args) ancestors, parents, descendants, children = hierarchy_info losses, acc1 = AverageMeter(), AverageMeter() acc_base, acc_par, acc_chi = AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() num_device = len(emb_model.device_ids) par_ratio = par_ratio / 10 #logger.print("par_ratio is {}".format(par_ratio)) if mode == "test": emb_model.eval() att_par.eval() att_chi.eval() metaval_accuracies = [] else: raise TypeError("invalid mode {:}".format(mode)) for batch_idx, (img_idx, imgs, labels, levels, wordids) in enumerate(dataloader): cpu_levels = levels.tolist() cpu_wordids = wordids.tolist() all_levels = list(set(cpu_levels)) all_levels.sort() lvls_embs = emb_model(imgs) lvls_wordids = [] # get idx, label, meta_label for every level for lvl in all_levels: lvl_wordids = sorted( set([ wordid for level, wordid in zip(cpu_levels, cpu_wordids) if level == lvl ])) lvls_wordids.append(lvl_wordids) if lvl == max(all_levels): lvl_idx = [ idx for idx, wd in enumerate(cpu_wordids) if wd in lvl_wordids ] lvl_labels_dicts = {x: i for i, x in enumerate(lvl_wordids)} grouped_s_idxs, grouped_q_idxs, few_shot_pp = [], [], [] idxs_dict = defaultdict(list) for i in lvl_idx: wordid = cpu_wordids[i] idxs_dict[wordid].append(i) idxs_dict = dict(sorted(idxs_dict.items())) for wordid, idxs in idxs_dict.items(): s_idxs = idxs[:n_support] grouped_s_idxs.append(torch.IntTensor(s_idxs)) grouped_q_idxs.append(torch.IntTensor(idxs[n_support:])) few_shot_pp.append(torch.mean(lvls_embs[s_idxs], dim=0)) support_idxs = torch.cat(grouped_s_idxs, dim=0).tolist() query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist() lvl_meta_labels = torch.LongTensor( [lvl_labels_dicts[cpu_wordids[x]] for x in query_idxs]) lvl_meta_labels = lvl_meta_labels.cuda(non_blocking=True) # lvls_wordids cls_support_idxs cls_query_idxs all_level_classes are in the same order, low--high level, in each level, small--big all_level_classes = [c for sublist in lvls_wordids for c in sublist] ''' if pp_buffer: if mode =="test": proto_lst = [] for idx, cls in enumerate(all_level_classes): if torch.sum( pp_buffer.pp_running[cls] ).item() > 0: # common classes proto_lst.append(pp_buffer.pp_running[cls]) else: raise ValueError("invalid mode {}".format(mode)) else: raise ValueError("invalid : no buffer ") ''' candidates = pp_buffer.pp_running voting = False #voting = True n_nei = 3 # classification over every level loss_lvls = [] for i, lvl in enumerate(all_levels): if lvl == max(all_levels): if voting: final_proto = get_att_proto_vote(lvls_embs, grouped_s_idxs, few_shot_pp, candidates, prop_proto, att_par, par_ratio, n_nei) else: final_proto = get_att_proto(few_shot_pp, candidates, prop_proto, att_par, par_ratio, n_nei) lvl_imgs_emb = lvls_embs[query_idxs] logits = -euclidean_dist( lvl_imgs_emb, final_proto, transform=True).view( len(query_idxs), len(final_proto)) loss = criterion(logits, lvl_meta_labels) loss_lvls.append(loss) top_fs = obtain_accuracy(logits, lvl_meta_labels.data, (1, )) acc1.update(top_fs[0].item(), len(query_idxs)) loss = sum(loss_lvls) losses.update(loss.item(), len(query_idxs)) if mode == 'train': optimizer.zero_grad() loss.backward() optimizer.step() elif mode == "test": metaval_accuracies.append(top_fs[0].item()) if batch_idx + 1 == len(dataloader): metaval_accuracies = np.array(metaval_accuracies) stds = np.std(metaval_accuracies, 0) ci95 = 1.96 * stds / np.sqrt(batch_idx + 1) #logger.print("ci95 is : {:}".format(ci95)) else: raise ValueError('Invalid mode = {:}'.format(mode)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() ''' if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \ or (mode=="test" and (batch_idx + 1 == len(dataloader))): Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(data_time=data_time, batch_time=batch_time) Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(time_string(), mode, epoch, args.epochs, batch_idx, len(dataloader)) Astring = 'loss=({:.3f}, {:.3f}), loss_lvls:{}, loss_min:{:.2f}, loss_max:{:.2f}, loss_mean:{:.2f}, loss_var:{:.2f}, acc@1=({:.2f}, {:.2f}), acc@base=({:.1f}, {:.1f}), acc@par=({:.1f}, {:.1f}), acc@chi=({:.1f}, {:.1f})'.format(losses.val, losses.avg, [l.item() for l in loss_lvls], min(loss_lvls).item(), max(loss_lvls).item(), torch.mean(torch.stack(loss_lvls)).item(), torch.var(torch.stack(loss_lvls)).item(), acc1.val, acc1.avg, acc_base.val, acc_base.avg, acc_par.val, acc_par.avg, acc_chi.val, acc_chi.avg) Cstring = 'p_base_weigth : {:.4f}; p_par_weight : {:.4f}; p_chi_weight : {:.4f}'.format(args.coef_base, args.coef_anc, args.coef_chi) logger.print('{:} {:} {:} {:} \n'.format(Sstring, Tstring, Astring, Cstring)) ''' return losses, acc1, ci95