示例#1
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path, True)

    sampler = MultiLayerNeighborSampler(
        list(range(args.neighbor_size, args.neighbor_size + args.num_layers)))
    train_loader = NodeDataLoader(g, {predict_ntype: train_idx},
                                  sampler,
                                  device=device,
                                  batch_size=args.batch_size)
    loader = NodeDataLoader(g, {predict_ntype: g.nodes(predict_ntype)},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)

    model = RHGNN(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_rel_hidden, args.num_rel_hidden, args.num_heads, g.ntypes,
        g.canonical_etypes, predict_ntype, args.num_layers,
        args.dropout).to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=len(train_loader) *
                                                     args.epochs,
                                                     eta_min=args.lr / 100)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in tqdm(train_loader):
            batch_logits = model(blocks, blocks[0].srcdata['feat'])
            batch_labels = labels[output_nodes[predict_ntype]]
            loss = F.cross_entropy(batch_logits, batch_labels)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, loader, g, labels, data.num_classes, predict_ntype,
                    train_idx, val_idx, test_idx, evaluator)))
    if args.save_path:
        torch.save(model.cpu().state_dict(), args.save_path)
        print('模型已保存到', args.save_path)
示例#2
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, args.node_feat, args.node_embed_path)

    sampler = MultiLayerNeighborSampler([args.neighbor_size] * args.num_layers)
    train_loader = NodeDataLoader(g, {predict_ntype: train_idx},
                                  sampler,
                                  device=device,
                                  batch_size=args.batch_size)
    loader = NodeDataLoader(g, {predict_ntype: g.nodes(predict_ntype)},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)

    model = HGT(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_heads, g.ntypes, g.canonical_etypes, predict_ntype,
        args.num_layers, args.dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), eps=1e-6)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        args.max_lr,
        epochs=args.epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.05,
        anneal_strategy='linear',
        final_div_factor=10.0)
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in tqdm(train_loader):
            batch_logits = model(blocks, blocks[0].srcdata['feat'])
            batch_labels = labels[output_nodes[predict_ntype]]
            loss = F.cross_entropy(batch_logits, batch_labels)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, loader, g, labels, data.num_classes, predict_ntype,
                    train_idx, val_idx, test_idx, evaluator)))
    if args.save_path:
        torch.save(model.cpu().state_dict(), args.save_path)
        print('模型已保存到', args.save_path)
