Exemplo n.º 1
0
def leveltrain(lr_scheduler,
               emb_model,
               att_par,
               att_chi,
               criterion,
               optimizer,
               logger,
               dataloader,
               hierarchy_info,
               wordid_level_label,
               epoch,
               args,
               mode,
               n_support,
               pp_buffer=None):

    args = deepcopy(args)
    ancestors, parents, descendants, children = hierarchy_info
    losses, acc1 = AverageMeter(), AverageMeter()
    acc_base, acc_par, acc_chi = AverageMeter(), AverageMeter(), AverageMeter()
    data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
    num_device = len(emb_model.device_ids)

    if mode == "train":
        emb_model.train()
        att_par.train()
        att_chi.train()
    elif mode == "test":
        emb_model.eval()
        att_par.eval()
        att_chi.eval()
        metaval_accuracies = []
    else:
        raise TypeError("invalid mode {:}".format(mode))

    for batch_idx, (img_idx, imgs, labels, levels,
                    wordids) in enumerate(dataloader):
        if mode == "train" and lr_scheduler:
            lr_scheduler.step()
        cpu_levels = levels.tolist()
        cpu_wordids = wordids.tolist()
        all_levels = list(set(cpu_levels))
        all_levels.sort()

        # get idx, label, meta_label for every level
        lvls_wordids, lvls_meta_labels, cls_support_idxs, cls_query_idxs = [], [], [], []
        for lvl in all_levels:
            lvl_wordids = sorted(
                set([
                    wordid for level, wordid in zip(cpu_levels, cpu_wordids)
                    if level == lvl
                ]))
            lvl_idx = [
                idx for idx, wd in enumerate(cpu_wordids) if wd in lvl_wordids
            ]
            lvl_labels_dicts = {x: i for i, x in enumerate(lvl_wordids)}

            grouped_s_idxs, grouped_q_idxs = [], []
            idxs_dict = defaultdict(list)
            for i in lvl_idx:
                wordid = cpu_wordids[i]
                idxs_dict[wordid].append(i)
            idxs_dict = dict(sorted(idxs_dict.items()))
            # for non-few-shot classes, support and query set are the same
            if lvl < max(all_levels):
                for wordid, idxs in idxs_dict.items():
                    grouped_s_idxs.append(torch.IntTensor(idxs))
                grouped_q_idxs = grouped_s_idxs
            else:
                for wordid, idxs in idxs_dict.items():
                    grouped_s_idxs.append(torch.IntTensor(idxs[:n_support]))
                    grouped_q_idxs.append(torch.IntTensor(idxs[n_support:]))
            support_idxs = torch.cat(grouped_s_idxs, dim=0).tolist()
            query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist()
            lvl_meta_labels = torch.LongTensor(
                [lvl_labels_dicts[cpu_wordids[x]] for x in query_idxs])
            lvl_meta_labels = lvl_meta_labels.cuda(non_blocking=True)
            # lvls_wordids cls_support_idxs cls_query_idxs all_level_classes are in the same order, low--high level, in each level, small--big
            lvls_wordids.append(lvl_wordids)
            lvls_meta_labels.append(lvl_meta_labels)
            cls_support_idxs.extend(grouped_s_idxs)
            cls_query_idxs.extend(grouped_q_idxs)
        all_level_classes = [c for sublist in lvls_wordids for c in sublist]

        lvls_embs = emb_model(imgs)
        if pp_buffer:
            if mode == "train":
                proto_base = pp_buffer.pp_running[all_level_classes]
            elif mode == "test":
                proto_base_lst = []
                for idx, cls in enumerate(all_level_classes):
                    if torch.sum(pp_buffer.pp_running[cls]).item(
                    ) > 0:  # common classes
                        s_i = cls_support_idxs[idx]
                        pp_new = 0.5 * pp_buffer.pp_running[
                            cls] + 0.5 * torch.mean(lvls_embs[s_i.long()],
                                                    dim=0)
                        proto_base_lst.append(pp_new)
                    else:  # non-common classes
                        s_i = cls_support_idxs[idx]
                        proto_base_lst.append(
                            torch.mean(lvls_embs[s_i.long()], dim=0))
                proto_base = torch.stack(proto_base_lst, dim=0)
            else:
                raise ValueError("invalid mode {}".format(mode))
        else:
            proto_base_lst = [
                torch.mean(lvls_embs[s_i.long()], dim=0)
                for s_i in cls_support_idxs
            ]
            proto_base = torch.stack(proto_base_lst, dim=0)

        if not args.coef_anc:
            proto_par = 0
        else:
            proto_par = get_att_proto(all_level_classes, parents, proto_base,
                                      att_par, args.n_hop)
        if not args.coef_chi:
            proto_chi = 0
        else:
            proto_chi = get_att_proto(all_level_classes, children, proto_base,
                                      att_chi, args.n_hop)

        coef_lst = [args.coef_base, args.coef_anc, args.coef_chi]
        proto_lst = [proto_base, proto_par, proto_chi]
        acc_lst = [acc_base, acc_par, acc_chi]
        # p_final
        if "avg" in args.training_strategy:
            denominator = args.coef_base + args.coef_anc + args.coef_chi
            final_proto = (args.coef_base / denominator) * proto_base + (
                args.coef_anc / denominator) * proto_par + (
                    args.coef_chi / denominator) * proto_chi
        elif "weighted" in args.training_strategy:
            # TODO: hardcode
            final_proto = 0.3 * proto_base + 0.7 * proto_par
            #final_proto = 0.45 * proto_base + 0.45 * proto_par + 0.1 * proto_chi
        elif "relation" in args.training_strategy:
            cat_proto = torch.cat([
                proto for coef, proto in zip(coef_lst, proto_lst) if coef > 0
            ],
                                  dim=-1)
            final_proto = relation_nn(cat_proto)
        else:
            raise ValueError(
                "undefined training_strategy {}, no proto_weight info inside".
                format(args.training_strategy))

        # classification over every level
        loss_lvls = []
        for i, lvl in enumerate(all_levels):
            lvl_wordid = lvls_wordids[i]
            idx_start = all_level_classes.index(lvl_wordid[0])
            idx_end = all_level_classes.index(lvl_wordid[-1])
            protos = final_proto[idx_start:idx_end + 1]
            query_idx = torch.cat(cls_query_idxs[idx_start:idx_end + 1],
                                  dim=0).long()
            lvl_imgs_emb = lvls_embs[query_idx]
            lvl_meta_labels = lvls_meta_labels[i]
            logits = -euclidean_dist(lvl_imgs_emb, protos,
                                     transform=True).view(
                                         len(query_idx), len(protos))
            loss = criterion(logits, lvl_meta_labels)
            loss_lvls.append(loss)
            if lvl == max(all_levels):
                top_fs = obtain_accuracy(logits, lvl_meta_labels.data, (1, ))
                acc1.update(top_fs[0].item(), len(query_idx))
                fine_proto_lst = []
                for c, p in zip(coef_lst, proto_lst):
                    if c > 0:
                        fine_proto_lst.append(p[idx_start:idx_end + 1])
                    else:
                        fine_proto_lst.append(0)
                update_acc(coef_lst, fine_proto_lst, acc_lst, lvl_imgs_emb,
                           query_idx, lvl_meta_labels)
            if "multi" in args.training_strategy:
                for coef, proto in zip(coef_lst, proto_lst):
                    if coef > 0:
                        logits = -euclidean_dist(
                            lvl_imgs_emb, proto, transform=True).view(
                                len(query_idx), len(proto))
                        loss = criterion(logits, lvl_meta_labels)
                        loss_lvls.append(loss)
        if "mean" in args.training_strategy:
            loss = sum(loss_lvls) / len(loss_lvls)
        elif "single" in args.training_strategy or "multi" in args.training_strategy:
            loss = sum(loss_lvls)
        elif "selective" in args.training_strategy:
            loss = sum(loss_lvls[:-1]) / 10 + loss_lvls[-1]
        else:
            raise ValueError(
                "undefined loss type info in training_strategy : {}".format(
                    args.training_strategy))

        losses.update(loss.item(), len(query_idx))

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        elif mode == "test":
            metaval_accuracies.append(top_fs[0].item())
            if batch_idx + 1 == len(dataloader):
                metaval_accuracies = np.array(metaval_accuracies)
                stds = np.std(metaval_accuracies, 0)
                ci95 = 1.96 * stds / np.sqrt(batch_idx + 1)
                logger.print("ci95 is : {:}".format(ci95))
        else:
            raise ValueError('Invalid mode = {:}'.format(mode))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \
        or (mode=="test" and (batch_idx + 1 == len(dataloader))):
            Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(
                data_time=data_time, batch_time=batch_time)
            Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(
                time_string(), mode, epoch, args.epochs, batch_idx,
                len(dataloader))
            Astring = 'loss=({:.3f}, {:.3f}), loss_lvls:{}, loss_min:{:.2f}, loss_max:{:.2f}, loss_mean:{:.2f}, loss_var:{:.2f}, acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@par=({:.1f}, {:.1f}), acc@chi=({:.1f}, {:.1f})'.format(
                losses.val, losses.avg, [l.item() for l in loss_lvls],
                min(loss_lvls).item(),
                max(loss_lvls).item(),
                torch.mean(torch.stack(loss_lvls)).item(),
                torch.var(torch.stack(loss_lvls)).item(), acc1.val, acc1.avg,
                acc_base.val, acc_base.avg, acc_par.val, acc_par.avg,
                acc_chi.val, acc_chi.avg)
            Cstring = 'p_base_weigth : {:.4f}; p_par_weight : {:.4f}; p_chi_weight : {:.4f}'.format(
                args.coef_base, args.coef_anc, args.coef_chi)

            logger.print('{:} {:} {:} {:} \n'.format(Sstring, Tstring, Astring,
                                                     Cstring))
    return losses, acc1, acc_base, acc_par, acc_chi
