import torch.nn.functional as F
from tqdm import tqdm
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv

root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products')
dataset = PygNodePropPredDataset('ogbn-products', root)
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
data = dataset[0]

train_idx = split_idx['train']
train_loader = NeighborSampler(data.edge_index,
                               node_idx=train_idx,
                               sizes=[15, 10, 5],
                               batch_size=1024,
                               shuffle=True,
                               num_workers=12)
subgraph_loader = NeighborSampler(data.edge_index,
                                  node_idx=None,
                                  sizes=[-1],
                                  batch_size=4096,
                                  shuffle=False,
                                  num_workers=12)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.num_layers = num_layers
Пример #2
0
 def train_dataloader(self):
     return NeighborSampler(self.data.adj_t, node_idx=self.data.train_mask,
                            sizes=[25, 10], return_e_id=False,
                            transform=self.convert_batch, batch_size=1024,
                            shuffle=True, num_workers=6,
                            persistent_workers=True)
Пример #3
0
def train_loop_per_worker(train_loop_config):
    dataset = train_loop_config["dataset_fn"]()
    batch_size = train_loop_config["batch_size"]
    num_epochs = train_loop_config["num_epochs"]

    data = dataset[0]
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    train_idx = train_idx.split(
        train_idx.size(0) //
        session.get_world_size())[session.get_world_rank()]

    train_loader = NeighborSampler(
        data.edge_index,
        node_idx=train_idx,
        sizes=[25, 10],
        batch_size=batch_size,
        shuffle=True,
    )

    # Disable distributed sampler since the train_loader has already been split above.
    train_loader = train.torch.prepare_data_loader(train_loader,
                                                   add_dist_sampler=False)

    # Do validation on rank 0 worker only.
    if session.get_world_rank() == 0:
        subgraph_loader = NeighborSampler(data.edge_index,
                                          node_idx=None,
                                          sizes=[-1],
                                          batch_size=2048,
                                          shuffle=False)
        subgraph_loader = train.torch.prepare_data_loader(
            subgraph_loader, add_dist_sampler=False)

    model = SAGE(dataset.num_features, 256, dataset.num_classes)
    model = train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    x, y = data.x.to(train.torch.get_device()), data.y.to(
        train.torch.get_device())

    for epoch in range(num_epochs):
        model.train()

        # ``batch_size`` is the number of samples in the current batch.
        # ``n_id`` are the ids of all the nodes used in the computation. This is
        # needed to pull in the necessary features just for the current batch that is
        # being trained on.
        # ``adjs`` is a list of 3 element tuple consisting of ``(edge_index, e_id,
        # size)`` for each sample in the batch, where ``edge_index``represent the
        # edges of the sampled subgraph, ``e_id`` are the ids of the edges in the
        # sample, and ``size`` holds the shape of the subgraph.
        # See ``torch_geometric.loader.neighbor_sampler.NeighborSampler`` for more info.
        for batch_size, n_id, adjs in train_loader:
            optimizer.zero_grad()
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()

        if session.get_world_rank() == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

        train_accuracy = validation_accuracy = test_accuracy = None

        # Do validation on rank 0 worker only.
        if session.get_world_rank() == 0:
            model.eval()
            with torch.no_grad():
                out = model.module.test(x, subgraph_loader)
            res = out.argmax(dim=-1) == data.y
            train_accuracy = int(res[data.train_mask].sum()) / int(
                data.train_mask.sum())
            validation_accuracy = int(res[data.val_mask].sum()) / int(
                data.val_mask.sum())
            test_accuracy = int(res[data.test_mask].sum()) / int(
                data.test_mask.sum())

        session.report(
            dict(
                train_accuracy=train_accuracy,
                validation_accuracy=validation_accuracy,
                test_accuracy=test_accuracy,
            ))
import torch
import torch.nn as nn

from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import DeepGraphInfomax

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]

train_loader = NeighborSampler(data.edge_index,
                               node_idx=None,
                               sizes=[10, 10, 25],
                               batch_size=256,
                               shuffle=True,
                               num_workers=12)

test_loader = NeighborSampler(data.edge_index,
                              node_idx=None,
                              sizes=[10, 10, 25],
                              batch_size=256,
                              shuffle=False,
                              num_workers=12)


