Пример #1
0
    def forward(self, x, edge_index, batch):
        x = self.lin(x)
        for conv in self.convs:
            x = conv(x, edge_index)

        x, S1 = self.mem1(x, batch)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout)
        x, S2 = self.mem2(x)

        return (
            F.log_softmax(x.squeeze(1), dim=-1),
            MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),
        )
Пример #2
0
def test_mem_pool():
    mpool = MemPooling(4, 8, heads=3, num_clusters=2)
    assert mpool.__repr__() == 'MemPooling(4, 8, heads=3, num_clusters=2)'

    x = torch.randn(17, 4)
    batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4])
    _, mask = to_dense_batch(batch, batch)

    out, S = mpool(x, batch)
    loss = MemPooling.kl_loss(S)

    assert out.size() == (5, 2, 8)
    assert S[~mask].sum() == 0
    assert S[mask].sum() == x.size(0)
    assert float(loss) > 0
Пример #3
0
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super().__init__()
        self.dropout = dropout

        self.lin = Linear(in_channels, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for i in range(2):
            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)
            norm = BatchNorm1d(hidden_channels)
            act = LeakyReLU()
            self.convs.append(
                DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))

        self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)
        self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)
Пример #4
0
def test_mem_pool():
    mpool1 = MemPooling(4, 8, heads=3, num_clusters=2)
    assert mpool1.__repr__() == 'MemPooling(4, 8, heads=3, num_clusters=2)'
    mpool2 = MemPooling(8, 4, heads=2, num_clusters=1)

    x = torch.randn(17, 4)
    batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4])
    _, mask = to_dense_batch(x, batch)

    out1, S = mpool1(x, batch)
    loss = MemPooling.kl_loss(S)
    with torch.autograd.set_detect_anomaly(True):
        loss.backward()
    out2, _ = mpool2(out1)

    assert out1.size() == (5, 2, 8)
    assert out2.size() == (5, 1, 4)
    assert S[~mask].sum() == 0
    assert round(S[mask].sum().item()) == x.size(0)
    assert float(loss) > 0