def graph_nohier_test(mymodel, train_hierarchy_info, criterion, logger, dataloader, epoch, args, n_support, pp_running, n_hop):
  args = deepcopy(args)
  losses, acc1 = AverageMeter(), AverageMeter()
  acc_base, acc_prop = AverageMeter(), AverageMeter()
  data_time, batch_time, end  = AverageMeter(), AverageMeter(), time.time()
  num_device = len(mymodel.emb_model.device_ids)
  
  mymodel.emb_model.eval(); mymodel.propagation_net.eval()
  if mymodel.mst_net: mymodel.mst_net.eval()
  metaval_accuracies = []

  for batch_idx, (_, imgs, labels) in enumerate(dataloader):
    assert len(set(labels.tolist())) == args.classes_per_it_tr
    embs  = mymodel.emb_model(imgs)
    n_train_classes = len(train_hierarchy_info[0])
    target_classes  = list( set(labels.tolist()))
    test_pp_list = []
    grouped_s_idxs, grouped_q_idxs = [], []
    for cls in target_classes:
      all_idxs = (labels == cls).nonzero().view(-1).tolist()
      s_idx = all_idxs[:n_support]; q_idx = all_idxs[n_support:]
      grouped_s_idxs.append(torch.IntTensor(s_idx)); grouped_q_idxs.append(torch.IntTensor(q_idx))
      test_pp_list.append(torch.mean( embs[s_idx], dim=0))
    test_pp = torch.stack(test_pp_list, dim=0)
    # get nodes
    label2metalabel, forward_adj, backward_adj = construct_propagation_graph(target_classes, train_hierarchy_info, test_pp, pp_running[:n_train_classes], args.n_chi, args.n_hop, 5, 0.5)
    base_proto = pp_running[list( label2metalabel.keys() )] # input a list indices will create a clone
    for cls, pp in zip(target_classes, test_pp_list):
      base_proto[ label2metalabel[cls] ] = pp 
  
    if "mst" in args.training_strategy:
      features = base_proto
      forward_adj = forward_adj.cuda(); backward_adj = backward_adj.cuda()
      if mymodel.mst_net: 
        distance_forward  = mymodel.mst_net(features, forward_adj)
        distance_backward = mymodel.mst_net(features, backward_adj)
      else:
        node_num, feat_dim = features.size()
        q = features.view(1, node_num, feat_dim)
        k = features.view(node_num, 1, feat_dim)
        distance_forward = distance_backward = torch.norm(q - k, p='fro', dim=2)

      #print("[before mst]-foward_adj is {}".format(forward_adj))
      #print("[before mst]-backward_adj is {}".format(backward_adj))
      # get edges
      forward_adj  = get_max_spanning_tree_kruskal(forward_adj, distance_forward)
      backward_adj = get_max_spanning_tree_kruskal(backward_adj, distance_backward)
      #print("[after mst]-foward_adj is {}".format(forward_adj))
      #print("[after mst]-backward_adj is {}".format(backward_adj))
    # propagation 
    if "single-round" in args.training_strategy:
      adj_lst = [forward_adj for i in range(args.n_hop)] + [backward_adj for i in range(args.n_hop)]
    elif "multi-round" in args.training_strategy:
      adj_lst = []
      for i in range(args.n_hop):
        adj_lst.append(forward_adj)
        adj_lst.append(backward_adj)
    elif "allaround" in args.training_strategy:
      all_adj = forward_adj + backward_adj 
      adj_lst = [all_adj for i in range(args.n_hop)] 
    elif "only-forward" in args.training_strategy:
      adj_lst = [forward_adj for i in range(args.n_hop)]
    elif "only-backward" in args.training_strategy:
      adj_lst = [backward_adj for i in range(args.n_hop)]
    else: raise ValueError("invalid training_strategy for adj : {}".format(args.training_strategy))
    prop_proto  = mymodel.propagation_net(base_proto, adj_lst*num_device)
    target_prop_proto  = prop_proto[:len(target_classes)]
    target_base_proto  = base_proto[:len(target_classes)]
    if args.coef_base == -1 and args.coef_prop == -1:
      if epoch == -1:
        final_proto = target_prop_proto
    else:
      coef_norm = args.coef_base + args.coef_prop
      final_proto = target_base_proto * args.coef_base / coef_norm + target_prop_proto * args.coef_prop / coef_norm
    query_idxs  = torch.cat(grouped_q_idxs, dim=0).tolist()
    query_meta_labels  =  torch.LongTensor( [label2metalabel[labels[i].item()] for i in query_idxs] ).cuda(non_blocking=True)
    logits      = - euclidean_dist(embs[query_idxs], final_proto, transform=True).view(len(query_idxs), len(target_classes))
    loss        = criterion(logits, query_meta_labels)
    losses.update(loss.item(), len(query_idxs))

    top_fs       = obtain_accuracy(logits, query_meta_labels, (1,))
    acc1.update(top_fs[0].item(), len(query_idxs))
    update_acc([args.coef_base, args.coef_prop], [target_base_proto, target_prop_proto], [acc_base, acc_prop], embs[query_idxs], query_meta_labels)
    metaval_accuracies.append(top_fs[0].item())

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()
    if batch_idx + 1 == len(dataloader):
      metaval_accuracies = np.array(metaval_accuracies)
      stds = np.std(metaval_accuracies, 0)
      ci95 = 1.96*stds/np.sqrt(batch_idx + 1)
      logger.print("ci95 is : {:}".format(ci95))
      Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(data_time=data_time, batch_time=batch_time)
      Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(time_string(), "test", epoch, args.epochs, batch_idx, len(dataloader))
      Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@prop=({:.1f}, {:.1f})'.format(losses.val, losses.avg, acc1.val, acc1.avg, acc_base.val, acc_base.avg, acc_prop.val, acc_prop.avg)
      Cstring = 'p_base_weigth : {:.4f}; p_prop_weight : {:.4f} '.format(args.coef_base, args.coef_prop)
      logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring))
  return losses, acc1, acc_base, acc_prop
