def test_neighbor_loader_with_imbalanced_sampler():
    zeros = torch.zeros(10, dtype=torch.long)
    ones = torch.ones(90, dtype=torch.long)

    y = torch.cat([zeros, ones], dim=0)
    edge_index = torch.empty((2, 0), dtype=torch.long)
    data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))

    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data)
    loader = NeighborLoader(data,
                            batch_size=10,
                            sampler=sampler,
                            num_neighbors=[-1])

    y = torch.cat([batch.y for batch in loader])

    histogram = y.bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == data.num_nodes
    assert prob.min() > 0.4 and prob.max() < 0.6

    # Test with label tensor as input:
    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data.y)
    loader = NeighborLoader(data,
                            batch_size=10,
                            sampler=sampler,
                            num_neighbors=[-1])

    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))
Exemple #2
0
def test_homogeneous_neighbor_loader(directed):
    torch.manual_seed(12345)

    data = Data()

    data.x = torch.arange(100)
    data.edge_index = get_edge_index(100, 100, 500)
    data.edge_attr = torch.arange(500)

    loader = NeighborLoader(data,
                            num_neighbors=[5] * 2,
                            batch_size=20,
                            directed=directed)

    assert str(loader) == 'NeighborLoader()'
    assert len(loader) == 5

    for batch in loader:
        assert isinstance(batch, Data)

        assert len(batch) == 4
        assert batch.x.size(0) <= 100
        assert batch.batch_size == 20
        assert batch.x.min() >= 0 and batch.x.max() < 100
        assert batch.edge_index.min() >= 0
        assert batch.edge_index.max() < batch.num_nodes
        assert batch.edge_attr.min() >= 0
        assert batch.edge_attr.max() < 500

        assert is_subset(batch.edge_index, data.edge_index, batch.x, batch.x)

        # Test for isolated nodes (there shouldn't exist any):
        assert data.edge_index.view(-1).unique().numel() == data.num_nodes
def test_pna_conv_get_degree_histogram():
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
    data = Data(num_nodes=5, edge_index=edge_index)
    loader = NeighborLoader(
        data,
        num_neighbors=[-1],
        input_nodes=None,
        batch_size=5,
        shuffle=False,
    )
    deg_hist = PNAConv.get_degree_histogram(loader)
    deg_hist_ref = torch.tensor([1, 2, 1, 1])
    assert torch.equal(deg_hist_ref, deg_hist)

    edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
    edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]])
    edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]])
    edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]])

    data_1 = Data(num_nodes=5,
                  edge_index=edge_index_1)  # deg_hist = [1, 2 ,1 ,1]
    data_2 = Data(num_nodes=5, edge_index=edge_index_2)  # deg_hist = [1, 1, 3]
    data_3 = Data(num_nodes=5, edge_index=edge_index_3)  # deg_hist = [0, 3, 2]
    data_4 = Data(num_nodes=5, edge_index=edge_index_4)  # deg_hist = [1, 1, 3]

    loader = DataLoader(
        [data_1, data_2, data_3, data_4],
        batch_size=1,
        shuffle=False,
    )
    deg_hist = PNAConv.get_degree_histogram(loader)
    deg_hist_ref = torch.tensor([3, 7, 9, 1])
    assert torch.equal(deg_hist_ref, deg_hist)
Exemple #4
0
 def dataloader(self, mask: Tensor, shuffle: bool, num_workers: int = 6):
     return NeighborLoader(self.data,
                           num_neighbors=[10, 10],
                           input_nodes=('paper', mask),
                           batch_size=1024,
                           shuffle=shuffle,
                           num_workers=num_workers,
                           persistent_workers=num_workers > 0)
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
                                                 GraphStore):
    # Initialize dataset (once):
    dataset = get_dataset(name='Cora')
    data = dataset[0]

    # Initialize feature store, graph store, and reference:
    feature_store = FeatureStore()
    graph_store = GraphStore()
    hetero_data = HeteroData()

    feature_store.put_tensor(data.x,
                             group_name='paper',
                             attr_name='x',
                             index=None)
    hetero_data['paper'].x = data.x

    feature_store.put_tensor(torch.arange(data.num_nodes),
                             group_name='paper',
                             attr_name='time',
                             index=None)
    hetero_data['paper'].time = torch.arange(data.num_nodes)

    num_nodes = data.x.size(dim=0)
    graph_store.put_edge_index(edge_index=data.edge_index,
                               edge_type=('paper', 'to', 'paper'),
                               layout='coo',
                               size=(num_nodes, num_nodes))
    hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index

    loader1 = NeighborLoader(hetero_data,
                             num_neighbors=[-1, -1],
                             input_nodes='paper',
                             time_attr='time',
                             batch_size=128)

    loader2 = NeighborLoader(
        (feature_store, graph_store),
        num_neighbors=[-1, -1],
        input_nodes=TensorAttr(group_name='paper', attr_name='x'),
        time_attr='time',
        batch_size=128,
    )

    for batch1, batch2 in zip(loader1, loader2):
        assert torch.equal(batch1['paper'].time, batch2['paper'].time)
