Esempio n. 1
0
def main():
    args = parse_args()
    print(args)
    set_random_seed(args.seed)
    device = get_device(args.device)

    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, _ = load_data(
        args.dataset)
    g = g.to(device)
    labels = labels.tolist()
    train_idx = torch.cat([train_idx, val_idx])
    add_node_feat(g, 'one-hot')

    label_neigh = sample_label_neighbors(labels,
                                         args.num_samples)  # (N, T_pos)
    # List[tensor(N, T_pos)] HGT计算出的注意力权重,M条元路径+一个总体
    attn_pos = calc_attn_pos(g, data.num_classes, predict_ntype,
                             args.num_samples, device, args)

    # 元路径对应的正样本图
    v = torch.repeat_interleave(g.nodes(predict_ntype), args.num_samples).cpu()
    pos_graphs = []
    for p in attn_pos[:-1]:
        u = p.view(1, -1).squeeze(dim=0)  # (N*T_pos,)
        pos_graphs.append(dgl.graph((u, v)))

    # 整体正样本图
    pos = attn_pos[-1]
    if args.use_label:
        pos[train_idx] = label_neigh[train_idx]
    u = pos.view(1, -1).squeeze(dim=0)
    pos_graphs.append(dgl.graph((u, v)))

    dgl.save_graphs(args.save_graph_path, pos_graphs)
    print('正样本图已保存到', args.save_graph_path)
Esempio n. 2
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    g, author_rank, field_ids, true_relevance = load_rank_data(device)
    field_paper = recall_paper(g.cpu(), field_ids, args.num_recall)
    data = RatingKnowledgeGraphDataset()
    user_item_graph = data.user_item_graph
    knowledge_graph = dgl.sampling.sample_neighbors(
        data.knowledge_graph, data.knowledge_graph.nodes(), args.neighbor_size, replace=True
    )

    sampler = MultiLayerNeighborSampler([args.neighbor_size] * args.num_hops)
    train_loader = KGCNEdgeDataLoader(
        user_item_graph, torch.arange(user_item_graph.num_edges()), sampler, knowledge_graph,
        device=device, batch_size=args.batch_size
    )

    model = KGCN(args.num_hidden, args.neighbor_size, 'sum', args.num_hops, *data.get_num()).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for _, pair_graph, blocks in train_loader:
            scores = model(pair_graph, blocks)
            loss = F.binary_cross_entropy(scores, pair_graph.edata['label'])
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch, sum(losses) / len(losses)))
        print(METRICS_STR.format(*evaluate(
            model, g, knowledge_graph, sampler, field_ids, author_rank, true_relevance, field_paper
        )))
Esempio n. 3
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, _ = \
        load_data(args.dataset, device)
    add_node_feat(g, 'one-hot')

    model = HGConvFull(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_heads, g.ntypes, g.canonical_etypes, predict_ntype,
        args.num_layers, args.dropout, args.residual).to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        logits = model(g, g.ndata['feat'])
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
            epoch, loss.item(),
            *evaluate_full(model, g, labels, train_idx, val_idx, test_idx)))
Esempio n. 4
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path, True)

    sampler = MultiLayerNeighborSampler(
        list(range(args.neighbor_size, args.neighbor_size + args.num_layers)))
    train_loader = NodeDataLoader(g, {predict_ntype: train_idx},
                                  sampler,
                                  device=device,
                                  batch_size=args.batch_size)
    loader = NodeDataLoader(g, {predict_ntype: g.nodes(predict_ntype)},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)

    model = RHGNN(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_rel_hidden, args.num_rel_hidden, args.num_heads, g.ntypes,
        g.canonical_etypes, predict_ntype, args.num_layers,
        args.dropout).to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=len(train_loader) *
                                                     args.epochs,
                                                     eta_min=args.lr / 100)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in tqdm(train_loader):
            batch_logits = model(blocks, blocks[0].srcdata['feat'])
            batch_labels = labels[output_nodes[predict_ntype]]
            loss = F.cross_entropy(batch_logits, batch_labels)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, loader, g, labels, data.num_classes, predict_ntype,
                    train_idx, val_idx, test_idx, evaluator)))
    if args.save_path:
        torch.save(model.cpu().state_dict(), args.save_path)
        print('模型已保存到', args.save_path)