Exemplo n.º 3
0
def train_model(lr_scheduler, model, criterion, optimizer, logger, dataloader, epoch, args, mode, n_support, n_query,proto_bag):
  args = deepcopy(args)
  losses, acc1 = AverageMeter(), AverageMeter()
  data_time, batch_time, end  = AverageMeter(), AverageMeter(), time.time()
  num_device = len(model.device_ids)

  if mode == "train":
    model.train()
  elif mode == "test":
    model.eval()
    metaval_accuracies = []
  else: raise TypeError("invalid mode {:}".format(mode))

  for batch_idx, (feas, labels) in enumerate(dataloader):
    if mode=="train":
      lr_scheduler.step()

    cpu_labels = labels.cpu().tolist()
    idxs_dict = defaultdict(list)
    for i, l in enumerate(cpu_labels):
      idxs_dict[l].append(i)
    idxs_dict = dict(sorted(idxs_dict.items()))
    grouped_s_idxs, grouped_q_idxs = [], []
    for lab, idxs in idxs_dict.items():
      grouped_s_idxs.append(torch.LongTensor(idxs[:n_support]))
      grouped_q_idxs.append(torch.LongTensor(idxs[n_support:]))
    query_idxs   = torch.cat(grouped_q_idxs, dim=0).tolist()
 
    embs  = model(feas)
    ## use first n support's mean to get proto (this proto from test dataset?)(why do we only use two?)
    

        
    #proto_lst = [torch.mean( embs[s_idxs], dim=0) for s_idxs in grouped_s_idxs]  
    #proto = torch.stack(proto_lst, dim=0)
    
    proto = get_proto(args,model,proto_bag)

    # classification 
    ## get test featuer after liner transform (data loader random sample?)
    query_emb    = embs[query_idxs] 
    logits       = - euclidean_dist(query_emb, proto, transform=True).view(len(query_idxs), len(proto))
    query_labels = labels[query_idxs].cuda(non_blocking=True)
    loss         = criterion(logits, query_labels)
    losses.update(loss.item(), len(query_idxs))  

    top_fs        = obtain_accuracy(logits, query_labels, (1,))
    acc1.update(top_fs[0].item(),  len(query_idxs))

    if mode == 'train':
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    elif mode=="test":
      metaval_accuracies.append(top_fs[0].item())
      if batch_idx + 1 == len(dataloader):
        metaval_accuracies = np.array(metaval_accuracies)
        stds = np.std(metaval_accuracies, 0)
        ci95 = 1.96*stds/np.sqrt(batch_idx + 1)
        logger.print("ci95 is : {:}".format(ci95))
    else: raise ValueError('Invalid mode = {:}'.format( mode ))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()
    if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \
    or (mode=="test" and (batch_idx + 1 == len(dataloader))):
      Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(data_time=data_time, batch_time=batch_time)
      Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(time_string(), mode, epoch, args.epochs, batch_idx, len(dataloader))
      Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f})'.format(losses.val, losses.avg, acc1.val, acc1.avg)

      logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring))
  
  return losses, acc1