def test_heterogeneous_neighbor_loader_on_cora(directed):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]
    data.edge_weight = torch.rand(data.num_edges)

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].n_id = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index
    hetero_data['paper', 'paper'].edge_weight = data.edge_weight

    split_idx = torch.arange(5, 8)

    loader = NeighborLoader(hetero_data,
                            num_neighbors=[-1, -1],
                            batch_size=split_idx.numel(),
                            input_nodes=('paper', split_idx),
                            directed=directed)
    assert len(loader) == 1

    hetero_batch = next(iter(loader))
    batch_size = hetero_batch['paper'].batch_size

    if not directed:
        n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                            num_hops=2,
                                            edge_index=data.edge_index,
                                            num_nodes=data.num_nodes)

        n_id = n_id.sort()[0]
        assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()
        assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)
    hetero_model = to_hetero(model, hetero_data.metadata())

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
                        hetero_batch.edge_weight_dict)['paper'][:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)

    try:
        shutil.rmtree(root)
    except PermissionError:
        pass
def test_temporal_heterogeneous_neighbor_loader_on_cora(get_dataset):
    dataset = get_dataset(name='Cora')
    data = dataset[0]

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].time = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index

    loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
                            input_nodes='paper', time_attr='time',
                            batch_size=1)

    for batch in loader:
        mask = batch['paper'].time[0] >= batch['paper'].time[1:]
        assert torch.all(mask)
def test_homogeneous_neighbor_loader_on_cora(directed):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]
    data.n_id = torch.arange(data.num_nodes)
    data.edge_weight = torch.rand(data.num_edges)

    split_idx = torch.arange(5, 8)

    loader = NeighborLoader(data,
                            num_neighbors=[-1, -1],
                            batch_size=split_idx.numel(),
                            input_nodes=split_idx,
                            directed=directed)
    assert len(loader) == 1

    batch = next(iter(loader))
    batch_size = batch.batch_size

    if not directed:
        n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                            num_hops=2,
                                            edge_index=data.edge_index,
                                            num_nodes=data.num_nodes)

        assert n_id.sort()[0].tolist() == batch.n_id.sort()[0].tolist()
        assert batch.num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = model(batch.x, batch.edge_index, batch.edge_weight)[:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)

    shutil.rmtree(root)
Exemple #9
0
def test_basic_gnn_inference(get_dataset, jk):
    dataset = get_dataset(name='Cora')
    data = dataset[0]

    model = GraphSAGE(dataset.num_features, hidden_channels=16, num_layers=2,
                      out_channels=dataset.num_classes, jk=jk)
    model.eval()

    out1 = model(data.x, data.edge_index)
    assert out1.size() == (data.num_nodes, dataset.num_classes)

    loader = NeighborLoader(data, num_neighbors=[-1], batch_size=128)
    out2 = model.inference(loader)
    assert out1.size() == out2.size()
    assert torch.allclose(out1, out2, atol=1e-4)

    assert 'n_id' not in data
Exemple #10
0
def test_neighbor_loader_with_imbalanced_sampler():
    zeros = torch.zeros(10, dtype=torch.long)
    ones = torch.ones(90, dtype=torch.long)

    y = torch.cat([zeros, ones], dim=0)
    edge_index = torch.empty((2, 0), dtype=torch.long)
    data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))

    torch.manual_seed(12345)
    sampler = ImbalancedSampler(data)
    loader = NeighborLoader(data,
                            batch_size=10,
                            sampler=sampler,
                            num_neighbors=[-1])

    ys: List[Tensor] = []
    for batch in loader:
        ys.append(batch.y)

    histogram = torch.cat(ys).bincount()
    prob = histogram / histogram.sum()

    assert histogram.sum() == data.num_nodes
    assert prob.min() > 0.4 and prob.max() < 0.6
Exemple #11
0
from tqdm import tqdm

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)

# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[25, 10], shuffle=True, **kwargs)

subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
                                 num_neighbors=[-1], shuffle=False, **kwargs)

# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
Exemple #12
0
def test_heterogeneous_neighbor_loader(directed):
    torch.manual_seed(12345)

    data = HeteroData()

    data['paper'].x = torch.arange(100)
    data['author'].x = torch.arange(100, 300)

    data['paper', 'paper'].edge_index = get_edge_index(100, 100, 500)
    data['paper', 'paper'].edge_attr = torch.arange(500)
    data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
    data['paper', 'author'].edge_attr = torch.arange(500, 1500)
    data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)
    data['author', 'paper'].edge_attr = torch.arange(1500, 2500)

    r1, c1 = data['paper', 'paper'].edge_index
    r2, c2 = data['paper', 'author'].edge_index + torch.tensor([[0], [100]])
    r3, c3 = data['author', 'paper'].edge_index + torch.tensor([[100], [0]])
    full_adj = SparseTensor(
        row=torch.cat([r1, r2, r3]),
        col=torch.cat([c1, c2, c3]),
        value=torch.arange(2500),
    )

    batch_size = 20
    loader = NeighborLoader(data,
                            num_neighbors=[10] * 2,
                            input_nodes='paper',
                            batch_size=batch_size,
                            directed=directed)
    assert str(loader) == 'NeighborLoader()'
    assert len(loader) == (100 + batch_size - 1) // batch_size

    for batch in loader:
        assert isinstance(batch, HeteroData)

        # Test node type selection:
        assert set(batch.node_types) == {'paper', 'author'}

        assert len(batch['paper']) == 2
        assert batch['paper'].x.size(0) <= 100
        assert batch['paper'].batch_size == batch_size
        assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

        assert len(batch['author']) == 1
        assert batch['author'].x.size(0) <= 200
        assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300

        # Test edge type selection:
        assert set(batch.edge_types) == {('paper', 'to', 'paper'),
                                         ('paper', 'to', 'author'),
                                         ('author', 'to', 'paper')}

        assert len(batch['paper', 'paper']) == 2
        row, col = batch['paper', 'paper'].edge_index
        value = batch['paper', 'paper'].edge_attr
        assert row.min() >= 0 and row.max() < batch['paper'].num_nodes
        assert col.min() >= 0 and col.max() < batch['paper'].num_nodes
        assert value.min() >= 0 and value.max() < 500
        if not directed:
            adj = full_adj[batch['paper'].x, batch['paper'].x]
            assert adj.nnz() == row.size(0)
            assert torch.allclose(row.unique(), adj.storage.row().unique())
            assert torch.allclose(col.unique(), adj.storage.col().unique())
            assert torch.allclose(value.unique(), adj.storage.value().unique())

        assert is_subset(batch['paper', 'paper'].edge_index,
                         data['paper', 'paper'].edge_index, batch['paper'].x,
                         batch['paper'].x)

        assert len(batch['paper', 'author']) == 2
        row, col = batch['paper', 'author'].edge_index
        value = batch['paper', 'author'].edge_attr
        assert row.min() >= 0 and row.max() < batch['paper'].num_nodes
        assert col.min() >= 0 and col.max() < batch['author'].num_nodes
        assert value.min() >= 500 and value.max() < 1500
        if not directed:
            adj = full_adj[batch['paper'].x, batch['author'].x]
            assert adj.nnz() == row.size(0)
            assert torch.allclose(row.unique(), adj.storage.row().unique())
            assert torch.allclose(col.unique(), adj.storage.col().unique())
            assert torch.allclose(value.unique(), adj.storage.value().unique())

        assert is_subset(batch['paper', 'author'].edge_index,
                         data['paper', 'author'].edge_index, batch['paper'].x,
                         batch['author'].x - 100)

        assert len(batch['author', 'paper']) == 2
        row, col = batch['author', 'paper'].edge_index
        value = batch['author', 'paper'].edge_attr
        assert row.min() >= 0 and row.max() < batch['author'].num_nodes
        assert col.min() >= 0 and col.max() < batch['paper'].num_nodes
        assert value.min() >= 1500 and value.max() < 2500
        if not directed:
            adj = full_adj[batch['author'].x, batch['paper'].x]
            assert adj.nnz() == row.size(0)
            assert torch.allclose(row.unique(), adj.storage.row().unique())
            assert torch.allclose(col.unique(), adj.storage.col().unique())
            assert torch.allclose(value.unique(), adj.storage.value().unique())

        assert is_subset(batch['author', 'paper'].edge_index,
                         data['author', 'paper'].edge_index,
                         batch['author'].x - 100, batch['paper'].x)

        # Test for isolated nodes (there shouldn't exist any):
        n_id = torch.cat([batch['paper'].x, batch['author'].x])
        row, col, _ = full_adj[n_id, n_id].coo()
        assert torch.cat([row, col]).unique().numel() == n_id.numel()
