def test_revgnn_multi_backward(num_groups):
    x = torch.randn(4, 32)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

    lin = Linear(32, 32)
    conv = SAGEConv(32 // num_groups, 32 // num_groups)
    conv = GroupAddRev(conv, num_groups=num_groups, num_bwd_passes=4)

    h = lin(x)
    out = conv(h, edge_index)
    target = out.mean()
    target.backward(retain_graph=True)
    target.backward(retain_graph=True)
    torch.autograd.grad(outputs=target,
                        inputs=[h] + list(conv.parameters()),
                        retain_graph=True)
    torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()))
Beispiel #2
0
class SAGEConvActor(object):
    def __init__(self, in_channels, out_channels):
        self.conv = SAGEConv(in_channels, out_channels)
        self.local_optim = Adam(self.conv.parameters(), lr=0.01)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj):
        return self.conv(x, edge_index)

    def zero_grad(self):
        self.local_optim.zero_grad()

    def step(self):
        self.local_optim.step()