Example #1
0
def main():
    args = parse_args()
    print(args)
    device = get_device(args.device)
    data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
        load_data(args.dataset, device)
    add_node_feat(g, 'pretrained', args.node_embed_path, True)
    if args.dataset == 'oag-venue':
        labels[labels == -1] = 0
    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    pos_g = pos_g.to(device)

    model = RHCO(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_rel_hidden, args.num_heads, g.ntypes,
        g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
        len(mgs), args.tau, args.lambda_).to(device)
    model.load_state_dict(torch.load(args.model_path, map_location=device))
    model.eval()

    base_pred = model.get_embeds(g, mgs, args.neighbor_size, args.batch_size,
                                 device)
    mask = torch.cat([train_idx, val_idx])
    logits = smooth(base_pred, pos_g, labels, mask, args)
    _, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx,
                                                 val_idx, test_idx, evaluator)
    print('After smoothing: Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(
        test_acc, test_f1))
Example #2
0
def train(args):
    set_random_seed(args.seed)
    device = get_device(args.device)
    data, g, features, labels, predict_ntype, train_idx, val_idx, test_idx, _ = \
        load_data(args.dataset, device)
    add_node_feat(g, 'one-hot')

    (*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
    mgs = [mg.to(device) for mg in mgs]
    if args.use_data_pos:
        pos_v, pos_u = data.pos
        pos_g = dgl.graph((pos_u, pos_v), device=device)
    pos = torch.zeros((g.num_nodes(predict_ntype), g.num_nodes(predict_ntype)),
                      dtype=torch.int,
                      device=device)
    pos[data.pos] = 1

    model = RHCOFull(
        {ntype: g.nodes[ntype].data['feat'].shape[1]
         for ntype in g.ntypes}, args.num_hidden, data.num_classes,
        args.num_rel_hidden, args.num_heads, g.ntypes,
        g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
        len(mgs), args.tau, args.lambda_).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=args.epochs,
                                                     eta_min=args.lr / 100)
    alpha = args.contrast_weight
    warnings.filterwarnings(
        'ignore', 'Setting attributes on ParameterDict is not supported')
    for epoch in range(args.epochs):
        model.train()
        contrast_loss, logits = model(g, g.ndata['feat'], mgs, features, pos)
        clf_loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        loss = alpha * contrast_loss + (1 - alpha) * clf_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        torch.cuda.empty_cache()
        print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
            epoch, loss.item(),
            *evaluate(model, g, labels, train_idx, val_idx, test_idx)))

    model.eval()
    _, base_pred = model(g, g.ndata['feat'], mgs, features, pos)
    mask = torch.cat([train_idx, val_idx])
    logits = smooth(base_pred, pos_g, labels, mask, args)
    _, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx,
                                                 val_idx, test_idx)
    print('After smoothing: Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(
        test_acc, test_f1))
Example #3
0
def correct_and_smooth(base_model, g, feats, labels, train_idx, val_idx,
                       test_idx, evaluator, args):
    print('Training C&S...')
    base_model.eval()
    base_pred = base_model(feats).softmax(dim=1)  # 注意要softmax

    cs = CorrectAndSmooth(args.num_correct_layers, args.correct_alpha,
                          args.correct_norm, args.num_smooth_layers,
                          args.smooth_alpha, args.smooth_norm, args.scale)
    mask = torch.cat([train_idx, val_idx])
    logits = cs(g, F.one_hot(labels).float(), base_pred, mask)
    _, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx,
                                                 val_idx, test_idx, evaluator)
    print('Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(test_acc, test_f1))
Example #4
0
def evaluate(model, mgs, feat, device, labels, num_classes, train_idx, val_idx,
             test_idx, evaluator):
    model.eval()
    embeds = model.get_embeds(mgs, [feat] * len(mgs))

    clf = nn.Linear(embeds.shape[1], num_classes).to(device)
    optimizer = optim.Adam(clf.parameters(), lr=0.05)
    best_acc, best_logits = 0, None
    for epoch in trange(200):
        clf.train()
        logits = clf(embeds)
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            clf.eval()
            logits = clf(embeds)
            predict = logits.argmax(dim=1)
            if accuracy(predict[val_idx], labels[val_idx]) > best_acc:
                best_logits = logits
    return calc_metrics(best_logits, labels, train_idx, val_idx, test_idx,
                        evaluator)
Example #5
0
def evaluate(model, g, features, labels, train_idx, val_idx, test_idx, evaluator):
    model.eval()
    logits = model(g, features)
    return calc_metrics(logits, labels, train_idx, val_idx, test_idx, evaluator)
Example #6
0
def evaluate(model, g, mgs, neighbor_size, batch_size, device, labels, train_idx, val_idx, test_idx, evaluator):
    model.eval()
    embeds = model.get_embeds(g, mgs, neighbor_size, batch_size, device)
    return calc_metrics(embeds, labels, train_idx, val_idx, test_idx, evaluator)
Example #7
0
def evaluate(model, g, labels, train_idx, val_idx, test_idx):
    model.eval()
    embeds = model.get_embeds(g)
    return calc_metrics(embeds, labels, train_idx, val_idx, test_idx)