def test_to_hetero_with_basic_model():
    x_dict = {
        'paper': torch.randn(100, 16),
        'author': torch.randn(100, 16),
    }
    edge_index_dict = {
        ('paper', 'cites', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('paper', 'written_by', 'author'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('author', 'writes', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
    }

    metadata = list(x_dict.keys()), list(edge_index_dict.keys())

    model = GraphSAGE((-1, -1), 32, num_layers=3)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2

    model = GAT((-1, -1), 32, num_layers=3, add_self_loops=False)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
def test_to_hetero_and_rgcn_equal_output():
    torch.manual_seed(1234)

    # Run `RGCN`:
    x = torch.randn(10, 16)  # 6 paper nodes, 4 author nodes
    adj = (torch.rand(10, 10) > 0.5)
    adj[6:, 6:] = False
    edge_index = adj.nonzero(as_tuple=False).t().contiguous()
    row, col = edge_index

    # # 0 = paper<->paper, 1 = paper->author, 2 = author->paper
    edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long)
    edge_type[(row < 6) & (col < 6)] = 0
    edge_type[(row < 6) & (col >= 6)] = 1
    edge_type[(row >= 6) & (col < 6)] = 2
    assert edge_type.min() == 0

    conv = RGCNConv(16, 32, num_relations=3)
    out1 = conv(x, edge_index, edge_type)

    # Run `to_hetero`:
    x_dict = {
        'paper': x[:6],
        'author': x[6:],
    }
    edge_index_dict = {
        ('paper', '_', 'paper'):
        edge_index[:, edge_type == 0],
        ('paper', '_', 'author'):
        edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]),
        ('author', '_', 'paper'):
        edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]),
    }

    node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys())

    adj_t_dict = {
        key: SparseTensor.from_edge_index(edge_index).t()
        for key, edge_index in edge_index_dict.items()
    }

    model = to_hetero(RGCN(16, 32), (node_types, edge_types))

    # Set model weights:
    for i, edge_type in enumerate(edge_types):
        weight = model.conv['__'.join(edge_type)].lin.weight
        weight.data = conv.weight[i].data.t()
    for i, node_type in enumerate(node_types):
        model.lin[node_type].weight.data = conv.root.data.t()
        model.lin[node_type].bias.data = conv.bias.data

    out2 = model(x_dict, edge_index_dict)
    out2 = torch.cat([out2['paper'], out2['author']], dim=0)
    assert torch.allclose(out1, out2, atol=1e-6)

    out3 = model(x_dict, adj_t_dict)
    out3 = torch.cat([out3['paper'], out3['author']], dim=0)
    assert torch.allclose(out1, out3, atol=1e-6)
def test_heterogeneous_neighbor_loader_on_cora(directed):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]
    data.edge_weight = torch.rand(data.num_edges)

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].n_id = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index
    hetero_data['paper', 'paper'].edge_weight = data.edge_weight

    split_idx = torch.arange(5, 8)

    loader = NeighborLoader(hetero_data,
                            num_neighbors=[-1, -1],
                            batch_size=split_idx.numel(),
                            input_nodes=('paper', split_idx),
                            directed=directed)
    assert len(loader) == 1

    hetero_batch = next(iter(loader))
    batch_size = hetero_batch['paper'].batch_size

    if not directed:
        n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                            num_hops=2,
                                            edge_index=data.edge_index,
                                            num_nodes=data.num_nodes)

        n_id = n_id.sort()[0]
        assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()
        assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)
    hetero_model = to_hetero(model, hetero_data.metadata())

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
                        hetero_batch.edge_weight_dict)['paper'][:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)

    try:
        shutil.rmtree(root)
    except PermissionError:
        pass
Example #4
0
def test_hgt_loader_on_cora(get_dataset):
    dataset = get_dataset(name='Cora')
    data = dataset[0]
    data.edge_weight = torch.rand(data.num_edges)

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].n_id = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index
    hetero_data['paper', 'paper'].edge_weight = data.edge_weight

    split_idx = torch.arange(5, 8)

    # Sample the complete two-hop neighborhood:
    loader = HGTLoader(hetero_data,
                       num_samples=[data.num_nodes] * 2,
                       batch_size=split_idx.numel(),
                       input_nodes=('paper', split_idx))
    assert len(loader) == 1

    hetero_batch = next(iter(loader))
    batch_size = hetero_batch['paper'].batch_size

    n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                        num_hops=2,
                                        edge_index=data.edge_index,
                                        num_nodes=data.num_nodes)

    n_id = n_id.sort()[0]
    assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()
    assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)
    hetero_model = to_hetero(model, hetero_data.metadata())

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
                        hetero_batch.edge_weight_dict)['paper'][:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)
