Beispiel #1
0
def test_custom_explain_message():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])

    conv = SAGEConv(8, 32)

    def explain_message(self, inputs, x_i, x_j):
        assert isinstance(self, SAGEConv)
        assert inputs.size() == (6, 8)
        assert inputs.size() == x_i.size() == x_j.size()
        assert torch.allclose(inputs, x_j)
        self.x_i = x_i
        self.x_j = x_j
        return inputs

    conv.explain_message = explain_message.__get__(conv, MessagePassing)
    conv.explain = True

    conv(x, edge_index)

    assert torch.allclose(conv.x_i, x[edge_index[1]])
    assert torch.allclose(conv.x_j, x[edge_index[0]])