def test_hetero_conv(aggr):
    data = HeteroData()
    data['paper'].x = torch.randn(50, 32)
    data['author'].x = torch.randn(30, 64)
    data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200)
    data['paper', 'author'].edge_index = get_edge_index(50, 30, 100)
    data['author', 'paper'].edge_index = get_edge_index(30, 50, 100)
    data['paper', 'paper'].edge_weight = torch.rand(200)

    conv = HeteroConv(
        {
            ('paper', 'to', 'paper'): GCNConv(-1, 64),
            ('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
            ('paper', 'to', 'author'): GATConv((-1, -1), 64),
        },
        aggr=aggr)

    assert len(list(conv.parameters())) > 0
    assert str(conv) == 'HeteroConv(num_relations=3)'

    out = conv(data.x_dict,
               data.edge_index_dict,
               edge_weight_dict=data.edge_weight_dict)

    assert len(out) == 2
    if aggr is not None:
        assert out['paper'].size() == (50, 64)
        assert out['author'].size() == (30, 64)
    else:
        assert out['paper'].size() == (50, 2, 64)
        assert out['author'].size() == (30, 1, 64)
Exemplo n.º 2
0
def test_hetero_conv_with_dot_syntax_node_types():
    data = HeteroData()
    data['src.paper'].x = torch.randn(50, 32)
    data['author'].x = torch.randn(30, 64)
    data['src.paper', 'src.paper'].edge_index = get_edge_index(50, 50, 200)
    data['src.paper', 'author'].edge_index = get_edge_index(50, 30, 100)
    data['author', 'src.paper'].edge_index = get_edge_index(30, 50, 100)
    data['src.paper', 'src.paper'].edge_weight = torch.rand(200)

    conv = HeteroConv({
        ('src.paper', 'to', 'src.paper'):
        GCNConv(-1, 64),
        ('author', 'to', 'src.paper'):
        SAGEConv((-1, -1), 64),
        ('src.paper', 'to', 'author'):
        GATConv((-1, -1), 64, add_self_loops=False),
    })

    assert len(list(conv.parameters())) > 0
    assert str(conv) == 'HeteroConv(num_relations=3)'

    out = conv(data.x_dict,
               data.edge_index_dict,
               edge_weight_dict=data.edge_weight_dict)

    assert len(out) == 2
    assert out['src.paper'].size() == (50, 64)
    assert out['author'].size() == (30, 64)
Exemplo n.º 3
0
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size):
        super().__init__()

        self.batch_size = batch_size
        self.dim = hidden_channels

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128,
                                                       self.dim * self.dim))

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    ('x_i', 'inner_edge_i', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_i', 'outer_edge_ij', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'outer_edge_ji', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_i', 'inner_edge_i', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                },
                aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.dim, outer_out_channels)

        self.lin_i = Linear(self.dim, inner_out_channels)
        self.lin_j = Linear(self.dim, inner_out_channels)
Exemplo n.º 4
0
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size, num_node_types, num_heads):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_channels = hidden_channels
        self.heads = num_heads

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    ('x_i', 'inner_edge_i', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_i', 'outer_edge_ij', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'outer_edge_ji', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_i', 'inner_edge_i', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                },
                aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.hidden_channels, outer_out_channels)

        self.lin_i = Linear(self.hidden_channels, inner_out_channels)
        self.lin_j = Linear(self.hidden_channels, inner_out_channels)
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)
def test_hetero_conv_with_custom_conv():
    data = HeteroData()
    data['paper'].x = torch.randn(50, 32)
    data['paper'].pos = torch.randn(50, 3)
    data['author'].x = torch.randn(30, 64)
    data['author'].pos = torch.randn(30, 3)
    data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200)
    data['paper', 'author'].edge_index = get_edge_index(50, 30, 100)
    data['author', 'paper'].edge_index = get_edge_index(30, 50, 100)

    conv = HeteroConv({key: CustomConv(64) for key in data.edge_types})
    out = conv(data.x_dict, data.edge_index_dict, data.pos_dict)
    assert len(out) == 2
    assert out['paper'].size() == (50, 64)
    assert out['author'].size() == (30, 64)
Exemplo n.º 7
0
    def _create_output_gate_parameters_and_layers(self):
        self.conv_o = HeteroConv({
            edge_type: SAGEConv(in_channels=(-1, -1),
                                out_channels=self.out_channels,
                                bias=self.bias)
            for edge_type in self.metadata[1]
        })

        self.W_o = {
            node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
            for node_type, in_channels in self.in_channels_dict.items()
        }
        self.b_o = {
            node_type: Parameter(torch.Tensor(1, self.out_channels))
            for node_type in self.in_channels_dict
        }
Exemplo n.º 8
0
def test_hetero_conv_self_loop_error():
    HeteroConv({('a', 'to', 'a'): MessagePassingLoops()})
    with pytest.raises(ValueError, match="incorrect message passing"):
        HeteroConv({('a', 'to', 'b'): MessagePassingLoops()})