Example #5
0
def test_to_hetero_with_gcn():
    metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')])
    x_dict = {'paper': torch.randn(100, 16)}
    edge_index_dict = {
        ('paper', '0', 'paper'): torch.randint(100, (2, 200)),
        ('paper', '1', 'paper'): torch.randint(100, (2, 200)),
    }

    model = GCN()
    model = to_hetero(model, metadata, debug=False)
    print(model)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 1
    assert out['paper'].size() == (100, 64)
Example #6
0
    def __init__(
        self,
        metadata: Tuple[List[NodeType], List[EdgeType]],
        hidden_channels: int,
        out_channels: int,
        dropout: float,
    ):
        super().__init__()
        self.save_hyperparameters()

        model = GNN(hidden_channels, out_channels, dropout)
        # Convert the homogeneous GNN model to a heterogeneous variant in
        # which distinct parameters are learned for each node and edge type.
        self.model = to_hetero(model, metadata, aggr='sum')

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
def test_graph_level_to_hetero():
    x_dict = {
        'paper': torch.randn(100, 16),
        'author': torch.randn(100, 16),
    }
    edge_index_dict = {
        ('paper', 'written_by', 'author'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('author', 'writes', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
    }
    batch_dict = {
        'paper': torch.zeros(100, dtype=torch.long),
        'author': torch.zeros(100, dtype=torch.long),
    }

    metadata = list(x_dict.keys()), list(edge_index_dict.keys())

    model = GraphLevelGNN()
    model = to_hetero(model, metadata, aggr='mean', debug=False)
    out = model(x_dict, edge_index_dict, batch_dict)
    assert out.size() == (1, 64)
Example #8
0
def test_to_hetero():
    metadata = (['paper', 'author'], [('paper', 'cites', 'paper'),
                                      ('paper', 'written_by', 'author'),
                                      ('author', 'writes', 'paper')])

    x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 16)}
    edge_index_dict = {
        ('paper', 'cites', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('paper', 'written_by', 'author'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('author', 'writes', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
    }
    edge_attr_dict = {
        ('paper', 'cites', 'paper'): torch.randn(200, 8),
        ('paper', 'written_by', 'author'): torch.randn(200, 8),
        ('author', 'writes', 'paper'): torch.randn(200, 8),
    }

    model = Net1()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_attr_dict)
    assert isinstance(out, tuple) and len(out) == 2
    assert isinstance(out[0], dict) and len(out[0]) == 2
    assert out[0]['paper'].size() == (100, 32)
    assert out[0]['author'].size() == (100, 32)
    assert isinstance(out[1], dict) and len(out[1]) == 3
    assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)
    assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)
    assert out[1][('author', 'writes', 'paper')].size() == (200, 16)

    for aggr in ['sum', 'mean', 'min', 'max', 'mul']:
        model = Net2()
        model = to_hetero(model, metadata, aggr='mean', debug=False)
        out = model(x_dict, edge_index_dict)
        assert isinstance(out, dict) and len(out) == 2
        assert out['paper'].size() == (100, 32)
        assert out['author'].size() == (100, 32)

    model = Net3()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict, edge_attr_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net4()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net5(num_layers=2)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 16)
    assert out['author'].size() == (100, 16)

    model = Net6(num_layers=2)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 16)
    assert out['author'].size() == (100, 16)

    model = Net7()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net8()
    model = to_hetero(model, metadata, debug=False)
    out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)})
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (4, 32)
    assert out['author'].size() == (8, 32)
Example #9
0
 def __init__(self, hidden_channels):
     super().__init__()
     self.encoder = GNNEncoder(hidden_channels, hidden_channels)
     self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
     self.decoder = EdgeDecoder(hidden_channels)
