Esempio n. 1
0
            is_sorted=True)
        adj_t = adj_t.to_symmetric()
        torch.save(adj_t, path)
    adj_t = gcn_norm(adj_t, add_self_loops=False)
    if args.low_memory:
        adj_t = adj_t.to(torch.half)
    print(f'Done! [{time.perf_counter() - t:.2f}s]')

    train_idx = dataset.get_idx_split('train')
    valid_idx = dataset.get_idx_split('valid')
    test_idx = dataset.get_idx_split('test')

    y_train = torch.from_numpy(dataset.paper_label[train_idx]).to(torch.long)
    y_valid = torch.from_numpy(dataset.paper_label[valid_idx]).to(torch.long)

    model = LabelPropagation(args.num_layers, args.alpha)

    N, C = dataset.num_papers, dataset.num_classes

    t = time.perf_counter()
    print('Propagating labels...', end=' ', flush=True)
    if args.low_memory:
        y = torch.zeros(N, C, dtype=torch.half)
        y[train_idx] = F.one_hot(y_train, C).to(torch.half)
        out = model(y, adj_t, post_step=lambda x: x)
        y_pred = out.argmax(dim=-1)
    else:
        y = torch.zeros(N, C)
        y[train_idx] = F.one_hot(y_train, C).to(torch.float)
        out = model(y, adj_t)
        y_pred = out.argmax(dim=-1)
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch_geometric.transforms as T
from torch_geometric.nn import LabelPropagation

root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = PygNodePropPredDataset('ogbn-arxiv',
                                 root,
                                 transform=T.Compose([
                                     T.ToUndirected(),
                                     T.ToSparseTensor(),
                                 ]))
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-arxiv')
data = dataset[0]

model = LabelPropagation(num_layers=3, alpha=0.9)
out = model(data.y, data.adj_t, mask=split_idx['train'])

y_pred = out.argmax(dim=-1, keepdim=True)

val_acc = evaluator.eval({
    'y_true': data.y[split_idx['valid']],
    'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
    'y_true': data.y[split_idx['test']],
    'y_pred': y_pred[split_idx['test']],
})['acc']

print(f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')