path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB')
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG(path, preprocess='metapath2vec', transform=transform)

# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

train_input_nodes = ('paper', data['paper'].train_mask)
val_input_nodes = ('paper', data['paper'].val_mask)
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}

if not args.use_hgt_loader:
    train_loader = NeighborLoader(data,
                                  num_neighbors=[10] * 2,
                                  shuffle=True,
                                  input_nodes=train_input_nodes,
                                  **kwargs)
    val_loader = NeighborLoader(data,
                                num_neighbors=[10] * 2,
                                input_nodes=val_input_nodes,
                                **kwargs)
else:
    train_loader = HGTLoader(data,
                             num_samples=[1024] * 4,
                             shuffle=True,
                             input_nodes=train_input_nodes,
                             **kwargs)
    val_loader = HGTLoader(data,
                           num_samples=[1024] * 4,
                           input_nodes=val_input_nodes,
def run(args: argparse.ArgumentParser) -> None:

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('BENCHMARK STARTS')
    for dataset_name in args.datasets:
        assert dataset_name in supported_sets.keys(
        ), f"Dataset {dataset_name} isn't supported."
        print(f'Dataset: {dataset_name}')
        dataset, num_classes = get_dataset(dataset_name, args.root,
                                           args.use_sparse_tensor)
        data = dataset.to(device)
        hetero = True if dataset_name == 'ogbn-mag' else False
        mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
        degree = None

        inputs_channels = data[
            'paper'].num_features if dataset_name == 'ogbn-mag' \
            else dataset.num_features

        for model_name in args.models:
            if model_name not in supported_sets[dataset_name]:
                print(f'Configuration of {dataset_name} + {model_name} '
                      f'not supported. Skipping.')
                continue
            print(f'Evaluation bench for {model_name}:')

            for batch_size in args.eval_batch_sizes:
                if not hetero:
                    subgraph_loader = NeighborLoader(
                        data,
                        num_neighbors=[-1],  # layer-wise inference
                        input_nodes=mask,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=args.num_workers,
                    )

                for layers in args.num_layers:
                    if hetero:
                        subgraph_loader = NeighborLoader(
                            data,
                            num_neighbors=[args.hetero_num_neighbors] *
                            layers,  # batch-wise inference
                            input_nodes=mask,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                        )

                    for hidden_channels in args.num_hidden_channels:
                        print('----------------------------------------------')
                        print(
                            f'Batch size={batch_size}, '
                            f'Layers amount={layers}, '
                            f'Num_neighbors={subgraph_loader.num_neighbors}, '
                            f'Hidden features size={hidden_channels}')
                        params = {
                            'inputs_channels': inputs_channels,
                            'hidden_channels': hidden_channels,
                            'output_channels': num_classes,
                            'num_heads': args.num_heads,
                            'num_layers': layers,
                        }

                        if model_name == 'pna':
                            if degree is None:
                                degree = PNAConv.get_degree_histogram(
                                    subgraph_loader)
                                print(f'Calculated degree for {dataset_name}.')
                            params['degree'] = degree

                        model = get_model(
                            model_name, params,
                            metadata=data.metadata() if hetero else None)
                        model = model.to(device)
                        model.eval()

                        for _ in range(args.warmup):
                            model.inference(subgraph_loader, device,
                                            progress_bar=True)
                        if args.experimental_mode:
                            with torch_geometric.experimental_mode():
                                with timeit():
                                    model.inference(subgraph_loader, device,
                                                    progress_bar=True)
                        else:
                            with timeit():
                                model.inference(subgraph_loader, device,
                                                progress_bar=True)

                        if args.profile:
                            with torch_profile():
                                model.inference(subgraph_loader, device,
                                                progress_bar=True)
                            rename_profile_file(
                                model_name, dataset_name, str(batch_size),
                                str(layers), str(hidden_channels),
                                str(subgraph_loader.num_neighbors))
def test_custom_neighbor_loader(FeatureStore, GraphStore):
    # Initialize feature store, graph store, and reference:
    feature_store = FeatureStore()
    graph_store = GraphStore()
    data = HeteroData()

    # Set up node features:
    x = torch.arange(100)
    data['paper'].x = x
    feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)

    x = torch.arange(100, 300)
    data['author'].x = x
    feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

    # Set up edge indices:

    # COO:
    edge_index = get_edge_index(100, 100, 500)
    data['paper', 'to', 'paper'].edge_index = edge_index
    coo = (edge_index[0], edge_index[1])
    graph_store.put_edge_index(edge_index=coo,
                               edge_type=('paper', 'to', 'paper'),
                               layout='coo',
                               size=(100, 100))

    # CSR:
    edge_index = get_edge_index(100, 200, 1000)
    data['paper', 'to', 'author'].edge_index = edge_index
    csr = SparseTensor.from_edge_index(edge_index).csr()[:2]
    graph_store.put_edge_index(edge_index=csr,
                               edge_type=('paper', 'to', 'author'),
                               layout='csr',
                               size=(100, 200))

    # CSC:
    edge_index = get_edge_index(200, 100, 1000)
    data['author', 'to', 'paper'].edge_index = edge_index
    csc = SparseTensor(row=edge_index[1], col=edge_index[0]).csr()[-2::-1]
    graph_store.put_edge_index(edge_index=csc,
                               edge_type=('author', 'to', 'paper'),
                               layout='csc',
                               size=(200, 100))

    # COO (sorted):
    edge_index = get_edge_index(200, 200, 100)
    edge_index = edge_index[:, edge_index[1].argsort()]
    data['author', 'to', 'author'].edge_index = edge_index
    coo = (edge_index[0], edge_index[1])
    graph_store.put_edge_index(edge_index=coo,
                               edge_type=('author', 'to', 'author'),
                               layout='coo',
                               size=(200, 200),
                               is_sorted=True)

    # Construct neighbor loaders:
    loader1 = NeighborLoader(data,
                             batch_size=20,
                             input_nodes=('paper', range(100)),
                             num_neighbors=[-1] * 2)

    loader2 = NeighborLoader((feature_store, graph_store),
                             batch_size=20,
                             input_nodes=('paper', range(100)),
                             num_neighbors=[-1] * 2)

    assert str(loader1) == str(loader2)
    assert len(loader1) == len(loader2)

    for batch1, batch2 in zip(loader1, loader2):
        assert len(batch1) == len(batch2)
        assert batch1['paper'].batch_size == batch2['paper'].batch_size

        # Mapped indices of neighbors may be differently sorted:
        assert torch.allclose(batch1['paper'].x.sort()[0],
                              batch2['paper'].x.sort()[0])
        assert torch.allclose(batch1['author'].x.sort()[0],
                              batch2['author'].x.sort()[0])

        assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[
            'paper', 'to', 'paper'].edge_index.size())
        assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[
            'paper', 'to', 'author'].edge_index.size())
        assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
            'author', 'to', 'paper'].edge_index.size())
