コード例 #1
0
class Block(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'):
        super(Block, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin = Linear(hidden_channels + out_channels, out_channels)
        else:
            self.lin = Linear(out_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin.reset_parameters()

    # x: [batch_size, num_nodes, in_channels]
    def forward(self, x, adj, mask=None, add_loop=True):
        # x1: [batch_size, num_nodes,hidden_channels]
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        # x2: [batch_size,num_nodes, out_channels]
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        # [batch_size,num_nodes,out_channels]
        return self.lin(self.jump([x1, x2]))
コード例 #2
0
class Block_2hop(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, jp=False):
        super(Block_2hop, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)
        self.jp = jp
        if self.jp:
            self.jump = JumpingKnowledge('cat')
            self.lin = Linear(hidden_channels + out_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        if self.jp:
            self.lin.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x1 = F.normalize(x1, p=2, dim=-1)
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        x2 = F.normalize(x2, p=2, dim=-1)
        if self.jp:
            return F.relu(self.lin(self.jump([x1, x2])))
        return x2
コード例 #3
0
class Block_1hop(torch.nn.Module):
    # If we only connect up to 1-hop neighbors, jumping knowledge is always False.
    def __init__(self, in_channels, hidden_channels, out_channels, jp=False):
        super(Block_1hop, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, out_channels)  

    def reset_parameters(self):
        self.conv1.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x1 = F.normalize(x1, p=2, dim=-1)
        return x1
コード例 #4
0
class Block_2hop(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(Block_2hop, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x1 = F.normalize(x1, p=2, dim=-1)
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        x2 = F.normalize(x2, p=2, dim=-1)
        return x2
コード例 #5
0
class Block(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(Block, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)

        self.lin = torch.nn.Linear(hidden_channels + out_channels,
                                   out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        return self.lin(torch.cat([x1, x2], dim=-1))
コード例 #6
0
class Block(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'):
        super().__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin = Linear(hidden_channels + out_channels, out_channels)
        else:
            self.lin = Linear(out_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, adj, mask=None):
        x1 = F.relu(self.conv1(x, adj, mask))
        x2 = F.relu(self.conv2(x1, adj, mask))
        return self.lin(self.jump([x1, x2]))