def test_to_hetero_with_basic_model(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphSAGE((-1, -1), 32, num_layers=3) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 model = GAT((-1, -1), 32, num_layers=3, add_self_loops=False) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2
def test_to_hetero_and_rgcn_equal_output(): torch.manual_seed(1234) # Run `RGCN`: x = torch.randn(10, 16) # 6 paper nodes, 4 author nodes adj = (torch.rand(10, 10) > 0.5) adj[6:, 6:] = False edge_index = adj.nonzero(as_tuple=False).t().contiguous() row, col = edge_index # # 0 = paper<->paper, 1 = paper->author, 2 = author->paper edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long) edge_type[(row < 6) & (col < 6)] = 0 edge_type[(row < 6) & (col >= 6)] = 1 edge_type[(row >= 6) & (col < 6)] = 2 assert edge_type.min() == 0 conv = RGCNConv(16, 32, num_relations=3) out1 = conv(x, edge_index, edge_type) # Run `to_hetero`: x_dict = { 'paper': x[:6], 'author': x[6:], } edge_index_dict = { ('paper', '_', 'paper'): edge_index[:, edge_type == 0], ('paper', '_', 'author'): edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]), ('author', '_', 'paper'): edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]), } node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys()) adj_t_dict = { key: SparseTensor.from_edge_index(edge_index).t() for key, edge_index in edge_index_dict.items() } model = to_hetero(RGCN(16, 32), (node_types, edge_types)) # Set model weights: for i, edge_type in enumerate(edge_types): weight = model.conv['__'.join(edge_type)].lin.weight weight.data = conv.weight[i].data.t() for i, node_type in enumerate(node_types): model.lin[node_type].weight.data = conv.root.data.t() model.lin[node_type].bias.data = conv.bias.data out2 = model(x_dict, edge_index_dict) out2 = torch.cat([out2['paper'], out2['author']], dim=0) assert torch.allclose(out1, out2, atol=1e-6) out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6)
def test_heterogeneous_neighbor_loader_on_cora(directed): root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dataset = Planetoid(root, 'Cora') data = dataset[0] data.edge_weight = torch.rand(data.num_edges) hetero_data = HeteroData() hetero_data['paper'].x = data.x hetero_data['paper'].n_id = torch.arange(data.num_nodes) hetero_data['paper', 'paper'].edge_index = data.edge_index hetero_data['paper', 'paper'].edge_weight = data.edge_weight split_idx = torch.arange(5, 8) loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1], batch_size=split_idx.numel(), input_nodes=('paper', split_idx), directed=directed) assert len(loader) == 1 hetero_batch = next(iter(loader)) batch_size = hetero_batch['paper'].batch_size if not directed: n_id, _, _, e_mask = k_hop_subgraph(split_idx, num_hops=2, edge_index=data.edge_index, num_nodes=data.num_nodes) n_id = n_id.sort()[0] assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist() assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum()) class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, out_channels) def forward(self, x, edge_index, edge_weight): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight).relu() return x model = GNN(dataset.num_features, 16, dataset.num_classes) hetero_model = to_hetero(model, hetero_data.metadata()) out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict, hetero_batch.edge_weight_dict)['paper'][:batch_size] assert torch.allclose(out1, out2, atol=1e-6) try: shutil.rmtree(root) except PermissionError: pass
def test_hgt_loader_on_cora(get_dataset): dataset = get_dataset(name='Cora') data = dataset[0] data.edge_weight = torch.rand(data.num_edges) hetero_data = HeteroData() hetero_data['paper'].x = data.x hetero_data['paper'].n_id = torch.arange(data.num_nodes) hetero_data['paper', 'paper'].edge_index = data.edge_index hetero_data['paper', 'paper'].edge_weight = data.edge_weight split_idx = torch.arange(5, 8) # Sample the complete two-hop neighborhood: loader = HGTLoader(hetero_data, num_samples=[data.num_nodes] * 2, batch_size=split_idx.numel(), input_nodes=('paper', split_idx)) assert len(loader) == 1 hetero_batch = next(iter(loader)) batch_size = hetero_batch['paper'].batch_size n_id, _, _, e_mask = k_hop_subgraph(split_idx, num_hops=2, edge_index=data.edge_index, num_nodes=data.num_nodes) n_id = n_id.sort()[0] assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist() assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum()) class GNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GraphConv(in_channels, hidden_channels) self.conv2 = GraphConv(hidden_channels, out_channels) def forward(self, x, edge_index, edge_weight): x = self.conv1(x, edge_index, edge_weight).relu() x = self.conv2(x, edge_index, edge_weight).relu() return x model = GNN(dataset.num_features, 16, dataset.num_classes) hetero_model = to_hetero(model, hetero_data.metadata()) out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict, hetero_batch.edge_weight_dict)['paper'][:batch_size] assert torch.allclose(out1, out2, atol=1e-6)
def test_to_hetero_with_gcn(): metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')]) x_dict = {'paper': torch.randn(100, 16)} edge_index_dict = { ('paper', '0', 'paper'): torch.randint(100, (2, 200)), ('paper', '1', 'paper'): torch.randint(100, (2, 200)), } model = GCN() model = to_hetero(model, metadata, debug=False) print(model) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 1 assert out['paper'].size() == (100, 64)
def __init__( self, metadata: Tuple[List[NodeType], List[EdgeType]], hidden_channels: int, out_channels: int, dropout: float, ): super().__init__() self.save_hyperparameters() model = GNN(hidden_channels, out_channels, dropout) # Convert the homogeneous GNN model to a heterogeneous variant in # which distinct parameters are learned for each node and edge type. self.model = to_hetero(model, metadata, aggr='sum') self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy()
def test_graph_level_to_hetero(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } batch_dict = { 'paper': torch.zeros(100, dtype=torch.long), 'author': torch.zeros(100, dtype=torch.long), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphLevelGNN() model = to_hetero(model, metadata, aggr='mean', debug=False) out = model(x_dict, edge_index_dict, batch_dict) assert out.size() == (1, 64)
def test_to_hetero(): metadata = (['paper', 'author'], [('paper', 'cites', 'paper'), ('paper', 'written_by', 'author'), ('author', 'writes', 'paper')]) x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 16)} edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } edge_attr_dict = { ('paper', 'cites', 'paper'): torch.randn(200, 8), ('paper', 'written_by', 'author'): torch.randn(200, 8), ('author', 'writes', 'paper'): torch.randn(200, 8), } model = Net1() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_attr_dict) assert isinstance(out, tuple) and len(out) == 2 assert isinstance(out[0], dict) and len(out[0]) == 2 assert out[0]['paper'].size() == (100, 32) assert out[0]['author'].size() == (100, 32) assert isinstance(out[1], dict) and len(out[1]) == 3 assert out[1][('paper', 'cites', 'paper')].size() == (200, 16) assert out[1][('paper', 'written_by', 'author')].size() == (200, 16) assert out[1][('author', 'writes', 'paper')].size() == (200, 16) for aggr in ['sum', 'mean', 'min', 'max', 'mul']: model = Net2() model = to_hetero(model, metadata, aggr='mean', debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net3() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict, edge_attr_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net4() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net5(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net6(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net7() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net8() model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 32) assert out['author'].size() == (8, 32)
def __init__(self, hidden_channels): super().__init__() self.encoder = GNNEncoder(hidden_channels, hidden_channels) self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') self.decoder = EdgeDecoder(hidden_channels)
def test_hetero_transformer_self_loop_error(): to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')])) with pytest.raises(ValueError, match="incorrect message passing"): to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'), ('b', 'to', 'a')]))
def test_to_hetero(): x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), ('paper', 'written_by', 'author'): torch.randint(100, (2, 200), dtype=torch.long), ('author', 'writes', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), } adj_t_dict = {} for edge_type, (row, col) in edge_index_dict.items(): adj_t_dict[edge_type] = SparseTensor(row=col, col=row, sparse_sizes=(100, 100)) edge_attr_dict = { ('paper', 'cites', 'paper'): torch.randn(200, 8), ('paper', 'written_by', 'author'): torch.randn(200, 8), ('author', 'writes', 'paper'): torch.randn(200, 8), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = Net1() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_attr_dict) assert isinstance(out, tuple) and len(out) == 2 assert isinstance(out[0], dict) and len(out[0]) == 2 assert out[0]['paper'].size() == (100, 32) assert out[0]['author'].size() == (100, 32) assert isinstance(out[1], dict) and len(out[1]) == 3 assert out[1][('paper', 'cites', 'paper')].size() == (200, 16) assert out[1][('paper', 'written_by', 'author')].size() == (200, 16) assert out[1][('author', 'writes', 'paper')].size() == (200, 16) assert sum(p.numel() for p in model.parameters()) == 1520 for aggr in ['sum', 'mean', 'min', 'max', 'mul']: model = Net2() model = to_hetero(model, metadata, aggr='mean', debug=False) assert sum(p.numel() for p in model.parameters()) == 5824 out1 = model(x_dict, edge_index_dict) assert isinstance(out1, dict) and len(out1) == 2 assert out1['paper'].size() == (100, 32) assert out1['author'].size() == (100, 32) out2 = model(x_dict, adj_t_dict) assert isinstance(out2, dict) and len(out2) == 2 for node_type in x_dict.keys(): assert torch.allclose(out1[node_type], out2[node_type], atol=1e-6) model = Net3() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict, edge_attr_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net4() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net5(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net6(num_layers=2) model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 16) assert out['author'].size() == (100, 16) model = Net7() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32) model = Net8() model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 32) assert out['author'].size() == (8, 32) model = Net9() model = to_hetero(model, metadata, debug=False) out = model({'paper': torch.randn(4, 16), 'author': torch.randn(8, 16)}) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (4, 16) assert out['author'].size() == (8, 16) model = Net10() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) assert isinstance(out, dict) and len(out) == 2 assert out['paper'].size() == (100, 32) assert out['author'].size() == (100, 32)
shuffle=True, input_nodes=train_input_nodes, **kwargs) val_loader = HGTLoader(data, num_samples=[1024] * 4, input_nodes=val_input_nodes, **kwargs) model = Sequential('x, edge_index', [ (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (Linear(-1, dataset.num_classes), 'x -> x'), ]) model = to_hetero(model, data.metadata(), aggr='sum').to(device) @torch.no_grad() def init_params(): # Initialize lazy parameters via forwarding a single batch to the model: batch = next(iter(train_loader)) batch = batch.to(device, 'edge_index') model(batch.x_dict, batch.edge_index_dict) def train(): model.train() total_examples = total_loss = 0 for batch in tqdm(train_loader):
def __init__(self, metadata, hidden_channels, num_layers, output_channels): super().__init__() self.model = to_hetero( GraphSAGE((-1, -1), hidden_channels, num_layers, output_channels), metadata)
def __init__(self, metadata, hidden_channels, num_layers, output_channels, num_heads): super().__init__() self.model = to_hetero( GAT((-1, -1), hidden_channels, num_layers, output_channels, add_self_loops=False, heads=num_heads), metadata)