Esempio n. 5
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, args.node_feat, args.node_embed_path)

    sampler = MultiLayerNeighborSampler([args.neighbor_size] * args.num_layers)
    train_loader = NodeDataLoader(g, {predict_ntype: train_idx},
                                  sampler,
                                  device=device,
                                  batch_size=args.batch_size)
    loader = NodeDataLoader(g, {predict_ntype: g.nodes(predict_ntype)},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)

    model = HGT(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_heads, g.ntypes, g.canonical_etypes, predict_ntype,
        args.num_layers, args.dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), eps=1e-6)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        args.max_lr,
        epochs=args.epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.05,
        anneal_strategy='linear',
        final_div_factor=10.0)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in tqdm(train_loader):
            batch_logits = model(blocks, blocks[0].srcdata['feat'])
            batch_labels = labels[output_nodes[predict_ntype]]
            loss = F.cross_entropy(batch_logits, batch_labels)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, loader, g, labels, data.num_classes, predict_ntype,
                    train_idx, val_idx, test_idx, evaluator)))
    if args.save_path:
        torch.save(model.cpu().state_dict(), args.save_path)
        print('模型已保存到', args.save_path)
Esempio n. 6
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, features, labels, predict_ntype, train_idx, val_idx, test_idx, _ = \
        load_data(args.dataset, device)
    add_node_feat(g, 'one-hot')

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    if args.use_data_pos:
        pos_v, pos_u = data.pos
        pos_g = dgl.graph((pos_u, pos_v), device=device)
    pos = torch.zeros((g.num_nodes(predict_ntype), g.num_nodes(predict_ntype)),
                      dtype=torch.int,
                      device=device)
    pos[data.pos] = 1

    model = RHCOFull(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_rel_hidden, args.num_heads, g.ntypes,
        g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
        len(mgs), args.tau, args.lambda_).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=args.epochs,
                                                     eta_min=args.lr / 100)
    alpha = args.contrast_weight
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        contrast_loss, logits = model(g, g.ndata['feat'], mgs, features, pos)
        clf_loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        loss = alpha * contrast_loss + (1 - alpha) * clf_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        torch.cuda.empty_cache()
        print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
            epoch, loss.item(),
            *evaluate(model, g, labels, train_idx, val_idx, test_idx)))

    model.eval()
    _, base_pred = model(g, g.ndata['feat'], mgs, features, pos)
    mask = torch.cat([train_idx, val_idx])
    logits = smooth(base_pred, pos_g, labels, mask, args)
    _, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx,
                                                 val_idx, test_idx)
    print('After smoothing: Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(
        test_acc, test_f1))
Esempio n. 7
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, _, feat, labels, _, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    feat = (feat - feat.mean(dim=0)) / feat.std(dim=0)
    # 标签传播图
    if args.dataset in ('acm', 'dblp'):
        pos_v, pos_u = data.pos
        pg = dgl.graph((pos_u, pos_v), device=device)
    else:
        pg = dgl.load_graphs(args.prop_graph)[0][-1].to(device)

    if args.dataset == 'oag-venue':
        labels[labels == -1] = 0

    base_model = nn.Linear(feat.shape[1], data.num_classes).to(device)
    train_base_model(base_model, feat, labels, train_idx, val_idx, test_idx,
                     evaluator, args)
    correct_and_smooth(base_model, pg, feat, labels, train_idx, val_idx,
                       test_idx, evaluator, args)
Esempio n. 8
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, _ = \
        load_data(args.dataset, device)
    add_node_feat(g, 'one-hot')

    model = HGTFull(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_heads, g.ntypes, g.canonical_etypes, predict_ntype,
        args.num_layers, args.dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), eps=1e-6)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                              args.max_lr,
                                              epochs=args.epochs,
                                              steps_per_epoch=1,
                                              pct_start=0.05,
                                              anneal_strategy='linear',
                                              final_div_factor=10.0)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        logits = model(g, g.ndata['feat'])
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        torch.cuda.empty_cache()
        print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
            epoch, loss.item(),
            *evaluate_full(model, g, labels, train_idx, val_idx, test_idx)))
    if args.save_path:
        torch.save(model.cpu().state_dict(), args.save_path)
        print('模型已保存到', args.save_path)