def test_hetero_transformer_self_loop_error():
    to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')]))
    with pytest.raises(ValueError, match="incorrect message passing"):
        to_hetero(ModelLoops(),
                  metadata=(['a', 'b'], [('a', 'to', 'b'), ('b', 'to', 'a')]))
def test_to_hetero():
    x_dict = {
        'paper': torch.randn(100, 16),
        'author': torch.randn(100, 16),
    }
    edge_index_dict = {
        ('paper', 'cites', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('paper', 'written_by', 'author'):
        torch.randint(100, (2, 200), dtype=torch.long),
        ('author', 'writes', 'paper'):
        torch.randint(100, (2, 200), dtype=torch.long),
    }
    adj_t_dict = {}
    for edge_type, (row, col) in edge_index_dict.items():
        adj_t_dict[edge_type] = SparseTensor(row=col,
                                             col=row,
                                             sparse_sizes=(100, 100))
    edge_attr_dict = {
        ('paper', 'cites', 'paper'): torch.randn(200, 8),
        ('paper', 'written_by', 'author'): torch.randn(200, 8),
        ('author', 'writes', 'paper'): torch.randn(200, 8),
    }

    metadata = list(x_dict.keys()), list(edge_index_dict.keys())

    model = Net1()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_attr_dict)
    assert isinstance(out, tuple) and len(out) == 2
    assert isinstance(out[0], dict) and len(out[0]) == 2
    assert out[0]['paper'].size() == (100, 32)
    assert out[0]['author'].size() == (100, 32)
    assert isinstance(out[1], dict) and len(out[1]) == 3
    assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)
    assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)
    assert out[1][('author', 'writes', 'paper')].size() == (200, 16)
    assert sum(p.numel() for p in model.parameters()) == 1520

    for aggr in ['sum', 'mean', 'min', 'max', 'mul']:
        model = Net2()
        model = to_hetero(model, metadata, aggr='mean', debug=False)
        assert sum(p.numel() for p in model.parameters()) == 5824

        out1 = model(x_dict, edge_index_dict)
        assert isinstance(out1, dict) and len(out1) == 2
        assert out1['paper'].size() == (100, 32)
        assert out1['author'].size() == (100, 32)

        out2 = model(x_dict, adj_t_dict)
        assert isinstance(out2, dict) and len(out2) == 2
        for node_type in x_dict.keys():
            assert torch.allclose(out1[node_type], out2[node_type], atol=1e-6)

    model = Net3()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict, edge_attr_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net4()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net5(num_layers=2)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 16)
    assert out['author'].size() == (100, 16)

    model = Net6(num_layers=2)
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 16)
    assert out['author'].size() == (100, 16)

    model = Net7()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)

    model = Net8()
    model = to_hetero(model, metadata, debug=False)
    out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)})
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (4, 32)
    assert out['author'].size() == (8, 32)

    model = Net9()
    model = to_hetero(model, metadata, debug=False)
    out = model({'paper': torch.randn(4, 16), 'author': torch.randn(8, 16)})
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (4, 16)
    assert out['author'].size() == (8, 16)

    model = Net10()
    model = to_hetero(model, metadata, debug=False)
    out = model(x_dict, edge_index_dict)
    assert isinstance(out, dict) and len(out) == 2
    assert out['paper'].size() == (100, 32)
    assert out['author'].size() == (100, 32)
                             shuffle=True,
                             input_nodes=train_input_nodes,
                             **kwargs)
    val_loader = HGTLoader(data,
                           num_samples=[1024] * 4,
                           input_nodes=val_input_nodes,
                           **kwargs)

model = Sequential('x, edge_index', [
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (Linear(-1, dataset.num_classes), 'x -> x'),
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)


@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device, 'edge_index')
    model(batch.x_dict, batch.edge_index_dict)


def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
Example #13
0
 def __init__(self, metadata, hidden_channels, num_layers, output_channels):
     super().__init__()
     self.model = to_hetero(
         GraphSAGE((-1, -1), hidden_channels, num_layers, output_channels),
         metadata)
Example #14
0
 def __init__(self, metadata, hidden_channels, num_layers, output_channels,
              num_heads):
     super().__init__()
     self.model = to_hetero(
         GAT((-1, -1), hidden_channels, num_layers, output_channels,
             add_self_loops=False, heads=num_heads), metadata)