示例#3
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    g, labels, num_classes, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.ogb_path, device)
    load_pretrained_node_embed(g, args.node_embed_path)
    g = g.to(device)

    sampler = MultiLayerNeighborSampler(
        list(range(args.neighbor_size, args.neighbor_size + args.num_layers))
    )
    train_loader = NodeDataLoader(g, {'paper': train_idx}, sampler, device=device, batch_size=args.batch_size)
    val_loader = NodeDataLoader(g, {'paper': val_idx}, sampler, device=device, batch_size=args.batch_size)
    test_loader = NodeDataLoader(g, {'paper': test_idx}, sampler, device=device, batch_size=args.batch_size)

    model = RHGNN(
        {ntype: g.nodes[ntype].data['feat'].shape[1] for ntype in g.ntypes},
        args.num_hidden, num_classes, args.num_rel_hidden, args.num_rel_hidden, args.num_heads,
        g.ntypes, g.canonical_etypes, 'paper', args.num_layers, args.dropout, residual=args.residual
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader) * args.epochs, eta_min=args.lr / 100
    )
    warnings.filterwarnings('ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        logits, train_labels, losses = [], [], []
        for input_nodes, output_nodes, blocks in tqdm(train_loader):
            batch_labels = labels[output_nodes['paper']]
            batch_logits = model(blocks, blocks[0].srcdata['feat'])
            loss = F.cross_entropy(batch_logits, batch_labels.squeeze(dim=1))

            logits.append(batch_logits.detach().cpu())
            train_labels.append(batch_labels.detach().cpu())
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()

        train_acc = accuracy(torch.cat(logits, dim=0), torch.cat(train_labels, dim=0), evaluator)
        val_acc = evaluate(val_loader, device, model, labels, evaluator)
        test_acc = evaluate(test_loader, device, model, labels, evaluator)
        print('Epoch {:d} | Train Loss {:.4f} | Train Acc {:.4f} | Val Acc {:.4f} | Test Acc {:.4f}'.format(
            epoch, torch.tensor(losses).mean().item(), train_acc, val_acc, test_acc
        ))
    # embed = model.inference(g, g.ndata['feat'], device, args.batch_size)
    # test_acc = accuracy(embed[test_idx], labels[test_idx], evaluator)
    test_acc = evaluate(test_loader, device, model, labels, evaluator)
    print('Test Acc {:.4f}'.format(test_acc))
示例#4
0
    def inference(self, g, feats, device, batch_size):
        """离线推断所有顶点的最终嵌入(不使用邻居采样)

        :param g: DGLGraph 异构图
        :param feats: Dict[str, tensor(N_i, d_in_i)] 顶点类型到输入顶点特征的映射
        :param device: torch.device
        :param batch_size: int 批大小
        :return: tensor(N_i, d_out) 待预测顶点的最终嵌入
        """
        g.ndata['emb'] = {ntype: self.fc_in[ntype](feat) for ntype, feat in feats.items()}
        for layer in self.layers:
            embeds = {
                ntype: torch.zeros(g.num_nodes(ntype), self.d, device=device)
                for ntype in g.ntypes
            }
            sampler = MultiLayerFullNeighborSampler(1)
            loader = NodeDataLoader(
                g, {ntype: g.nodes(ntype) for ntype in g.ntypes}, sampler, device=device,
                batch_size=batch_size, shuffle=True
            )
            for input_nodes, output_nodes, blocks in tqdm(loader):
                block = blocks[0]
                h = layer(block, block.srcdata['emb'])
                for ntype in h:
                    embeds[ntype][output_nodes[ntype]] = h[ntype]
            g.ndata['emb'] = embeds
        return self.classifier(g.nodes[self.predict_ntype].data['emb'])
    def inference(self, mode='validation', verbose=False):
        assert mode in ['validation', 'testing'], "got mode {}".format(mode)
        from dgl.dataloading import NodeDataLoader, MultiLayerNeighborSampler
        self.eval()
        if mode == 'testing':
            sampler = MultiLayerNeighborSampler([None])
        else:
            sampler = MultiLayerNeighborSampler(self.fans)
        g = self.cpu_graph
        kwargs = {
            'batch_size': 64,
            'shuffle': True,
            'drop_last': False,
            'num_workers': 6,
        }
        dataloader = NodeDataLoader(g, th.arange(g.number_of_nodes()), sampler,
                                    **kwargs)
        if verbose:
            dataloader = tqdm(dataloader)

        x = self.embedding.weight
        x = th.cat((self.W1(x[:self.num_users]), self.W2(x[self.num_users:])),
                   dim=0)

        # Within a layer, iterate over nodes in batches
        for input_nodes, output_nodes, blocks in dataloader:
            block = blocks[0].to(commons.device)
            h = self.forward_block(block, x[input_nodes])
            self.check_point[output_nodes] = h

        if verbose:
            print('Inference Done Successfully')
    def inference(self, mode='validation', verbose=False):
        assert mode in ['validation', 'testing'], "got mode {}".format(mode)
        from dgl.dataloading import NodeDataLoader, MultiLayerNeighborSampler
        self.eval()
        if mode == 'testing':
            sampler = MultiLayerNeighborSampler([None] * self.num_layers)
        else:
            sampler = MultiLayerNeighborSampler(self.fans)

        g = self.cpu_graph
        kwargs = {
            'batch_size': 1024,
            'shuffle': True,
            'drop_last': False,
            'num_workers': commons.workers,
        }

        dataloader = NodeDataLoader(g, th.arange(g.number_of_nodes()), sampler,
                                    **kwargs)
        # Within a layer, iterate over nodes in batches
        if verbose:
            dataloader = tqdm(dataloader)
        for input_nodes, output_nodes, blocks in dataloader:
            blocks = [x.to(commons.device) for x in blocks]
            users = th.arange(output_nodes.shape[0]).long().to(self.device)
            d1 = th.zeros((0, )).long().to(self.device)
            d2 = th.zeros((0, )).long().to(self.device)
            h = self.forward_blocks(blocks, users, d1, d2)[0]
            self.check_point[output_nodes] = h
        if verbose:
            print('Inference Done Successfully')
示例#7
0
def calc_attn_pos(g, num_classes, predict_ntype, num_samples, device, args):
    """使用预训练的HGT模型计算的注意力权重选择目标顶点的正样本。"""
    # 第1层只保留AB边,第2层只保留BA边,其中A是目标顶点类型,B是中间顶点类型
    num_neighbors = [{}, {}]
    # 形如ABA的元路径,其中A是目标顶点类型
    metapaths = []
    rev_etype = {
        e: next(re for rs, re, rd in g.canonical_etypes
                if rs == d and rd == s and re != e)
        for s, e, d in g.canonical_etypes
    }
    for s, e, d in g.canonical_etypes:
        if d == predict_ntype:
            re = rev_etype[e]
            num_neighbors[0][re] = num_neighbors[1][e] = 10
            metapaths.append((re, e))
    for i in range(len(num_neighbors)):
        d = dict.fromkeys(g.etypes, 0)
        d.update(num_neighbors[i])
        num_neighbors[i] = d
    sampler = MultiLayerNeighborSampler(num_neighbors)
    loader = NodeDataLoader(g, {predict_ntype: g.nodes(predict_ntype)},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)

    model = HGT(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, num_classes, args.num_heads,
        g.ntypes, g.canonical_etypes, predict_ntype, 2,
        args.dropout).to(device)
    model.load_state_dict(torch.load(args.hgt_model_path, map_location=device))

    # 每条元路径ABA对应一个正样本图G_ABA,加一个总体正样本图G_pos
    pos = [
        torch.zeros(g.num_nodes(predict_ntype),
                    num_samples,
                    dtype=torch.long,
                    device=device) for _ in range(len(metapaths) + 1)
    ]
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in tqdm(loader):
            _ = model(blocks, blocks[0].srcdata['feat'])
            # List[tensor(N_src, N_dst)]
            attn = [
                calc_attn(mp, model, blocks, device).t() for mp in metapaths
            ]
            for i in range(len(attn)):
                _, nid = torch.topk(attn[i], num_samples)  # (N_dst, T_pos)
                # nid是blocks[0]中的源顶点id,将其转换为原异构图中的顶点id
                pos[i][output_nodes[predict_ntype]] = input_nodes[
                    predict_ntype][nid]
            _, nid = torch.topk(sum(attn), num_samples)
            pos[-1][
                output_nodes[predict_ntype]] = input_nodes[predict_ntype][nid]
    return [p.cpu() for p in pos]
示例#8
0
def infer(model, g, ntype, out_dim, sampler, batch_size, device):
    model.eval()
    embeds = torch.zeros((g.num_nodes(ntype), out_dim), device=device)
    loader = NodeDataLoader(g, {ntype: g.nodes(ntype)},
                            sampler,
                            device=device,
                            batch_size=batch_size)
    for _, output_nodes, blocks in tqdm(loader):
        embeds[output_nodes[ntype]] = model(blocks, blocks[0].srcdata['feat'])
    embeds = embeds / embeds.norm(dim=1, keepdim=True)
    return embeds
示例#9
0
def init_dataloaders(args,
                     g,
                     train_idx,
                     test_idx,
                     target_idx,
                     device,
                     use_ddp=False):
    fanouts = [int(fanout) for fanout in args.fanout.split(',')]
    sampler = MultiLayerNeighborSampler(fanouts)

    train_loader = NodeDataLoader(g,
                                  target_idx[train_idx],
                                  sampler,
                                  use_ddp=use_ddp,
                                  device=device,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=False)

    # The datasets do not have a validation subset, use the train subset
    val_loader = NodeDataLoader(g,
                                target_idx[train_idx],
                                sampler,
                                use_ddp=use_ddp,
                                device=device,
                                batch_size=args.batch_size,
                                shuffle=False,
                                drop_last=False)

    # -1 for sampling all neighbors
    test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
    test_loader = NodeDataLoader(g,
                                 target_idx[test_idx],
                                 test_sampler,
                                 use_ddp=use_ddp,
                                 device=device,
                                 batch_size=32,
                                 shuffle=False,
                                 drop_last=False)

    return train_loader, val_loader, test_loader
示例#10
0
    def inference(self, g, feats, device, batch_size):
        """离线推断所有顶点的最终嵌入(不使用邻居采样)

        :param g: DGLGraph 异构图
        :param feats: Dict[str, tensor(N_i, d_in_i)] 顶点类型到输入顶点特征的映射
        :param device: torch.device
        :param batch_size: int 批大小
        :return: tensor(N_i, d_out) 待预测顶点的最终嵌入
        """
        feats = {
            (stype, etype, dtype): self.fc_in[dtype](feats[dtype].to(device))
            for stype, etype, dtype in self.etypes
        }
        rel_feats = {rel: emb.flatten() for rel, emb in self.rel_embed.items()}
        for layer in self.layers:
            # TODO 内存占用过大
            embeds = {
                (stype, etype, dtype): torch.zeros(g.num_nodes(dtype), self._d)
                for stype, etype, dtype in g.canonical_etypes
            }
            sampler = MultiLayerFullNeighborSampler(1)
            loader = NodeDataLoader(
                g, {ntype: torch.arange(g.num_nodes(ntype)) for ntype in g.ntypes}, sampler,
                batch_size=batch_size, shuffle=True
            )
            for input_nodes, output_nodes, blocks in tqdm(loader):
                block = blocks[0].to(device)
                in_feats = {
                    (s, e, d): feats[(s, e, d)][input_nodes[d]].to(device)
                    for s, e, d in feats
                }
                h, rel_embeds = layer(block, in_feats, rel_feats)
                for s, e, d in h:
                    embeds[(s, e, d)][output_nodes[d]] = h[(s, e, d)].cpu()
            feats = embeds
            rel_feats = rel_embeds
        feats = {r: feat.to(device) for r, feat in feats.items()}

        out_feats = {ntype: torch.zeros(g.num_nodes(ntype), self._d) for ntype in g.ntypes}
        for ntype in set(d for _, _, d in feats):
            dst_feats = {e: feats[(s, e, d)] for s, e, d in feats if d == ntype}
            dst_rel_feats = {e: rel_feats[e] for s, e, d in feats if d == ntype}
            for batch in DataLoader(torch.arange(g.num_nodes(ntype)), batch_size=batch_size):
                out_feats[ntype][batch] = self.rel_fusing[ntype](
                    {e: dst_feats[e][batch] for e in dst_rel_feats}, dst_rel_feats
                )
        return self.classifier(out_feats[self.predict_ntype])
def main():
    data = UserItemDataset()
    g = data[0]
    train_idx = g.nodes['user'].data['train_mask'].nonzero(as_tuple=True)[0]

    sampler = MultiLayerFullNeighborSampler(2)
    dataloader = NodeDataLoader(g, {'user': train_idx}, sampler, batch_size=256)

    model = RGCN(10, 20, 5, g.etypes)
    opt = optim.Adam(model.parameters())

    for epoch in range(10):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in dataloader:
            logits = model(blocks, blocks[0].srcdata['feat'])['user']
            loss = F.cross_entropy(logits, blocks[-1].dstnodes['user'].data['label'])
            losses.append(loss.item())
            opt.zero_grad()
            loss.backward()
            opt.step()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch + 1, torch.tensor(losses).mean().item()))