Esempio n. 9
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, features, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device, reverse_self=False)

    model = RGCN(
        features.shape[1], args.num_hidden, data.num_classes, [predict_ntype],
        {ntype: g.num_nodes(ntype) for ntype in g.ntypes}, g.etypes,
        predict_ntype, args.num_layers, args.dropout
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    features = {predict_ntype: features}
    for epoch in range(args.epochs):
        model.train()
        logits = model(g, features)
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
            epoch, loss.item(),
            *evaluate(model, g, features, labels, train_idx, val_idx, test_idx, evaluator)
        ))
Esempio n. 10
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)

    raw_file = DATA_DIR / 'oag/cs/mag_papers.txt'
    train_dataset = OAGCSContrastDataset(raw_file, split='train')
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate)
    valid_dataset = OAGCSContrastDataset(raw_file, split='valid')
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate)

    model = ContrastiveSciBERT(args.num_hidden, args.tau, device).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=total_steps * 0.1, num_training_steps=total_steps
    )
    for epoch in range(args.epochs):
        model.train()
        losses, scores = [], []
        for titles, keywords in tqdm(train_loader):
            logits, loss = model(titles, keywords)
            labels = torch.arange(len(titles), device=device)
            losses.append(loss.item())
            scores.append(score(logits, labels))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
        val_score = evaluate(valid_loader, model, device)
        print('Epoch {:d} | Loss {:.4f} | Train Acc {:.4f} | Val Acc {:.4f}'.format(
            epoch, sum(losses) / len(losses), sum(scores) / len(scores), val_score
        ))
    model_save_path = MODEL_DIR / 'scibert.pt'
    torch.save(model.state_dict(), model_save_path)
    print('模型已保存到', model_save_path)
Esempio n. 11
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path, True)
    features = g.nodes[predict_ntype].data['feat']

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    pos_g = pos_g.to(device)
    pos = pos_g.in_edges(pos_g.nodes())[0].view(pos_g.num_nodes(), -1)  # (N, T_pos) 每个目标顶点的正样本id
    # 不能用pos_g.edges(),必须按终点id排序

    id_loader = DataLoader(train_idx, batch_size=args.batch_size)
    loader = NodeDataLoader(
        g, {predict_ntype: train_idx}, PositiveSampler([args.neighbor_size] * args.num_layers, pos),
        device=device, batch_size=args.batch_size
    )
    sampler = PositiveSampler([None], pos)
    mg_loaders = [
        NodeDataLoader(mg, train_idx, sampler, device=device, batch_size=args.batch_size)
        for mg in mgs
    ]
    pos_loader = NodeDataLoader(pos_g, train_idx, sampler, device=device, batch_size=args.batch_size)

    model_class = get_model_class(args.model)
    model = model_class(
        {ntype: g.nodes[ntype].data['feat'].shape[1] for ntype in g.ntypes},
        args.num_hidden, data.num_classes, args.num_rel_hidden, args.num_heads,
        g.ntypes, g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
        len(mgs), args.tau, args.lambda_
    ).to(device)
    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, map_location=device))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(loader) * args.epochs, eta_min=args.lr / 100
    )
    alpha = args.contrast_weight
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for (batch, (_, _, blocks), *mg_blocks, (_, _, pos_blocks)) in tqdm(zip(id_loader, loader, *mg_loaders, pos_loader)):
            mg_feats = [features[i] for i, _, _ in mg_blocks]
            mg_blocks = [b[0] for _, _, b in mg_blocks]
            pos_block = pos_blocks[0]
            # pos_block.num_dst_nodes() = batch_size + 正样本数
            batch_pos = torch.zeros(pos_block.num_dst_nodes(), batch.shape[0], dtype=torch.int, device=device)
            batch_pos[pos_block.in_edges(torch.arange(batch.shape[0], device=device))] = 1
            contrast_loss, logits = model(blocks, blocks[0].srcdata['feat'], mg_blocks, mg_feats, batch_pos.t())
            clf_loss = F.cross_entropy(logits, labels[batch])
            loss = alpha * contrast_loss + (1 - alpha) * clf_loss
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch, sum(losses) / len(losses)))
        torch.save(model.state_dict(), args.save_path)
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(METRICS_STR.format(*evaluate(
                model, g, mgs, args.neighbor_size, args.batch_size, device,
                labels, train_idx, val_idx, test_idx, evaluator
            )))
    torch.save(model.state_dict(), args.save_path)
    print('模型已保存到', args.save_path)
