def test_custom_explain_message(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) conv = SAGEConv(8, 32) def explain_message(self, inputs, x_i, x_j): assert isinstance(self, SAGEConv) assert inputs.size() == (6, 8) assert inputs.size() == x_i.size() == x_j.size() assert torch.allclose(inputs, x_j) self.x_i = x_i self.x_j = x_j return inputs conv.explain_message = explain_message.__get__(conv, MessagePassing) conv.explain = True conv(x, edge_index) assert torch.allclose(conv.x_i, x[edge_index[1]]) assert torch.allclose(conv.x_j, x[edge_index[0]])