示例#12
0
def main():
    data = CiteseerGraphDataset()
    g = data[0]
    train_idx = g.ndata['train_mask'].nonzero(as_tuple=True)[0]

    sampler = MultiLayerFullNeighborSampler(2)
    dataloader = NodeDataLoader(g, train_idx, sampler, batch_size=32)

    model = GCN(g.ndata['feat'].shape[1], 100, data.num_classes)
    optimizer = optim.Adam(model.parameters())
    for epoch in range(30):
        model.train()
        losses = []
        for input_nodes, output_nodes, blocks in dataloader:
            logits = model(blocks, blocks[0].srcdata['feat'])
            loss = F.cross_entropy(logits, blocks[-1].dstdata['label'])
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Epoch {:d} | Loss {:.4f}'.format(
            epoch + 1,
            torch.tensor(losses).mean().item()))
示例#13
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path, True)
    features = g.nodes[predict_ntype].data['feat']

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    pos_g = pos_g.to(device)
    pos = pos_g.in_edges(pos_g.nodes())[0].view(pos_g.num_nodes(), -1)  # (N, T_pos) 每个目标顶点的正样本id
    # 不能用pos_g.edges(),必须按终点id排序

    id_loader = DataLoader(train_idx, batch_size=args.batch_size)
    loader = NodeDataLoader(
        g, {predict_ntype: train_idx}, PositiveSampler([args.neighbor_size] * args.num_layers, pos),
        device=device, batch_size=args.batch_size
    )
    sampler = PositiveSampler([None], pos)
    mg_loaders = [
        NodeDataLoader(mg, train_idx, sampler, device=device, batch_size=args.batch_size)
        for mg in mgs
    ]
    pos_loader = NodeDataLoader(pos_g, train_idx, sampler, device=device, batch_size=args.batch_size)

    model_class = get_model_class(args.model)
    model = model_class(
        {ntype: g.nodes[ntype].data['feat'].shape[1] for ntype in g.ntypes},
        args.num_hidden, data.num_classes, args.num_rel_hidden, args.num_heads,
        g.ntypes, g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
        len(mgs), args.tau, args.lambda_
    ).to(device)
    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, map_location=device))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(loader) * args.epochs, eta_min=args.lr / 100
    )
    alpha = args.contrast_weight
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for (batch, (_, _, blocks), *mg_blocks, (_, _, pos_blocks)) in tqdm(zip(id_loader, loader, *mg_loaders, pos_loader)):
            mg_feats = [features[i] for i, _, _ in mg_blocks]
            mg_blocks = [b[0] for _, _, b in mg_blocks]
            pos_block = pos_blocks[0]
            # pos_block.num_dst_nodes() = batch_size + 正样本数
            batch_pos = torch.zeros(pos_block.num_dst_nodes(), batch.shape[0], dtype=torch.int, device=device)
            batch_pos[pos_block.in_edges(torch.arange(batch.shape[0], device=device))] = 1
            contrast_loss, logits = model(blocks, blocks[0].srcdata['feat'], mg_blocks, mg_feats, batch_pos.t())
            clf_loss = F.cross_entropy(logits, labels[batch])
            loss = alpha * contrast_loss + (1 - alpha) * clf_loss
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch, sum(losses) / len(losses)))
        torch.save(model.state_dict(), args.save_path)
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(METRICS_STR.format(*evaluate(
                model, g, mgs, args.neighbor_size, args.batch_size, device,
                labels, train_idx, val_idx, test_idx, evaluator
            )))
    torch.save(model.state_dict(), args.save_path)
    print('模型已保存到', args.save_path)
