def test_revgnn_forward_inverse(num_groups):
    x = torch.randn(4, 32)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

    lin = Linear(32, 32)
    conv = SAGEConv(32 // num_groups, 32 // num_groups)
    conv = GroupAddRev(conv, num_groups=num_groups)
    assert str(conv) == (f'GroupAddRev(SAGEConv({32 // num_groups}, '
                         f'{32 // num_groups}, aggr=mean), '
                         f'num_groups={num_groups})')

    h = lin(x)
    h_o = h.clone().detach()

    out = conv(h, edge_index)
    assert h.storage().size() == 0

    h_rev = conv.inverse(out, edge_index)
    assert torch.allclose(h_o, h_rev, atol=0.001)