import torch.nn.functional as F from tqdm import tqdm from ogb.nodeproppred import PygNodePropPredDataset, Evaluator from torch_geometric.loader import NeighborSampler from torch_geometric.nn import SAGEConv root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products') dataset = PygNodePropPredDataset('ogbn-products', root) split_idx = dataset.get_idx_split() evaluator = Evaluator(name='ogbn-products') data = dataset[0] train_idx = split_idx['train'] train_loader = NeighborSampler(data.edge_index, node_idx=train_idx, sizes=[15, 10, 5], batch_size=1024, shuffle=True, num_workers=12) subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=4096, shuffle=False, num_workers=12) class SAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers): super().__init__() self.num_layers = num_layers
def train_dataloader(self): return NeighborSampler(self.data.adj_t, node_idx=self.data.train_mask, sizes=[25, 10], return_e_id=False, transform=self.convert_batch, batch_size=1024, shuffle=True, num_workers=6, persistent_workers=True)
def train_loop_per_worker(train_loop_config): dataset = train_loop_config["dataset_fn"]() batch_size = train_loop_config["batch_size"] num_epochs = train_loop_config["num_epochs"] data = dataset[0] train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split( train_idx.size(0) // session.get_world_size())[session.get_world_rank()] train_loader = NeighborSampler( data.edge_index, node_idx=train_idx, sizes=[25, 10], batch_size=batch_size, shuffle=True, ) # Disable distributed sampler since the train_loader has already been split above. train_loader = train.torch.prepare_data_loader(train_loader, add_dist_sampler=False) # Do validation on rank 0 worker only. if session.get_world_rank() == 0: subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=2048, shuffle=False) subgraph_loader = train.torch.prepare_data_loader( subgraph_loader, add_dist_sampler=False) model = SAGE(dataset.num_features, 256, dataset.num_classes) model = train.torch.prepare_model(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) x, y = data.x.to(train.torch.get_device()), data.y.to( train.torch.get_device()) for epoch in range(num_epochs): model.train() # ``batch_size`` is the number of samples in the current batch. # ``n_id`` are the ids of all the nodes used in the computation. This is # needed to pull in the necessary features just for the current batch that is # being trained on. # ``adjs`` is a list of 3 element tuple consisting of ``(edge_index, e_id, # size)`` for each sample in the batch, where ``edge_index``represent the # edges of the sampled subgraph, ``e_id`` are the ids of the edges in the # sample, and ``size`` holds the shape of the subgraph. # See ``torch_geometric.loader.neighbor_sampler.NeighborSampler`` for more info. for batch_size, n_id, adjs in train_loader: optimizer.zero_grad() out = model(x[n_id], adjs) loss = F.nll_loss(out, y[n_id[:batch_size]]) loss.backward() optimizer.step() if session.get_world_rank() == 0: print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}") train_accuracy = validation_accuracy = test_accuracy = None # Do validation on rank 0 worker only. if session.get_world_rank() == 0: model.eval() with torch.no_grad(): out = model.module.test(x, subgraph_loader) res = out.argmax(dim=-1) == data.y train_accuracy = int(res[data.train_mask].sum()) / int( data.train_mask.sum()) validation_accuracy = int(res[data.val_mask].sum()) / int( data.val_mask.sum()) test_accuracy = int(res[data.test_mask].sum()) / int( data.test_mask.sum()) session.report( dict( train_accuracy=train_accuracy, validation_accuracy=validation_accuracy, test_accuracy=test_accuracy, ))
import torch import torch.nn as nn from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborSampler from torch_geometric.nn import SAGEConv from torch_geometric.nn import DeepGraphInfomax path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) data = dataset[0] train_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[10, 10, 25], batch_size=256, shuffle=True, num_workers=12) test_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[10, 10, 25], batch_size=256, shuffle=False, num_workers=12) class Encoder(nn.Module): def __init__(self, in_channels, hidden_channels): super(Encoder, self).__init__() self.convs = torch.nn.ModuleList([
def test_neighbor_sampler_on_cora(): root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dataset = Planetoid(root, 'Cora') data = dataset[0] batch = torch.arange(10) loader = NeighborSampler(data.edge_index, sizes=[-1, -1, -1], node_idx=batch, batch_size=10) class SAGE(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, 16)) self.convs.append(SAGEConv(16, 16)) self.convs.append(SAGEConv(16, out_channels)) def batch(self, x, adjs): for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) return x def full(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) return x model = SAGE(dataset.num_features, dataset.num_classes) _, n_id, adjs = next(iter(loader)) out1 = model.batch(data.x[n_id], adjs) out2 = model.full(data.x, data.edge_index)[batch] assert torch.allclose(out1, out2) class GAT(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.convs = torch.nn.ModuleList() self.convs.append(GATConv(in_channels, 16, heads=2)) self.convs.append(GATConv(32, 16, heads=2)) self.convs.append(GATConv(32, out_channels, heads=2, concat=False)) def batch(self, x, adjs): for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) return x def full(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) return x _, n_id, adjs = next(iter(loader)) out1 = model.batch(data.x[n_id], adjs) out2 = model.full(data.x, data.edge_index)[batch] assert torch.allclose(out1, out2) shutil.rmtree(root)
import os.path as osp import torch import torch.nn.functional as F from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborSampler from torch_geometric.nn import SAGEConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit') dataset = Reddit(path) data = dataset[0] train_loader = NeighborSampler(data.edge_index, node_idx=data.train_mask, sizes=[25, 10], batch_size=1024, shuffle=True, num_workers=12) subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12) class SAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(SAGE, self).__init__() self.num_layers = 2
dataset = Reddit('../data/Reddit') data = dataset[0] cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir) train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12) subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12) class Net(torch.nn.Module): def __init__(self, in_channels, out_channels): super(Net, self).__init__() self.convs = ModuleList( [SAGEConv(in_channels, 128), SAGEConv(128, out_channels)]) def forward(self, x, edge_index): for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i != len(self.convs) - 1:
(h.squeeze(0), m), dim=1))) + prev out = out.squeeze(0) out = F.relu(out) return out model = AutoEncoder(node_hidden_dim, Encoder()).to(dev) data.train_mask = data.val_mask = data.test_mask = None tr_data, val_data, ts_data = model.split_edges(data, val_ratio=val_ratio, test_ratio=test_ratio) tr_loader = NeighborSampler(tr_data, size=[5] * num_step_message_passing, num_hops=num_step_message_passing, batch_size=batch_size, bipartite=False, shuffle=True) val_loader = NeighborSampler(val_data, size=[5] * num_step_message_passing, num_hops=num_step_message_passing, batch_size=batch_size, bipartite=False) ts_loader = NeighborSampler(ts_data, size=[5] * num_step_message_passing, num_hops=num_step_message_passing, batch_size=batch_size, bipartite=False) optimizer = torch.optim.Adam(model.parameters(), lr=lr)
def run(rank, world_size, dataset): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) data = dataset[0] train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) train_idx = train_idx.split(train_idx.size(0) // world_size)[rank] train_loader = NeighborSampler(data.edge_index, node_idx=train_idx, sizes=[25, 10], batch_size=1024, shuffle=True, num_workers=0) if rank == 0: subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1], batch_size=2048, shuffle=False, num_workers=6) torch.manual_seed(12345) model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank) model = DistributedDataParallel(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) x, y = data.x.to(rank), data.y.to(rank) for epoch in range(1, 21): model.train() for batch_size, n_id, adjs in train_loader: adjs = [adj.to(rank) for adj in adjs] optimizer.zero_grad() out = model(x[n_id], adjs) loss = F.nll_loss(out, y[n_id[:batch_size]]) loss.backward() optimizer.step() dist.barrier() if rank == 0: print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now model.eval() with torch.no_grad(): out = model.module.inference(x, rank, subgraph_loader) res = out.argmax(dim=-1) == data.y acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum()) acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum()) acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum()) print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}') dist.barrier() dist.destroy_process_group()