示例#14
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path)
    features = g.nodes[predict_ntype].data['feat']
    relations = [r for r in g.canonical_etypes if r[2] == predict_ntype]

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    pos_g = pos_g.to(device)
    pos = pos_g.in_edges(pos_g.nodes())[0].view(pos_g.num_nodes(),
                                                -1)  # (N, T_pos) 每个目标顶点的正样本id

    id_loader = DataLoader(train_idx, batch_size=args.batch_size)
    sampler = PositiveSampler([None], pos)
    loader = NodeDataLoader(g, {predict_ntype: train_idx},
                            sampler,
                            device=device,
                            batch_size=args.batch_size)
    mg_loaders = [
        NodeDataLoader(mg,
                       train_idx,
                       sampler,
                       device=device,
                       batch_size=args.batch_size) for mg in mgs
    ]
    pos_loader = NodeDataLoader(pos_g,
                                train_idx,
                                sampler,
                                device=device,
                                batch_size=args.batch_size)

    model = HeCo(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, args.feat_drop,
        args.attn_drop, relations, args.tau, args.lambda_).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        model.train()
        losses = []
        for (batch, (_, _, blocks), *mg_blocks, (_, _, pos_blocks)) in tqdm(
                zip(id_loader, loader, *mg_loaders, pos_loader)):
            block = blocks[0]
            mg_feats = [features[i] for i, _, _ in mg_blocks]
            mg_blocks = [b[0] for _, _, b in mg_blocks]
            pos_block = pos_blocks[0]
            batch_pos = torch.zeros(pos_block.num_dst_nodes(),
                                    batch.shape[0],
                                    dtype=torch.int,
                                    device=device)
            batch_pos[pos_block.in_edges(
                torch.arange(batch.shape[0], device=device))] = 1
            loss, _ = model(block, block.srcdata['feat'], mg_blocks, mg_feats,
                            batch_pos.t())
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch,
                                                sum(losses) / len(losses)))
        if epoch % args.eval_every == 0 or epoch == args.epochs - 1:
            print(
                METRICS_STR.format(*evaluate(
                    model, mgs, features, device, labels, data.num_classes,
                    train_idx, val_idx, test_idx, evaluator)))