def run(args: argparse.ArgumentParser) -> None:
    for dataset_name in args.datasets:
        print(f"Dataset: {dataset_name}")
        root = osp.join(args.root, dataset_name)

        if dataset_name == 'mag':
            transform = T.ToUndirected(merge=True)
            dataset = OGB_MAG(root=root, transform=transform)
            train_idx = ('paper', dataset[0]['paper'].train_mask)
            eval_idx = ('paper', None)
            neighbor_sizes = args.hetero_neighbor_sizes
        else:
            dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root)
            split_idx = dataset.get_idx_split()
            train_idx = split_idx['train']
            eval_idx = None
            neighbor_sizes = args.homo_neighbor_sizes

        data = dataset[0].to(args.device)

        for num_neighbors in neighbor_sizes:
            print(f'Training sampling with {num_neighbors} neighbors')
            for batch_size in args.batch_sizes:
                train_loader = NeighborLoader(
                    data,
                    num_neighbors=num_neighbors,
                    input_nodes=train_idx,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=args.num_workers,
                )
                runtimes = []
                num_iterations = 0
                for run in range(args.runs):
                    start = default_timer()
                    for batch in tqdm.tqdm(train_loader):
                        num_iterations += 1
                    stop = default_timer()
                    runtimes.append(round(stop - start, 3))
                average_time = round(sum(runtimes) / args.runs, 3)
                print(f'batch size={batch_size}, iterations={num_iterations}, '
                      f'runtimes={runtimes}, average runtime={average_time}')

        print('Evaluation sampling with all neighbors')
        for batch_size in args.eval_batch_sizes:
            subgraph_loader = NeighborLoader(
                data,
                num_neighbors=[-1],
                input_nodes=eval_idx,
                batch_size=batch_size,
                shuffle=False,
                num_workers=args.num_workers,
            )
            runtimes = []
            num_iterations = 0
            for run in range(args.runs):
                start = default_timer()
                for batch in tqdm.tqdm(subgraph_loader):
                    num_iterations += 1
                stop = default_timer()
                runtimes.append(round(stop - start, 3))
            average_time = round(sum(runtimes) / args.runs, 3)
            print(f'batch size={batch_size}, iterations={num_iterations}, '
                  f'runtimes={runtimes}, average runtime={average_time}')