def test_hetero_conv(aggr): data = HeteroData() data['paper'].x = torch.randn(50, 32) data['author'].x = torch.randn(30, 64) data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200) data['paper', 'author'].edge_index = get_edge_index(50, 30, 100) data['author', 'paper'].edge_index = get_edge_index(30, 50, 100) data['paper', 'paper'].edge_weight = torch.rand(200) conv = HeteroConv( { ('paper', 'to', 'paper'): GCNConv(-1, 64), ('author', 'to', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'to', 'author'): GATConv((-1, -1), 64), }, aggr=aggr) assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' out = conv(data.x_dict, data.edge_index_dict, edge_weight_dict=data.edge_weight_dict) assert len(out) == 2 if aggr is not None: assert out['paper'].size() == (50, 64) assert out['author'].size() == (30, 64) else: assert out['paper'].size() == (50, 2, 64) assert out['author'].size() == (30, 1, 64)
def test_hetero_conv_with_dot_syntax_node_types(): data = HeteroData() data['src.paper'].x = torch.randn(50, 32) data['author'].x = torch.randn(30, 64) data['src.paper', 'src.paper'].edge_index = get_edge_index(50, 50, 200) data['src.paper', 'author'].edge_index = get_edge_index(50, 30, 100) data['author', 'src.paper'].edge_index = get_edge_index(30, 50, 100) data['src.paper', 'src.paper'].edge_weight = torch.rand(200) conv = HeteroConv({ ('src.paper', 'to', 'src.paper'): GCNConv(-1, 64), ('author', 'to', 'src.paper'): SAGEConv((-1, -1), 64), ('src.paper', 'to', 'author'): GATConv((-1, -1), 64, add_self_loops=False), }) assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' out = conv(data.x_dict, data.edge_index_dict, edge_weight_dict=data.edge_weight_dict) assert len(out) == 2 assert out['src.paper'].size() == (50, 64) assert out['author'].size() == (30, 64)
def __init__(self, hidden_channels, outer_out_channels, inner_out_channels, num_layers, batch_size): super().__init__() self.batch_size = batch_size self.dim = hidden_channels nn = Sequential(Linear(5, 128), ReLU(), Linear(128, self.dim * self.dim)) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { ('x_i', 'inner_edge_i', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'inner_edge_j', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_i', 'outer_edge_ij', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'outer_edge_ji', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_i', 'inner_edge_i', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'inner_edge_j', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), }, aggr='sum') self.convs.append(conv) self.lin = Linear(self.dim, outer_out_channels) self.lin_i = Linear(self.dim, inner_out_channels) self.lin_j = Linear(self.dim, inner_out_channels)
def __init__(self, hidden_channels, outer_out_channels, inner_out_channels, num_layers, batch_size, num_node_types, num_heads): super().__init__() self.batch_size = batch_size self.hidden_channels = hidden_channels self.heads = num_heads self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_i', 'outer_edge_ij', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'outer_edge_ji', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), }, aggr='sum') self.convs.append(conv) self.lin = Linear(self.hidden_channels, outer_out_channels) self.lin_i = Linear(self.hidden_channels, inner_out_channels) self.lin_j = Linear(self.hidden_channels, inner_out_channels)
def __init__(self, metadata, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in metadata[1] }) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels)
def test_hetero_conv_with_custom_conv(): data = HeteroData() data['paper'].x = torch.randn(50, 32) data['paper'].pos = torch.randn(50, 3) data['author'].x = torch.randn(30, 64) data['author'].pos = torch.randn(30, 3) data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200) data['paper', 'author'].edge_index = get_edge_index(50, 30, 100) data['author', 'paper'].edge_index = get_edge_index(30, 50, 100) conv = HeteroConv({key: CustomConv(64) for key in data.edge_types}) out = conv(data.x_dict, data.edge_index_dict, data.pos_dict) assert len(out) == 2 assert out['paper'].size() == (50, 64) assert out['author'].size() == (30, 64)
def _create_output_gate_parameters_and_layers(self): self.conv_o = HeteroConv({ edge_type: SAGEConv(in_channels=(-1, -1), out_channels=self.out_channels, bias=self.bias) for edge_type in self.metadata[1] }) self.W_o = { node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) for node_type, in_channels in self.in_channels_dict.items() } self.b_o = { node_type: Parameter(torch.Tensor(1, self.out_channels)) for node_type in self.in_channels_dict }
def test_hetero_conv_self_loop_error(): HeteroConv({('a', 'to', 'a'): MessagePassingLoops()}) with pytest.raises(ValueError, match="incorrect message passing"): HeteroConv({('a', 'to', 'b'): MessagePassingLoops()})