def test_revgnn_multi_backward(num_groups): x = torch.randn(4, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) lin = Linear(32, 32) conv = SAGEConv(32 // num_groups, 32 // num_groups) conv = GroupAddRev(conv, num_groups=num_groups, num_bwd_passes=4) h = lin(x) out = conv(h, edge_index) target = out.mean() target.backward(retain_graph=True) target.backward(retain_graph=True) torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()), retain_graph=True) torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()))
class SAGEConvActor(object): def __init__(self, in_channels, out_channels): self.conv = SAGEConv(in_channels, out_channels) self.local_optim = Adam(self.conv.parameters(), lr=0.01) def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj): return self.conv(x, edge_index) def zero_grad(self): self.local_optim.zero_grad() def step(self): self.local_optim.step()