def test_revgnn_forward_inverse(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups) assert str(conv) == (f'GroupAddRev(SAGEConv({32 // num_groups}, ' f'{32 // num_groups}, aggr=mean), ' f'num_groups={num_groups})') h = lin(x) h_o = h.clone().detach() out = conv(h, edge_index) assert h.storage().size() == 0 h_rev = conv.inverse(out, edge_index) assert torch.allclose(h_o, h_rev, atol=0.001)