Esempio n. 12
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    g, author_rank, field_ids, true_relevance = load_rank_data(device)
    out_dim = g.nodes['field'].data['feat'].shape[1]
    add_node_feat(g, 'pretrained', args.node_embed_path, use_raw_id=True)
    field_paper = recall_paper(g.cpu(), field_ids,
                               args.num_recall)  # {field_id: [paper_id]}

    sampler = MultiLayerNeighborSampler([args.neighbor_size] * args.num_layers)
    sampler.set_output_context(to_dgl_context(device))
    triplet_collator = TripletNodeCollator(g, sampler)

    model = RHGNN(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, out_dim, args.num_rel_hidden,
        args.num_rel_hidden, args.num_heads, g.ntypes, g.canonical_etypes,
        'author', args.num_layers, args.dropout).to(device)
    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, map_location=device))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=len(field_ids) *
                                                     args.epochs,
                                                     eta_min=args.lr / 100)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for f in tqdm(field_ids):
            false_author_ids = list(
                set(g.in_edges(field_paper[f], etype='writes')[0].tolist()) -
                set(author_rank[f]))
            triplets = sample_triplets(f, author_rank[f], false_author_ids,
                                       args.num_triplets).to(device)
            aid, blocks = triplet_collator.collate(triplets)
            author_embeds = model(blocks, blocks[0].srcdata['feat'])
            author_embeds = author_embeds / author_embeds.norm(dim=1,
                                                               keepdim=True)
            aid_map = {a: i for i, a in enumerate(aid.tolist())}
            anchor = g.nodes['field'].data['feat'][triplets[:, 0]]
            positive = author_embeds[[
                aid_map[a] for a in triplets[:, 1].tolist()
            ]]
            negative = author_embeds[[
                aid_map[a] for a in triplets[:, 2].tolist()
            ]]
            loss = F.triplet_margin_loss(anchor, positive, negative,
                                         args.margin)

            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        torch.save(model.state_dict(), args.model_save_path)
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, g, out_dim, sampler, args.batch_size, device,
                    field_ids, field_paper, author_rank, true_relevance)))
    torch.save(model.state_dict(), args.model_save_path)
    print('模型已保存到', args.model_save_path)

    embeds = infer(model, g, 'author', out_dim, sampler, args.batch_size,
                   device)
    author_embed_save_path = DATA_DIR / 'rank/author_embed.pkl'
    torch.save(embeds.cpu(), author_embed_save_path)
    print('学者嵌入已保存到', author_embed_save_path)
Esempio n. 13
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path)
    features = g.nodes[predict_ntype].data['feat']
    relations = [r for r in g.canonical_etypes if r[2] == predict_ntype]

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    pos_g = pos_g.to(device)
    pos = pos_g.in_edges(pos_g.nodes())[0].view(pos_g.num_nodes(),
                                                -1)  # (N, T_pos) 每个目标顶点的正样本id

    id_loader = DataLoader(train_idx, batch_size=args.batch_size)
    sampler = PositiveSampler([None], pos)
    loader = NodeDataLoader(g, {predict_ntype: train_idx},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)
    mg_loaders = [
        NodeDataLoader(mg,
                       train_idx,
                       sampler,
                       device=device,
                       batch_size=args.batch_size) for mg in mgs
    ]
    pos_loader = NodeDataLoader(pos_g,
                                train_idx,
                                sampler,
                                device=device,
                                batch_size=args.batch_size)

    model = HeCo(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, args.feat_drop,
        args.attn_drop, relations, args.tau, args.lambda_).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for (batch, (_, _, blocks), *mg_blocks, (_, _, pos_blocks)) in tqdm(
                zip(id_loader, loader, *mg_loaders, pos_loader)):
            block = blocks[0]
            mg_feats = [features[i] for i, _, _ in mg_blocks]
            mg_blocks = [b[0] for _, _, b in mg_blocks]
            pos_block = pos_blocks[0]
            batch_pos = torch.zeros(pos_block.num_dst_nodes(),
                                    batch.shape[0],
                                    dtype=torch.int,
                                    device=device)
            batch_pos[pos_block.in_edges(
                torch.arange(batch.shape[0], device=device))] = 1
            loss, _ = model(block, block.srcdata['feat'], mg_blocks, mg_feats,
                            batch_pos.t())
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, mgs, features, device, labels, data.num_classes,
                    train_idx, val_idx, test_idx, evaluator)))