class Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'): super(Block, self).__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin = Linear(hidden_channels + out_channels, out_channels) else: self.lin = Linear(out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() # x: [batch_size, num_nodes, in_channels] def forward(self, x, adj, mask=None, add_loop=True): # x1: [batch_size, num_nodes,hidden_channels] x1 = F.relu(self.conv1(x, adj, mask, add_loop)) # x2: [batch_size,num_nodes, out_channels] x2 = F.relu(self.conv2(x1, adj, mask, add_loop)) # [batch_size,num_nodes,out_channels] return self.lin(self.jump([x1, x2]))
class Block_2hop(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, jp=False): super(Block_2hop, self).__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) self.jp = jp if self.jp: self.jump = JumpingKnowledge('cat') self.lin = Linear(hidden_channels + out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() if self.jp: self.lin.reset_parameters() def forward(self, x, adj, mask=None, add_loop=True): x1 = F.relu(self.conv1(x, adj, mask, add_loop)) x1 = F.normalize(x1, p=2, dim=-1) x2 = F.relu(self.conv2(x1, adj, mask, add_loop)) x2 = F.normalize(x2, p=2, dim=-1) if self.jp: return F.relu(self.lin(self.jump([x1, x2]))) return x2
class Block_1hop(torch.nn.Module): # If we only connect up to 1-hop neighbors, jumping knowledge is always False. def __init__(self, in_channels, hidden_channels, out_channels, jp=False): super(Block_1hop, self).__init__() self.conv1 = DenseSAGEConv(in_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() def forward(self, x, adj, mask=None, add_loop=True): x1 = F.relu(self.conv1(x, adj, mask, add_loop)) x1 = F.normalize(x1, p=2, dim=-1) return x1
class Block_2hop(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(Block_2hop, self).__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, adj, mask=None, add_loop=True): x1 = F.relu(self.conv1(x, adj, mask, add_loop)) x1 = F.normalize(x1, p=2, dim=-1) x2 = F.relu(self.conv2(x1, adj, mask, add_loop)) x2 = F.normalize(x2, p=2, dim=-1) return x2
class Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(Block, self).__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) self.lin = torch.nn.Linear(hidden_channels + out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, x, adj, mask=None, add_loop=True): x1 = F.relu(self.conv1(x, adj, mask, add_loop)) x2 = F.relu(self.conv2(x1, adj, mask, add_loop)) return self.lin(torch.cat([x1, x2], dim=-1))
class Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'): super().__init__() self.conv1 = DenseSAGEConv(in_channels, hidden_channels) self.conv2 = DenseSAGEConv(hidden_channels, out_channels) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin = Linear(hidden_channels + out_channels, out_channels) else: self.lin = Linear(out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, x, adj, mask=None): x1 = F.relu(self.conv1(x, adj, mask)) x2 = F.relu(self.conv2(x1, adj, mask)) return self.lin(self.jump([x1, x2]))