Ejemplo n.º 1
0
    def __init__(self, in_channels, hidden_channels, out_channels, mode="cat"):
        super(Block, self).__init__()

        # self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        # self.conv2 = DenseSAGEConv(hidden_channels, out_channels)

        # self.conv1 = DenseGCNConv(in_channels, hidden_channels)
        # self.conv2 = DenseGCNConv(hidden_channels, out_channels)

        nn1 = torch.nn.Sequential(
            Linear(in_channels, hidden_channels),
            ReLU(),
            Linear(hidden_channels, hidden_channels),
        )

        nn2 = torch.nn.Sequential(
            Linear(hidden_channels, out_channels),
            ReLU(),
            Linear(out_channels, out_channels),
        )

        self.conv1 = DenseGINConv(nn1, train_eps=True)
        self.conv2 = DenseGINConv(nn2, train_eps=True)

        self.jump = JumpingKnowledge(mode)
        if mode == "cat":
            self.lin = Linear(hidden_channels + out_channels, out_channels)
        else:
            self.lin = Linear(out_channels, out_channels)
def test_dense_sage_conv():
    in_channels, out_channels = (16, 32)
    nn = Seq(Lin(in_channels, 32), ReLU(), Lin(32, out_channels))
    sparse_conv = GINConv(nn)
    dense_conv = DenseGINConv(nn)
    dense_conv = DenseGINConv(nn, train_eps=True)
    assert dense_conv.__repr__() == (
        'DenseGINConv(nn=Sequential(\n'
        '  (0): Linear(in_features=16, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '))')

    x = torch.randn((5, in_channels))
    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],
                               [1, 2, 0, 2, 0, 1, 4, 3]])

    sparse_out = sparse_conv(x, edge_index)

    x = torch.cat([x, x.new_zeros(1, in_channels)],
                  dim=0).view(2, 3, in_channels)
    adj = torch.Tensor([
        [[0, 1, 1], [1, 0, 1], [1, 1, 0]],
        [[0, 1, 0], [1, 0, 0], [0, 0, 0]],
    ])
    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8)

    dense_out = dense_conv(x, adj, mask)
    assert dense_out.size() == (2, 3, out_channels)
    dense_out = dense_out.view(6, out_channels)[:-1]
    assert torch.allclose(sparse_out, dense_out, atol=1e-04)
Ejemplo n.º 3
0
class Block(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'):
        super(Block, self).__init__()

        # self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        # self.conv2 = DenseSAGEConv(hidden_channels, out_channels)

        # self.conv1 = DenseGCNConv(in_channels, hidden_channels)
        # self.conv2 = DenseGCNConv(hidden_channels, out_channels)

        nn1 = torch.nn.Sequential(Linear(in_channels, hidden_channels), ReLU(),
                                  Linear(hidden_channels, hidden_channels))

        nn2 = torch.nn.Sequential(Linear(hidden_channels, out_channels),
                                  ReLU(), Linear(out_channels, out_channels))

        self.conv1 = DenseGINConv(nn1, train_eps=True)
        self.conv2 = DenseGINConv(nn2, train_eps=True)

        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin = Linear(hidden_channels + out_channels, out_channels)
        else:
            self.lin = Linear(out_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        return self.lin(self.jump([x1, x2]))
    def __init__(self, in_dim, out_dim, dim=32):
        super(Encoder, self).__init__()

        nn1 = Sequential(Linear(in_dim, dim), ReLU(), Linear(dim, dim))
        self.conv1 = DenseGINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(in_dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = DenseGINConv(nn2)
        self.bn2 = torch.nn.BatchNorm1d(in_dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = DenseGINConv(nn3)
        self.bn3 = torch.nn.BatchNorm1d(in_dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, out_dim)
Ejemplo n.º 5
0
 def _gcn(self, name, input_dim, hidden_dim, bias, activation='relu'):
     if name == 'SAGE':
         return DenseSAGEConv(input_dim,
                              hidden_dim,
                              normalize=True,
                              bias=bias)
     else:
         nn1 = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                             self._activation(activation),
                             nn.Linear(hidden_dim, hidden_dim))
         return DenseGINConv(nn1)
    def __init__(self, in_dim, out_dim, dim=32):
        super(Decoder, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.id_mat = torch.eye(in_dim)
        self.register_buffer('id', self.id_mat)

        nn1 = Sequential(Linear(in_dim + 2, dim), ReLU(), Linear(dim, dim))
        self.conv1 = DenseGINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(in_dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = DenseGINConv(nn2)
        self.bn2 = torch.nn.BatchNorm1d(in_dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = DenseGINConv(nn3)
        self.bn3 = torch.nn.BatchNorm1d(in_dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, out_dim)
Ejemplo n.º 7
0
def test_dense_gin_conv_with_broadcasting():
    batch_size, num_nodes, channels = 8, 3, 16
    nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels))
    conv = DenseGINConv(nn)

    x = torch.randn(batch_size, num_nodes, channels)
    adj = torch.Tensor([
        [0, 1, 1],
        [1, 0, 1],
        [1, 1, 0],
    ])

    assert conv(x, adj).size() == (batch_size, num_nodes, channels)
    mask = torch.tensor([1, 1, 1], dtype=torch.bool)
    assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)