Exemplo n.º 4
0
def graphtrain(lr_scheduler, mymodel, criterion, optimizer, logger, dataloader,
               hierarchy_info, epoch, args, mode, n_support, pp_running,
               n_hop):
    args = deepcopy(args)
    losses, acc1 = AverageMeter(), AverageMeter()
    acc_base, acc_prop = AverageMeter(), AverageMeter()
    data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
    num_device = len(mymodel.emb_model.device_ids)

    if mode == "train":
        mymodel.emb_model.train()
        mymodel.propagation_net.train()
        if mymodel.mst_net: mymodel.mst_net.train()
        parents, children = hierarchy_info
    elif mode == "test":
        mymodel.emb_model.eval()
        mymodel.propagation_net.eval()
        if mymodel.mst_net: mymodel.mst_net.eval()
        parents, children, all_train_classes = hierarchy_info
        metaval_accuracies = []
    else:
        raise TypeError("invalid mode {:}".format(mode))
    for batch_idx, (_, imgs, labels) in enumerate(dataloader):
        if mode == "train" and len(imgs) == args.batch_size and len(
                set(labels.tolist())) != args.classes_per_it_tr:
            # Curriculum many-shot learning, do not update few-shot loss and acc
            mymodel.emb_model.module.use_classifier = True
            prediction = mymodel.emb_model(imgs)
            mymodel.emb_model.module.use_classifier = False
            labels = labels.cuda()
            loss = criterion(prediction, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            if mode == "train" and lr_scheduler:
                if not type(lr_scheduler) == list:
                    lr_scheduler = [lr_scheduler]
                for l_s in lr_scheduler:
                    l_s.step()
            embs = mymodel.emb_model(imgs)
            with torch.no_grad():
                target_classes = list(set(labels.tolist()))
                # get nodes
                label2metalabel, forward_adj, backward_adj = get_real_propagation_graph(
                    target_classes, hierarchy_info, args.n_hop, 5, 0.5
                )  # if the num of children or parents > 5, only select half of them
                base_proto = pp_running[list(label2metalabel.keys(
                ))]  # input a list indices will create a clone
                grouped_s_idxs, grouped_q_idxs, target_idx = [], [], []
                for cls in target_classes:
                    all_idxs = (labels == cls).nonzero().view(-1).tolist()
                    s_idx = all_idxs[:n_support]
                    q_idx = all_idxs[n_support:]
                    grouped_s_idxs.append(torch.IntTensor(s_idx))
                    grouped_q_idxs.append(torch.IntTensor(q_idx))
                    base_proto[label2metalabel[cls]] = torch.mean(embs[s_idx],
                                                                  dim=0)

                if "mst" in args.training_strategy:
                    # get edges
                    features = base_proto
                    forward_adj = forward_adj.cuda()
                    backward_adj = backward_adj.cuda()
                    if mymodel.mst_net:
                        distance_forward = -mymodel.mst_net(
                            features, forward_adj)
                        distance_backward = -mymodel.mst_net(
                            features, backward_adj)
                    else:  # euclidean distance
                        node_num, feat_dim = features.size()
                        q = features.view(1, node_num, feat_dim)
                        k = features.view(node_num, 1, feat_dim)
                        distance_forward = distance_backward = torch.norm(
                            q - k, p='fro', dim=2)

                    #print("[before mst]-foward_adj is {}".format(forward_adj))
                    #print("[before mst]-backward_adj is {}".format(backward_adj))
                    forward_adj = get_max_spanning_tree_kruskal(
                        forward_adj, distance_forward)
                    backward_adj = get_max_spanning_tree_kruskal(
                        backward_adj, distance_backward)
                    #print("[after mst]-foward_adj is {}".format(forward_adj))
                    #print("[after mst]-backward_adj is {}".format(backward_adj))
            # propagation
            if "single-round" in args.training_strategy:
                adj_lst = [forward_adj for i in range(args.n_hop)
                           ] + [backward_adj for i in range(args.n_hop)]
            elif "multi-round" in args.training_strategy:
                adj_lst = []
                for i in range(args.n_hop):
                    adj_lst.append(forward_adj)
                    adj_lst.append(backward_adj)
            elif "allaround" in args.training_strategy:
                all_adj = forward_adj + backward_adj
                adj_lst = [all_adj for i in range(args.n_hop)]
            elif "only-forward" in args.training_strategy:
                adj_lst = [forward_adj for i in range(args.n_hop)]
            elif "only-backward" in args.training_strategy:
                adj_lst = [backward_adj for i in range(args.n_hop)]
            else:
                raise ValueError(
                    "invalid training_strategy for adj : {}".format(
                        args.training_strategy))
            prop_proto = mymodel.propagation_net(base_proto,
                                                 adj_lst * num_device)
            target_prop_proto = prop_proto[:len(target_classes)]
            target_base_proto = base_proto[:len(target_classes)]
            if args.coef_base == -1 and args.coef_prop == -1:
                if epoch == -1:
                    final_proto = target_prop_proto
                else:
                    final_proto = target_base_proto * (
                        1 - epoch /
                        args.epochs) + target_prop_proto * epoch / args.epochs
            else:
                coef_norm = args.coef_base + args.coef_prop
                final_proto = target_base_proto * args.coef_base / coef_norm + target_prop_proto * args.coef_prop / coef_norm
            query_idxs = torch.cat(grouped_q_idxs, dim=0).tolist()
            query_meta_labels = torch.LongTensor([
                label2metalabel[labels[i].item()] for i in query_idxs
            ]).cuda(non_blocking=True)
            logits = -euclidean_dist(
                embs[query_idxs], final_proto, transform=True).view(
                    len(query_idxs), len(target_classes))
            loss = criterion(logits, query_meta_labels)
            losses.update(loss.item(), len(query_idxs))

            if mode == "train":
                # train on selected classes in the graph
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                top_fs = obtain_accuracy(logits, query_meta_labels, (1, ))
                acc1.update(top_fs[0].item(), len(query_idxs))
                update_acc([args.coef_base, args.coef_prop],
                           [target_base_proto, target_prop_proto],
                           [acc_base, acc_prop], embs[query_idxs],
                           query_meta_labels)

            elif mode == "test":
                with torch.no_grad():
                    # only test few-shot level, but the loss is still graph loss
                    top_fs = obtain_accuracy(logits, query_meta_labels, (1, ))
                    acc1.update(top_fs[0].item(), len(query_idxs))
                    update_acc([args.coef_base, args.coef_prop],
                               [target_base_proto, target_prop_proto],
                               [acc_base, acc_prop], embs[query_idxs],
                               query_meta_labels)
                    metaval_accuracies.append(top_fs[0].item())
                    if batch_idx + 1 == len(dataloader):
                        metaval_accuracies = np.array(metaval_accuracies)
                        stds = np.std(metaval_accuracies, 0)
                        ci95 = 1.96 * stds / np.sqrt(batch_idx + 1)
                        logger.print("ci95 is : {:}".format(ci95))

            else:
                raise ValueError("undefined mode : {}".format(mode))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if (mode=="train" and ((batch_idx % args.log_interval == 0) or (batch_idx + 1 == len(dataloader)))) \
         or (mode=="test" and (batch_idx + 1 == len(dataloader))):
            Tstring = 'TIME[{data_time.val:.2f} ({data_time.avg:.2f}) {batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(
                data_time=data_time, batch_time=batch_time)
            Sstring = '{:} {:} [Epoch={:03d}/{:03d}] [{:03d}/{:03d}]'.format(
                time_string(), mode, epoch, args.epochs, batch_idx,
                len(dataloader))
            Astring = 'loss=({:.3f}, {:.3f}), acc@1=({:.1f}, {:.1f}), acc@base=({:.1f}, {:.1f}), acc@prop=({:.1f}, {:.1f})'.format(
                losses.val, losses.avg, acc1.val, acc1.avg, acc_base.val,
                acc_base.avg, acc_prop.val, acc_prop.avg)
            Cstring = 'p_base_weigth : {:.4f}; p_prop_weight : {:.4f} '.format(
                args.coef_base, args.coef_prop)
            logger.print('{:} {:} {:} \n'.format(Sstring, Tstring, Astring))
    return losses, acc1, acc_base, acc_prop