class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(Encoder, self).__init__()
        self.convs = torch.nn.ModuleList([
def test_neighbor_sampler_on_cora():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]

    batch = torch.arange(10)
    loader = NeighborSampler(data.edge_index,
                             sizes=[-1, -1, -1],
                             node_idx=batch,
                             batch_size=10)

    class SAGE(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()

            self.convs = torch.nn.ModuleList()
            self.convs.append(SAGEConv(in_channels, 16))
            self.convs.append(SAGEConv(16, 16))
            self.convs.append(SAGEConv(16, out_channels))

        def batch(self, x, adjs):
            for i, (edge_index, _, size) in enumerate(adjs):
                x_target = x[:size[1]]  # Target nodes are always placed first.
                x = self.convs[i]((x, x_target), edge_index)
            return x

        def full(self, x, edge_index):
            for conv in self.convs:
                x = conv(x, edge_index)
            return x

    model = SAGE(dataset.num_features, dataset.num_classes)

    _, n_id, adjs = next(iter(loader))
    out1 = model.batch(data.x[n_id], adjs)
    out2 = model.full(data.x, data.edge_index)[batch]
    assert torch.allclose(out1, out2)

    class GAT(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()

            self.convs = torch.nn.ModuleList()
            self.convs.append(GATConv(in_channels, 16, heads=2))
            self.convs.append(GATConv(32, 16, heads=2))
            self.convs.append(GATConv(32, out_channels, heads=2, concat=False))

        def batch(self, x, adjs):
            for i, (edge_index, _, size) in enumerate(adjs):
                x_target = x[:size[1]]  # Target nodes are always placed first.
                x = self.convs[i]((x, x_target), edge_index)
            return x

        def full(self, x, edge_index):
            for conv in self.convs:
                x = conv(x, edge_index)
            return x

    _, n_id, adjs = next(iter(loader))
    out1 = model.batch(data.x[n_id], adjs)
    out2 = model.full(data.x, data.edge_index)[batch]
    assert torch.allclose(out1, out2)

    shutil.rmtree(root)
Пример #6
0
import os.path as osp

import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]

train_loader = NeighborSampler(data.edge_index,
                               node_idx=data.train_mask,
                               sizes=[25, 10],
                               batch_size=1024,
                               shuffle=True,
                               num_workers=12)
subgraph_loader = NeighborSampler(data.edge_index,
                                  node_idx=None,
                                  sizes=[-1],
                                  batch_size=1024,
                                  shuffle=False,
                                  num_workers=12)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SAGE, self).__init__()

        self.num_layers = 2
Пример #7
0
dataset = Reddit('../data/Reddit')
data = dataset[0]

cluster_data = ClusterData(data,
                           num_parts=1500,
                           recursive=False,
                           save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data,
                             batch_size=20,
                             shuffle=True,
                             num_workers=12)

subgraph_loader = NeighborSampler(data.edge_index,
                                  sizes=[-1],
                                  batch_size=1024,
                                  shuffle=False,
                                  num_workers=12)


class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.convs = ModuleList(
            [SAGEConv(in_channels, 128),
             SAGEConv(128, out_channels)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1:
Пример #8
0
                (h.squeeze(0), m), dim=1))) + prev
            out = out.squeeze(0)
        out = F.relu(out)

        return out


model = AutoEncoder(node_hidden_dim, Encoder()).to(dev)

data.train_mask = data.val_mask = data.test_mask = None
tr_data, val_data, ts_data = model.split_edges(data,
                                               val_ratio=val_ratio,
                                               test_ratio=test_ratio)
tr_loader = NeighborSampler(tr_data,
                            size=[5] * num_step_message_passing,
                            num_hops=num_step_message_passing,
                            batch_size=batch_size,
                            bipartite=False,
                            shuffle=True)
val_loader = NeighborSampler(val_data,
                             size=[5] * num_step_message_passing,
                             num_hops=num_step_message_passing,
                             batch_size=batch_size,
                             bipartite=False)
ts_loader = NeighborSampler(ts_data,
                            size=[5] * num_step_message_passing,
                            num_hops=num_step_message_passing,
                            batch_size=batch_size,
                            bipartite=False)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def run(rank, world_size, dataset):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    data = dataset[0]
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]

    train_loader = NeighborSampler(data.edge_index,
                                   node_idx=train_idx,
                                   sizes=[25, 10],
                                   batch_size=1024,
                                   shuffle=True,
                                   num_workers=0)

    if rank == 0:
        subgraph_loader = NeighborSampler(data.edge_index,
                                          node_idx=None,
                                          sizes=[-1],
                                          batch_size=2048,
                                          shuffle=False,
                                          num_workers=6)

    torch.manual_seed(12345)
    model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    x, y = data.x.to(rank), data.y.to(rank)

    for epoch in range(1, 21):
        model.train()

        for batch_size, n_id, adjs in train_loader:
            adjs = [adj.to(rank) for adj in adjs]

            optimizer.zero_grad()
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()

        dist.barrier()

        if rank == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

        if rank == 0 and epoch % 5 == 0:  # We evaluate on a single GPU for now
            model.eval()
            with torch.no_grad():
                out = model.module.inference(x, rank, subgraph_loader)
            res = out.argmax(dim=-1) == data.y
            acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum())
            acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum())
            acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum())
            print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}')

        dist.barrier()

    dist.destroy_process_group()