Пример #1
0
class GCNWithJK(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GCNWithJK, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]
        x = self.jump(xs)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #2
0
class DiffPool(torch.nn.Module):
    def __init__(self,
                 args,
                 num_nodes=10,
                 num_layers=4,
                 hidden=16,
                 ratio=0.25):
        super(DiffPool, self).__init__()

        self.args = args
        num_features = self.args.filters_3

        self.att = DenseAttentionModule(self.args)

        num_nodes = ceil(ratio * num_nodes)
        self.embed_block1 = Block(num_features, hidden, hidden)
        self.pool_block1 = Block(num_features, hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range((num_layers // 2) - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))
        self.jump = JumpingKnowledge(mode="cat")
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, num_features)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for block1, block2 in zip(self.embed_blocks, self.pool_blocks):
            block1.reset_parameters()
            block2.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, adj, mask):
        s = self.pool_block1(x, adj, mask, add_loop=True)
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))

        xs = [self.att(x, mask)]
        x, adj, _, _ = dense_diff_pool(x, adj, s, mask)

        for i, (embed,
                pool) in enumerate(zip(self.embed_blocks, self.pool_blocks)):
            s = pool(x, adj)
            x = F.relu(embed(x, adj))
            xs.append(self.att(x))
            if i < (len(self.embed_blocks) - 1):
                x, adj, _, _ = dense_diff_pool(x, adj, s)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #3
0
class DiffPool(torch.nn.Module):
    def __init__(self, dataset, num_pools, hidden, ratio=0.25):
        super(DiffPool, self).__init__()

        self.num_pools, self.hidden = num_pools, hidden

        num_nodes = ceil(ratio * dataset[0].num_nodes)
        self.embed_block1 = Block(dataset.num_features, hidden, hidden)
        self.pool_block1 = Block(dataset.num_features, hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range(num_pools - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for embed_block, pool_block in zip(self.embed_blocks,
                                           self.pool_blocks):
            embed_block.reset_parameters()
            pool_block.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        # x:[batch_size,num_nodes,in_channels]
        x, adj, mask = data.x, data.adj, data.mask

        # x:[batch_size, num_nodes, c_num_nodes]
        s = self.pool_block1(x, adj, mask, add_loop=True)
        # s:[batch_size, num_nodes, hidden]
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        xs = [x.mean(dim=1)]
        # x:[batch_size, c_num_nodes, hidden]
        x, adj, _, _ = dense_diff_pool(x, adj, s, mask)
        # adj: [batch_size,c_num_nodes, c_num_nodes]
        for i, (embed_block, pool_block) in enumerate(
                zip(self.embed_blocks, self.pool_blocks)):
            # s: [batch_size,c_num_nodes, cc_num_nodes]
            s = pool_block(x, adj)
            # x: [batch_size,c_num_nodes,hidden]
            x = F.relu(embed_block(x, adj))
            xs.append(x.mean(dim=1))
            if i < len(self.embed_blocks) - 1:
                # x: [batch_size,cc_num_nodes, hidden]
                x, adj, _, _ = dense_diff_pool(x, adj, s)
                # adj: [batch_size,cc_num_nodes,cc_num_nodes]
        x = self.jump(xs)  # x: [batch_size,len(self.embed_blocks)+1)*hidden]
        x = F.relu(self.lin1(x))  # x: [batch_size,hidden]
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)  # x: [batch_size,dataset.num_classes]
        return F.log_softmax(x, dim=-1)
Пример #4
0
class Coarsening(torch.nn.Module):
    def __init__(self, dataset, hidden, ratio=0.25): # we only use 1 layer for coarsening
        super(Coarsening, self).__init__()

        # self.embed_block1 = GNNBlock(dataset.num_features, hidden, hidden)
        self.embed_block1 = DenseGCNConv(dataset.num_features, hidden)
        self.coarse_block1 = CoarsenBlock(hidden, ratio)
        self.embed_block2 = DenseGCNConv(hidden, dataset.num_features)

        self.jump = JumpingKnowledge(mode='cat')

        self.lin1 = Linear(hidden + dataset.num_features, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.coarse_block1.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, epsilon=0.01, opt_epochs=100):
        x, adj, mask = data.x, data.adj, data.mask
        batch_num_nodes = data.mask.sum(-1)
        x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        # xs = [x1.mean(dim=1)]
        coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes)
        xs = [coarse_x.mean(dim=1)]
        x2 = F.tanh(self.embed_block2(coarse_x, new_adj, mask, add_loop=True))
        xs.append(x2.mean(dim=1))


        opt_loss = 0.0
        for i in range(len(x)):
            x3 = self.get_nonzero_rows(x[i])
            x4 = self.get_nonzero_rows(x2[i])
            # if x3.size()[0]==0 or x4.size()[0]==0:
            #     continue
            # opt_loss += sinkhorn_loss_default(x3, x4, epsilon, niter=opt_epochs).float()
            opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs)
        return xs, new_adj, S, opt_loss

    def predict(self, xs):
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def get_nonzero_rows(self, M):# M is a matrix
        # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0.........
        #So we use other methods to take place of it
        MM, MM_ind = M.sum(-1).sort()
        N = (M.sum(-1)>0).sum()
        return M[MM_ind[:N]]


    def __repr__(self):
        return self.__class__.__name__
Пример #5
0
class DiffPool(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, ratio=0.25):
        super(DiffPool, self).__init__()

        num_nodes = ceil(ratio * dataset[0].num_nodes)
        self.embed_block1 = Block(dataset.num_features, hidden, hidden)
        self.pool_block1 = Block(dataset.num_features, hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range((num_layers // 2) - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for embed_block, pool_block in zip(self.embed_blocks,
                                           self.pool_blocks):
            embed_block.reset_parameters()
            pool_block.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, adj, mask = data.x, data.adj, data.mask
        link_losses = 0.
        ent_losses = 0.
        s = self.pool_block1(x, adj, mask, add_loop=True)
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        xs = [x.mean(dim=1)]
        x, adj, link_loss, ent_loss = dense_diff_pool(x, adj, s, mask)
        link_losses += link_loss
        ent_losses += ent_loss
        for i, (embed_block, pool_block) in enumerate(
                zip(self.embed_blocks, self.pool_blocks)):
            s = pool_block(x, adj)
            x = F.relu(embed_block(x, adj))
            xs.append(x.mean(dim=1))
            if i < len(self.embed_blocks) - 1:
                x, adj, link_loss, ent_loss = dense_diff_pool(x, adj, s)
                link_losses += link_loss
                ent_losses += ent_loss
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        return F.log_softmax(x, dim=-1), link_losses + ent_losses

    def __repr__(self):
        return self.__class__.__name__
Пример #6
0
class GATJK(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2,
                 dropout=0.5, heads=2, jk_type='max'):
        super(GATJK, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(
            GATConv(in_channels, hidden_channels, heads=heads, concat=True))

        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels*heads))
        for _ in range(num_layers - 2):

            self.convs.append(
                    GATConv(hidden_channels*heads, hidden_channels, heads=heads, concat=True) ) 
            self.bns.append(nn.BatchNorm1d(hidden_channels*heads))

        self.convs.append(
            GATConv(hidden_channels*heads, hidden_channels, heads=heads))

        self.dropout = dropout
        self.activation = F.elu # note: uses elu

        self.jump = JumpingKnowledge(jk_type, channels=hidden_channels*heads, num_layers=1)
        if jk_type == 'cat':
            self.final_project = nn.Linear(hidden_channels*heads*num_layers, out_channels)
        else: # max or lstm
            self.final_project = nn.Linear(hidden_channels*heads, out_channels)


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.jump.reset_parameters()
        self.final_project.reset_parameters()


    def forward(self, data):
        x = data.graph['node_feat']
        xs = []
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, data.graph['edge_index'])
            x = self.bns[i](x)
            x = self.activation(x)
            xs.append(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, data.graph['edge_index'])
        xs.append(x)
        x = self.jump(xs)
        x = self.final_project(x)
        return x
Пример #7
0
class GIN(nn.Module):
    def __init__(self, dataset, num_layers, hidden, train_eps=False, mode='cat'):
        super().__init__()
        self.conv1 = GINConv(nn.Sequential(
            nn.Linear(dataset.num_features, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.BatchNorm1d(hidden),
        ), train_eps=train_eps)
        self.convs = nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(nn.Sequential(
                    nn.Linear(hidden, hidden),
                    nn.ReLU(),
                    nn.Linear(hidden, hidden),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden),
                ), train_eps=train_eps))
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin1 = nn.Linear(num_layers * hidden, hidden)
        else:
            self.lin1 = nn.Linear(hidden, hidden)
        self.lin2 = nn.Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        xs = [x]
        for conv in self.convs:
            x = conv(x, edge_index)
            xs += [x]
        x = self.jump(xs)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #8
0
class DiffPool(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(DiffPool, self).__init__()

        num_nodes = ceil(0.25 * dataset[0].num_nodes)
        self.embed_block1 = Block(dataset.num_features, hidden, hidden)
        self.pool_block1 = Block(dataset.num_features, hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range((num_layers // 2) - 1):
            num_nodes = ceil(0.25 * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for block1, block2 in zip(self.embed_blocks, self.pool_blocks):
            block1.reset_parameters()
            block2.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, adj, mask = data.x, data.adj, data.mask

        s = self.pool_block1(x, adj, mask, add_loop=True)
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        xs = [x.mean(dim=1)]
        x, adj, reg = dense_diff_pool(x, adj, s, mask)

        for embed, pool in zip(self.embed_blocks, self.pool_blocks):
            s = pool(x, adj)
            x = F.relu(embed(x, adj))
            xs.append(x.mean(dim=1))
            x, adj, _, _ = dense_diff_pool(x, adj, s)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
class JKNet(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 num_layers,
                 dropout,
                 mode='concat'):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=False))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=False))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin1 = Linear(num_layers * hidden_channels, hidden_channels)
        else:
            self.lin1 = Linear(hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, adj_t):
        xs = []
        for i, conv in enumerate(self.convs):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xs += [x]

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)
Пример #10
0
class ModelEnsemble(torch.nn.Module):
    def __init__(self, num_features, num_class):
        super(ModelEnsemble, self).__init__()
        self.JK = JumpingKnowledge(mode='lstm',
                                   channels=num_class,
                                   num_layers=1)
        # self.linear = Linear(num_features, num_class)

    def reset_parameters(self):
        self.JK.reset_parameters()

    def forward(self, x):
        x = self.JK(x)
        return log_softmax(x, dim=-1)
Пример #11
0
class GraphSAGEWithJK(torch.nn.Module):
    def __init__(self, num_input_features, num_layers, hidden, mode='cat'):
        super(GraphSAGEWithJK, self).__init__()
        self.conv1 = SAGEConv(num_input_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin1 = Linear(3 * num_layers * hidden, hidden)
        else:
            self.lin1 = Linear(3 * hidden, hidden)
        self.lin2 = Linear(hidden, 2)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))

        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]
        x = self.jump(xs)

        x = torch.cat([
            global_add_pool(x, batch),
            global_mean_pool(x, batch),
            global_max_pool(x, batch)
        ],
                      dim=1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #12
0
class GCNWithJK(torch.nn.Module):
    def __init__(self,
                 num_features,
                 output_channels,
                 num_layers=3,
                 nb_neurons=128,
                 mode='cat',
                 **kwargs):
        super(GCNWithJK, self).__init__()
        self.conv1 = GCNConv(num_features, nb_neurons)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(nb_neurons, nb_neurons))
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin1 = Linear(num_layers * nb_neurons, nb_neurons)
        else:
            self.lin1 = Linear(nb_neurons, nb_neurons)
        self.lin2 = Linear(nb_neurons, output_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, target_size, **kwargs):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]
        x = self.jump(xs)
        x = global_mean_pool(x, batch, size=target_size)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #13
0
class Graclus(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(Graclus, self).__init__()
        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='mean'))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #14
0
class TopK(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(TopK, self).__init__()
        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='mean'))
            self.pools.append(TopKPooling(hidden, ratio=0.8))
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv, pool in zip(self.convs, self.pools):
            conv.reset_parameters()
            pool.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, (conv, pool) in enumerate(zip(self.convs, self.pools)):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0:
                x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch)
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #15
0
class EdgeCIN0(torch.nn.Module):
    """
    A variant of CIN0 operating up to edge level. It may optionally ignore two_cell features.

    This model is based on
    https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py
    """
    def __init__(self,
                 num_input_features,
                 num_classes,
                 num_layers,
                 hidden,
                 dropout_rate: float = 0.5,
                 jump_mode=None,
                 nonlinearity='relu',
                 include_top_features=True,
                 update_top_features=True,
                 readout='sum'):
        super(EdgeCIN0, self).__init__()

        self.max_dim = 1
        self.include_top_features = include_top_features
        # If the top features are included, then they can be updated by a network.
        self.update_top_features = include_top_features and update_top_features
        self.dropout_rate = dropout_rate
        self.jump_mode = jump_mode
        self.convs = torch.nn.ModuleList()
        self.update_top_nns = torch.nn.ModuleList()
        self.nonlinearity = nonlinearity
        self.pooling_fn = get_pooling_fn(readout)
        conv_nonlinearity = get_nonlinearity(nonlinearity, return_module=True)
        for i in range(num_layers):
            layer_dim = num_input_features if i == 0 else hidden
            v_conv_update = Sequential(Linear(layer_dim, hidden),
                                       conv_nonlinearity(),
                                       Linear(hidden, hidden),
                                       conv_nonlinearity(), BN(hidden))
            e_conv_update = Sequential(Linear(layer_dim, hidden),
                                       conv_nonlinearity(),
                                       Linear(hidden, hidden),
                                       conv_nonlinearity(), BN(hidden))
            v_conv_up = Sequential(Linear(layer_dim * 2, layer_dim),
                                   conv_nonlinearity(), BN(layer_dim))
            e_conv_down = Sequential(Linear(layer_dim * 2, layer_dim),
                                     conv_nonlinearity(), BN(layer_dim))
            e_conv_inp_dim = layer_dim * 2 if include_top_features else layer_dim
            e_conv_up = Sequential(Linear(e_conv_inp_dim, layer_dim),
                                   conv_nonlinearity(), BN(layer_dim))
            self.convs.append(
                EdgeCINConv(layer_dim,
                            layer_dim,
                            v_conv_up,
                            e_conv_down,
                            e_conv_up,
                            v_conv_update,
                            e_conv_update,
                            train_eps=False))
            if self.update_top_features and i < num_layers - 1:
                self.update_top_nns.append(
                    Sequential(Linear(layer_dim, hidden), conv_nonlinearity(),
                               Linear(hidden, hidden), conv_nonlinearity(),
                               BN(hidden)))

        self.jump = JumpingKnowledge(
            jump_mode) if jump_mode is not None else None
        if jump_mode == 'cat':
            self.lin1 = Linear(num_layers * hidden, hidden)
        else:
            self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        if self.jump_mode is not None:
            self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        for net in self.update_top_nns:
            net.reset_parameters()

    def pool_complex(self, xs, data):
        # All complexes have nodes so we can extract the batch size from cochains[0]
        batch_size = data.cochains[0].batch.max() + 1
        # The MP output is of shape [message_passing_dim, batch_size, feature_dim]
        pooled_xs = torch.zeros(self.max_dim + 1,
                                batch_size,
                                xs[0].size(-1),
                                device=batch_size.device)
        for i in range(len(xs)):
            # It's very important that size is supplied.
            pooled_xs[i, :, :] = self.pooling_fn(xs[i],
                                                 data.cochains[i].batch,
                                                 size=batch_size)
        return pooled_xs

    def jump_complex(self, jump_xs):
        # Perform JumpingKnowledge at each level of the complex
        xs = []
        for jumpx in jump_xs:
            xs += [self.jump(jumpx)]
        return xs

    def forward(self, data: ComplexBatch):
        model_nonlinearity = get_nonlinearity(self.nonlinearity,
                                              return_module=False)
        xs, jump_xs = None, None
        for c, conv in enumerate(self.convs):
            params = data.get_all_cochain_params(
                max_dim=self.max_dim,
                include_top_features=self.include_top_features)
            xs = conv(*params)
            # If we are at the last convolutional layer, we do not need to update after
            # We also check two_cell features do indeed exist in this batch before doing this.
            if self.update_top_features and c < len(
                    self.convs) - 1 and 2 in data.cochains:
                top_x = self.update_top_nns[c](data.cochains[2].x)
                data.set_xs(xs + [top_x])
            else:
                data.set_xs(xs)

            if self.jump_mode is not None:
                if jump_xs is None:
                    jump_xs = [[] for _ in xs]
                for i, x in enumerate(xs):
                    jump_xs[i] += [x]

        if self.jump_mode is not None:
            xs = self.jump_complex(jump_xs)

        pooled_xs = self.pool_complex(xs, data)
        x = pooled_xs.sum(dim=0)

        x = model_nonlinearity(self.lin1(x))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.lin2(x)
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #16
0
class EdgeOrient(torch.nn.Module):
    """
    A model for edge-defined signals taking edge orientation into account.
    """
    def __init__(self,
                 num_input_features,
                 num_classes,
                 num_layers,
                 hidden,
                 dropout_rate: float = 0.0,
                 jump_mode=None,
                 nonlinearity='id',
                 readout='sum',
                 fully_invar=False):
        super(EdgeOrient, self).__init__()

        self.max_dim = 1
        self.fully_invar = fully_invar
        orient = not self.fully_invar
        self.dropout_rate = dropout_rate
        self.jump_mode = jump_mode
        self.convs = torch.nn.ModuleList()
        self.nonlinearity = nonlinearity
        self.pooling_fn = get_pooling_fn(readout)
        for i in range(num_layers):
            layer_dim = num_input_features if i == 0 else hidden
            # !!!!! Biases must be set to false. Otherwise, the model is not equivariant !!!!
            update_up = Linear(layer_dim, hidden, bias=False)
            update_down = Linear(layer_dim, hidden, bias=False)
            update = Linear(layer_dim, hidden, bias=False)

            self.convs.append(
                OrientedConv(dim=1,
                             up_msg_size=layer_dim,
                             down_msg_size=layer_dim,
                             update_up_nn=update_up,
                             update_down_nn=update_down,
                             update_nn=update,
                             act_fn=get_nonlinearity(nonlinearity,
                                                     return_module=False),
                             orient=orient))
        self.jump = JumpingKnowledge(
            jump_mode) if jump_mode is not None else None
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        if self.jump_mode is not None:
            self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data: CochainBatch, include_partial=False):
        if self.fully_invar:
            data.x = torch.abs(data.x)
        for c, conv in enumerate(self.convs):
            x = conv(data)
            data.x = x

        cell_pred = x

        # To obtain orientation invariance, we take the absolute value of the features.
        # Unless we did that already before the first layer.
        batch_size = data.batch.max() + 1
        if not self.fully_invar:
            x = torch.abs(x)
        x = self.pooling_fn(x, data.batch, size=batch_size)

        # At this point we have invariance: we can use any non-linearity we like.
        # Here, independently from previous non-linearities, we choose ReLU.
        # Note that this makes the model non-linear even when employing identity
        # in previous layers.
        x = torch.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.lin2(x)

        if include_partial:
            return x, cell_pred
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #17
0
class Net(torch.nn.Module):
    def __init__(self, num_classes, gnn_layers, embed_dim, hidden_dim,
                 jk_layer, process_step, dropout):
        super(Net, self).__init__()

        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        self.embedding = Embedding(6, embed_dim)

        for i in range(gnn_layers):
            if i == 0:
                self.convs.append(
                    AGGINConv(Sequential(Linear(2 * embed_dim + 2,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))
            else:
                self.convs.append(
                    AGGINConv(Sequential(Linear(hidden_dim,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))

        if jk_layer.isdigit():
            jk_layer = int(jk_layer)
            self.jk = JumpingKnowledge(mode='lstm',
                                       channels=hidden_dim,
                                       gnn_layers=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'cat':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(gnn_layers * hidden_dim,
                                processing_steps=process_step))
            self.fc1 = Linear(2 * gnn_layers * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'max':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jk.reset_parameters()
        self.s2s.reset_parameters()
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        self.fc3.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Embedding the categorical values from Gene expression and Node type
        xc = x[:, :2].type(torch.long)
        ems = self.embedding(xc)
        ems = ems.view(-1, ems.shape[1] * ems.shape[2])
        x = torch.cat((ems, x[:, 2:]), dim=1)

        xs = []
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            xs += [x]
        x = self.jk(xs)
        x = self.s2s(x, batch)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)

        return logits
Пример #18
0
class EdgeMPNN(torch.nn.Module):
    """
    An MPNN operating in the line graph.
    """
    def __init__(self,
                 num_input_features,
                 num_classes,
                 num_layers,
                 hidden,
                 dropout_rate: float = 0.0,
                 jump_mode=None,
                 nonlinearity='relu',
                 readout='sum',
                 fully_invar=True):
        super(EdgeMPNN, self).__init__()

        self.max_dim = 1
        self.dropout_rate = dropout_rate
        self.fully_invar = fully_invar
        orient = not self.fully_invar
        self.jump_mode = jump_mode
        self.convs = torch.nn.ModuleList()
        self.nonlinearity = nonlinearity
        self.pooling_fn = get_pooling_fn(readout)
        for i in range(num_layers):
            layer_dim = num_input_features if i == 0 else hidden
            # We pass this lambda function to discard upper adjacencies
            update_up = lambda x: 0
            update_down = Linear(layer_dim, hidden, bias=False)
            update = Linear(layer_dim, hidden, bias=False)
            self.convs.append(
                OrientedConv(dim=1,
                             up_msg_size=layer_dim,
                             down_msg_size=layer_dim,
                             update_up_nn=update_up,
                             update_down_nn=update_down,
                             update_nn=update,
                             act_fn=get_nonlinearity(nonlinearity,
                                                     return_module=False),
                             orient=orient))
        self.jump = JumpingKnowledge(
            jump_mode) if jump_mode is not None else None
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        if self.jump_mode is not None:
            self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data: CochainBatch, include_partial=False):
        if self.fully_invar:
            data.x = torch.abs(data.x)
        for c, conv in enumerate(self.convs):
            x = conv(data)
            data.x = x
        cell_pred = x

        batch_size = data.batch.max() + 1
        if not self.fully_invar:
            x = torch.abs(x)
        x = self.pooling_fn(x, data.batch, size=batch_size)

        # At this point we have invariance: we can use any non-linearity we like.
        # Here, independently from previous non-linearities, we choose ReLU.
        # Note that this makes the model non-linear even when employing identity
        # in previous layers.
        x = torch.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.lin2(x)

        if include_partial:
            return x, cell_pred
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #19
0
class MultiLayerCoarsening(torch.nn.Module):
    def __init__(self, dataset, hidden, num_layers=2, ratio=0.5):
        super(MultiLayerCoarsening, self).__init__()

        self.embed_block1 = DenseGCNConv(dataset.num_features, hidden)
        self.coarse_block1 = CoarsenBlock(hidden, ratio)
        self.embed_block2 = DenseGCNConv(hidden, dataset.num_features)
        # self.embed_block2 = GNNBlock(hidden, hidden, dataset.num_features)

        self.num_layers = num_layers

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(hidden + dataset.num_features, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.coarse_block1.reset_parameters()
        self.embed_block2.reset_parameters()

        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, epsilon=0.01, opt_epochs=100):
        x, adj, mask = data.x, data.adj, data.mask
        batch_num_nodes = data.mask.sum(-1)

        new_adjs = [adj]
        Ss = []

        x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        xs = [x1.mean(dim=1)]
        new_adj = adj
        coarse_x = x1
        # coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes)
        # new_adjs.append(new_adj)
        # Ss.append(S)

        for i in range(self.num_layers):
            coarse_x, new_adj, S = self.coarse_block1(coarse_x, new_adj,
                                                      batch_num_nodes)
            new_adjs.append(new_adj)
            Ss.append(S)
        x2 = self.embed_block2(
            coarse_x, new_adj, mask, add_loop=True
        )  #should not add ReLu, otherwise x2 could be all zero.
        xs.append(x2.mean(dim=1))
        opt_loss = 0.0
        for i in range(len(x)):
            x3 = self.get_nonzero_rows(x[i])
            x4 = self.get_nonzero_rows(x2[i])
            if x3.size()[0] == 0:
                continue
            if x4.size()[0] == 0:
                # opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs).float()
                continue
            opt_loss += sinkhorn_loss_default(x3,
                                              x4,
                                              epsilon,
                                              niter=opt_epochs).float()

        return xs, new_adjs, Ss, opt_loss

    def get_nonzero_rows(self, M):  # M is a matrix
        # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0.........
        #So we use other methods to take place of it
        MM, MM_ind = torch.abs(M.sum(-1)).sort()
        N = (torch.abs(M.sum(-1)) > 0).sum()
        return M[MM_ind[:N]]

    def predict(self, xs):
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #20
0
class MultiLayerCoarsening(torch.nn.Module):
    def __init__(self, dataset, hidden, num_layers=2, ratio=0.5):
        super(MultiLayerCoarsening, self).__init__()

        self.embed_block1 = DenseGCNConv(dataset.num_features, hidden)
        self.coarse_block1 = CoarsenBlock(hidden, ratio)
        self.embed_block2 = DenseGCNConv(hidden, dataset.num_features)
        # self.embed_block2 = GNNBlock(hidden, hidden, dataset.num_features)

        self.num_layers = num_layers

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear( hidden *num_layers, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.coarse_block1.reset_parameters()

        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, epsilon=0.01, opt_epochs=100):
        x, adj, mask = data.x, data.adj, data.mask
        batch_num_nodes = data.mask.sum(-1)

        new_adjs = [adj]
        Ss = []

        x1 = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        # xs = [x1.mean(dim=1)]
        xs = []
        coarse_x, new_adj, S = self.coarse_block1(x1, adj, batch_num_nodes)
        new_adjs.append(new_adj)
        Ss.append(S)
        # x2 = F.relu(self.embed_block1(coarse_x, new_adj, mask, add_loop=True))
        xs.append(coarse_x.mean(dim=1))

        x2 = self.embed_block2(coarse_x, new_adj, mask,
                               add_loop=True)  # should not add ReLu, otherwise x2 could be all zero.
        # xs.append(x2.mean(dim=1))

        for i in range(self.num_layers-1):
            x1 = F.relu(self.embed_block1(F.relu(x2), new_adj, mask, add_loop=True))
            coarse_x, new_adj, S = self.coarse_block1(x1, new_adj, batch_num_nodes)
            new_adjs.append(new_adj)
            Ss.append(S)
            xs.append(coarse_x.mean(dim=1))
            x2 = self.embed_block2(coarse_x, new_adj, mask, add_loop=True)#should not add ReLu, otherwise x2 could be all zero.
        # xs.append(x2.mean(dim=1))
        opt_loss = 0.0
        for i in range(len(x)):
            x3 = self.get_nonzero_rows(x[i])
            x4 = self.get_nonzero_rows(x2[i])
            if x3.size()[0]==0:
                continue
            if x4.size()[0]==0:
                opt_loss += sinkhorn_loss_default(x3, x2[i], epsilon, niter=opt_epochs).float()
                continue
            opt_loss += sinkhorn_loss_default(x3, x4, epsilon, niter=opt_epochs).float()

        return xs, new_adjs, Ss, opt_loss

    def get_nonzero_rows(self, M):# M is a matrix
        # row_ind = M.sum(-1).nonzero().squeeze() #nonzero has bugs in Pytorch 1.2.0.........
        #So we use other methods to take place of it
        MM, MM_ind = torch.abs(M.sum(-1)).sort()
        N = (torch.abs(M.sum(-1))>0).sum()
        return M[MM_ind[:N]]

    def predict(self, xs):
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)


    def test(self, train_z, train_y, test_z, test_y, solver='lbfgs',
             multi_class='auto', *args, **kwargs):
        r"""Evaluates latent space quality via a logistic regression downstream
        task."""
        clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                 **kwargs).fit(train_z.detach().cpu().numpy(),
                                               train_y.detach().cpu().numpy())
        return clf.score(test_z.detach().cpu().numpy(),
                         test_y.detach().cpu().numpy())

    def __repr__(self):
        return self.__class__.__name__
Пример #21
0
class JK(nn.Module):
    def __init__(self,
                 nfeat,
                 nhid,
                 nclass,
                 dropout=0.5,
                 lr=0.01,
                 weight_decay=5e-4,
                 n_edge=1,
                 with_relu=True,
                 drop=False,
                 with_bias=True,
                 device=None):

        super(JK, self).__init__()

        assert device is not None, "Please specify 'device'!"
        self.device = device
        self.nfeat = nfeat
        self.hidden_sizes = [nhid]
        self.nclass = int(nclass)
        self.dropout = dropout
        self.lr = lr
        self.drop = drop
        if not with_relu:
            self.weight_decay = 0
        else:
            self.weight_decay = weight_decay
        self.with_relu = with_relu
        self.with_bias = with_bias
        self.n_edge = n_edge
        self.output = None
        self.best_model = None
        self.best_output = None
        self.adj_norm = None
        self.features = None
        self.gate = Parameter(torch.rand(1))  # creat a generator between [0,1]
        # self.beta = Parameter(torch.Tensor(self.n_edge))
        nclass = int(nclass)
        """JK from torch-geometric"""
        num_features = nfeat
        dim = nhid
        nn1 = Sequential(
            Linear(num_features, dim),
            ReLU(),
        )
        self.gc1 = GINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(
            Linear(dim, dim),
            ReLU(),
        )
        self.gc2 = GINConv(nn2)
        nn3 = Sequential(
            Linear(dim, dim),
            ReLU(),
        )
        self.gc3 = GINConv(nn3)

        self.jump = JumpingKnowledge(mode='cat')  # 'cat', 'lstm', 'max'
        self.bn2 = torch.nn.BatchNorm1d(dim)
        # self.fc1 = Linear(dim*3, dim)
        self.fc2 = Linear(dim * 2, int(nclass))

    def forward(self, x, adj):
        """we don't change the edge_index, just update the edge_weight;
        some edge_weight are regarded as removed if it equals to zero"""
        x = x.to_dense()
        edge_index = adj._indices()
        """GJK-Nets"""
        if self.attention:
            adj = self.att_coef(x, adj, i=0)
        x1 = F.relu(
            self.gc1(x, edge_index=edge_index, edge_weight=adj._values()))
        if self.attention:  # if attention=True, use attention mechanism
            adj_2 = self.att_coef(x1, adj, i=1)
            adj_values = self.gate * adj._values() + (
                1 - self.gate) * adj_2._values()
        else:
            adj_values = adj._values()
        x1 = F.dropout(x1, self.dropout, training=self.training)
        x2 = F.relu(self.gc2(x1, edge_index=edge_index,
                             edge_weight=adj_values))
        # x2 = self.bn1(x2)
        # if self.attention:  # if attention=True, use attention mechanism
        #     adj_3 = self.att_coef(x2, adj, i=1)
        #     adj_values = self.gate * adj_2._values() + (1 - self.gate) * adj_3._values()
        # else:
        #     adj_values = adj._values()
        x2 = F.dropout(x2, self.dropout, training=self.training)
        # x3 = F.relu(self.gc2(x2, edge_index=edge_index, edge_weight=adj_values))
        # x3 = F.dropout(x3, self.dropout, training=self.training)

        x_last = self.jump([x1, x2])
        x_last = F.dropout(x_last, self.dropout, training=self.training)
        x_last = self.fc2(x_last)

        return F.log_softmax(x_last, dim=1)

    def initialize(self):
        self.gc1.reset_parameters()
        self.gc2.reset_parameters()
        self.fc2.reset_parameters()
        try:
            self.jump.reset_parameters()
            self.fc1.reset_parameters()
            self.gc3.reset_parameters()
        except:
            pass

    def att_coef(self, fea, edge_index, is_lil=False, i=0):
        if is_lil == False:
            edge_index = edge_index._indices()
        else:
            edge_index = edge_index.tocoo()

        n_node = fea.shape[0]
        row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu(
        ).data.numpy()[:]
        # row, col = edge_index[0], edge_index[1]

        fea_copy = fea.cpu().data.numpy()
        sim_matrix = cosine_similarity(X=fea_copy,
                                       Y=fea_copy)  # try cosine similarity
        sim = sim_matrix[row, col]
        sim[sim < 0.1] = 0
        # print('dropped {} edges'.format(1-sim.nonzero()[0].shape[0]/len(sim)))

        # """use jaccard for binary features and cosine for numeric features"""
        # fea_start, fea_end = fea[edge_index[0]], fea[edge_index[1]]
        # isbinray = np.array_equal(fea_copy, fea_copy.astype(bool))  # check is the fea are binary
        # np.seterr(divide='ignore', invalid='ignore')
        # if isbinray:
        #     fea_start, fea_end = fea_start.T, fea_end.T
        #     sim = jaccard_score(fea_start, fea_end, average=None)  # similarity scores of each edge
        # else:
        #     fea_copy[np.isinf(fea_copy)] = 0
        #     fea_copy[np.isnan(fea_copy)] = 0
        #     sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy)  # try cosine similarity
        #     sim = sim_matrix[edge_index[0], edge_index[1]]
        #     sim[sim < 0.01] = 0
        """build a attention matrix"""
        att_dense = lil_matrix((n_node, n_node), dtype=np.float32)
        att_dense[row, col] = sim
        if att_dense[0, 0] == 1:
            att_dense = att_dense - sp.diags(
                att_dense.diagonal(), offsets=0, format="lil")
        # normalization, make the sum of each row is 1
        att_dense_norm = normalize(att_dense, axis=1, norm='l1')
        """add learnable dropout, make character vector"""
        if self.drop:
            character = np.vstack(
                (att_dense_norm[row, col].A1, att_dense_norm[col, row].A1))
            character = torch.from_numpy(character.T)
            drop_score = self.drop_learn_1(character)
            drop_score = torch.sigmoid(
                drop_score
            )  # do not use softmax since we only have one element
            mm = torch.nn.Threshold(0.5, 0)
            drop_score = mm(drop_score)
            mm_2 = torch.nn.Threshold(-0.49, 1)
            drop_score = mm_2(-drop_score)
            drop_decision = drop_score.clone().requires_grad_()
            # print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0])
            drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32)
            drop_matrix[row,
                        col] = drop_decision.cpu().data.numpy().squeeze(-1)
            att_dense_norm = att_dense_norm.multiply(
                drop_matrix.tocsr())  # update, remove the 0 edges

        if att_dense_norm[
                0,
                0] == 0:  # add the weights of self-loop only add self-loop at the first layer
            degree = (att_dense_norm != 0).sum(1).A1
            # degree = degree.squeeze(-1).squeeze(-1)
            lam = 1 / (degree + 1)  # degree +1 is to add itself
            self_weight = sp.diags(np.array(lam), offsets=0, format="lil")
            att = att_dense_norm + self_weight  # add the self loop
        else:
            att = att_dense_norm

        att_adj = edge_index
        att_edge_weight = att[row, col]
        att_edge_weight = np.exp(att_edge_weight)  # exponent, kind of softmax
        att_edge_weight = torch.tensor(np.array(att_edge_weight)[0],
                                       dtype=torch.float32).cuda()

        shape = (n_node, n_node)
        new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape)
        return new_adj

    def add_loop_sparse(self, adj, fill_value=1):
        # make identify sparse tensor
        row = torch.range(0, int(adj.shape[0] - 1), dtype=torch.int64)
        i = torch.stack((row, row), dim=0)
        v = torch.ones(adj.shape[0], dtype=torch.float32)
        shape = adj.shape
        I_n = torch.sparse.FloatTensor(i, v, shape)
        return adj + I_n.to(self.device)

    def fit(
        self,
        features,
        adj,
        labels,
        idx_train,
        idx_val=None,
        idx_test=None,
        train_iters=81,
        att_0=None,
        attention=False,
        model_name=None,
        initialize=True,
        verbose=False,
        normalize=False,
        patience=500,
    ):
        '''
            train the gcn model, when idx_val is not None, pick the best model
            according to the validation loss
        '''
        self.sim = None
        self.attention = attention
        self.idx_test = idx_test
        # self.device = self.gc1.weight.device

        if initialize:
            self.initialize()

        if type(adj) is not torch.Tensor:
            features, adj, labels = utils.to_tensor(features,
                                                    adj,
                                                    labels,
                                                    device=self.device)
        else:
            features = features.to(self.device)
            adj = adj.to(self.device)
            labels = labels.to(self.device)

        # normalize = False # we don't need normalize here, the norm is conducted in the GCN (self.gcn1) model
        # if normalize:
        #     if utils.is_sparse_tensor(adj):
        #         adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
        #     else:
        #         adj_norm = utils.normalize_adj_tensor(adj)
        # else:
        #     adj_norm = adj

        adj = self.add_loop_sparse(adj)
        """Make the coefficient D^{-1/2}(A+I)D^{-1/2}"""
        self.adj_norm = adj
        self.features = features
        self.labels = labels

        if idx_val is None:
            self._train_without_val(labels, idx_train, train_iters, verbose)
        else:
            if patience < train_iters:
                self._train_with_early_stopping(labels, idx_train, idx_val,
                                                train_iters, patience, verbose)
            else:
                self._train_with_val(labels, idx_train, idx_val, train_iters,
                                     verbose)

    def _train_without_val(self, labels, idx_train, train_iters, verbose):
        self.train()
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        for i in range(train_iters):
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(
                output[idx_train], labels[idx_train], weight=None
            )  # this weight is the weight of each training nodes
            loss_train.backward()
            optimizer.step()
            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(
                    i, loss_train.item()))

        self.eval()
        output = self.forward(self.features, self.adj_norm)
        self.output = output

    def _train_with_val(self, labels, idx_train, idx_val, train_iters,
                        verbose):
        if verbose:
            print('=== training gcn model ===')
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        best_loss_val = 100
        best_acc_val = 0

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            # pred = output[self.idx_test].max(1)[1]
            # acc_test =accuracy(output[self.idx_test], labels[self.idx_test])
            # acc_test = pred.eq(labels[self.idx_test]).sum().item() / self.idx_test.shape[0]

            self.eval()
            output = self.forward(self.features, self.adj_norm)
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            if verbose and i % 20 == 0:
                print('Epoch {}, training loss: {}, test acc: {}'.format(
                    i, loss_train.item(), acc_val))

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output
                weights = deepcopy(self.state_dict())

        if verbose:
            print(
                '=== picking the best model according to the performance on validation ==='
            )
        self.load_state_dict(weights)

    def _train_with_early_stopping(self, labels, idx_train, idx_val,
                                   train_iters, patience, verbose):
        if verbose:
            print('=== training gcn model ===')
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        early_stopping = patience
        best_loss_val = 100

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            self.eval()
            output = self.forward(self.features, self.adj_norm)

            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(
                    i, loss_train.item()))

            loss_val = F.nll_loss(output[idx_val], labels[idx_val])

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())
                patience = early_stopping
            else:
                patience -= 1
            if i > early_stopping and patience <= 0:
                break

        if verbose:
            print('=== early stopping at {0}, loss_val = {1} ==='.format(
                i, best_loss_val))
        self.load_state_dict(weights)

    def test(self, idx_test, model_name=None):
        # self.model_name = model_name
        self.eval()
        output = self.predict()
        # output = self.output
        loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
        acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
        print("Test set results:", "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return acc_test, output

    def _set_parameters(self):
        # TODO
        pass

    def predict(self, features=None, adj=None):
        '''By default, inputs are unnormalized data'''

        # self.eval()
        if features is None and adj is None:
            return self.forward(self.features, self.adj_norm)
        else:
            if type(adj) is not torch.Tensor:
                features, adj = utils.to_tensor(features,
                                                adj,
                                                device=self.device)

            self.features = features
            if utils.is_sparse_tensor(adj):
                self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
            else:
                self.adj_norm = utils.normalize_adj_tensor(adj)
            return self.forward(self.features, self.adj_norm)
class Graclus(torch.nn.Module):
    def __init__(self, num_features, num_classes, num_layers, hidden, pooling_type,
                 no_cat=False, encode_edge=False):
        super(Graclus, self).__init__()
        self.encode_edge = encode_edge
        if encode_edge:
            self.conv1 = GCNConv(hidden, aggr='add')
        else:
            self.conv1 = GraphConv(num_features, hidden, aggr='add')

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GraphConv(hidden, hidden, aggr='add'))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        if no_cat:
            self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)
        self.pooling_type = pooling_type
        self.no_cat = no_cat

        self.atom_encoder = AtomEncoder(emb_dim=hidden)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.encode_edge:
            x = self.atom_encoder(x)
            x = self.conv1(x, edge_index, data.edge_attr)
        else:
            x = self.conv1(x, edge_index)
        x = F.relu(x)
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if self.pooling_type != 'none':
                if self.pooling_type == 'complement':
                    complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True)
                    cluster = graclus(complement, num_nodes=x.size(0))
                elif self.pooling_type == 'graclus':
                    cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch

        if not self.no_cat:
            x = self.jump(xs)
        else:
            x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x
Пример #23
0
class MincutPool(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, ratio=0.1):
        super(MincutPool, self).__init__()

        num_nodes = ceil(ratio * dataset[0].num_nodes)
        self.embed_block1 = Block(dataset.num_features, hidden, hidden)
        self.pool_block1 = Linear(hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Linear(hidden, num_nodes))
        self.embed_final = Block(hidden, hidden, hidden)

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(hidden * (num_layers + 1), hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for embed_block, pool_block in zip(self.embed_blocks,
                                           self.pool_blocks):
            embed_block.reset_parameters()
            pool_block.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, adj, mask = data.x, data.adj, data.mask
        mincut_losses = 0.
        ortho_losses = 0.
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        s = self.pool_block1(x)
        xs = [x.mean(dim=1)]
        x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s, mask)
        mincut_losses += mincut_loss
        ortho_losses += ortho_loss
        for i, (embed_block, pool_block) in enumerate(
                zip(self.embed_blocks, self.pool_blocks)):
            x = F.relu(embed_block(x, adj))
            s = pool_block(x)
            xs.append(x.mean(dim=1))
            if i < len(self.embed_blocks):
                x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s)
                mincut_losses += mincut_loss
                ortho_losses += ortho_loss

        x = F.relu(self.embed_final(x, adj, add_loop=True))
        xs.append(x.mean(dim=1))
        x = self.jump(xs)

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        # print(mincut_losses+ortho_losses)
        return F.log_softmax(x, dim=-1), mincut_losses + ortho_losses

    def __repr__(self):
        return self.__class__.__name__
Пример #24
0
class GCNWithJK(torch.nn.Module):
    def __init__(self, num_layers, num_input_features, hidden, mode='cat'):
        super(GCNWithJK, self).__init__()
        self.conv1 = GCNConv(num_input_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))

        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':  # concatenation
            self.lin1 = Linear(3 * num_layers * hidden, hidden)
        else:
            self.lin1 = Linear(3 * hidden, hidden)
        self.lin2 = Linear(hidden, 2)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        # data: Batch(batch=[num_nodes_in_batch],
        #               edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge],
        #               edge_index=[2,2*num_nodes_in_batch],
        #               pos=[num_nodes_in_batch,2],
        #               x=[num_nodes_in_batch, num_input_features_per_node],
        #               y=[num_graphs_in_batch, num_classes]
        # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2]

        x, edge_index, batch = data.x, data.edge_index, data.batch
        # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node])
        # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch])
        # batch.shape: torch.Size([num_nodes_in_batch])
        # example:  x.shape = troch.Size([2490,33])
        #           edge_index.shape = torch.Size([2,4980])
        #           batch.shape = torch.Size([2490])

        x = F.relu(self.conv1(x, edge_index))
        # x.shape: torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = troch.Size([2490,66])
        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]  # xs: list containing layer-wise representations
        x = self.jump(
            xs
        )  # aggregate information across different layers (concatenation)
        # x.shape: torch.Size([num_nodes_in_batch, num_layers * hidden])
        # example: x.shape = torch.Size([2490, 3*66])

        x = torch.cat([
            global_add_pool(x, batch),
            global_mean_pool(x, batch),
            global_max_pool(x, batch)
        ],
                      dim=1)
        # x.shape: torch.Size([num_graphs_in_batch, 3*num_layers * hidden])
        # example: x.shape = torch.Size([32, 3*3*66])

        x = F.relu(self.lin1(x))
        # x.shape: torch.Size([num_graphs_in_batch, hidden])
        # example: x.shape = torch.Size([32, 66])
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        # x.shape: torch.Size([num_graphs_in_batch, num_classes])
        # example: x.shape = torch.Size([32, 2])
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
Пример #25
0
class CIN0(torch.nn.Module):
    """
    A cellular version of GIN.

    This model is based on
    https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py
    """
    def __init__(self,
                 num_input_features,
                 num_classes,
                 num_layers,
                 hidden,
                 dropout_rate: float = 0.5,
                 max_dim: int = 2,
                 jump_mode=None,
                 nonlinearity='relu',
                 readout='sum'):
        super(CIN0, self).__init__()

        self.max_dim = max_dim
        self.dropout_rate = dropout_rate
        self.jump_mode = jump_mode
        self.convs = torch.nn.ModuleList()
        self.nonlinearity = nonlinearity
        self.pooling_fn = get_pooling_fn(readout)
        conv_nonlinearity = get_nonlinearity(nonlinearity, return_module=True)
        for i in range(num_layers):
            layer_dim = num_input_features if i == 0 else hidden
            conv_update = Sequential(Linear(layer_dim,
                                            hidden), conv_nonlinearity(),
                                     Linear(hidden, hidden),
                                     conv_nonlinearity(), BN(hidden))
            conv_up = Sequential(Linear(layer_dim * 2, layer_dim),
                                 conv_nonlinearity(), BN(layer_dim))
            conv_down = Sequential(Linear(layer_dim * 2, layer_dim),
                                   conv_nonlinearity(), BN(layer_dim))
            self.convs.append(
                CINConv(layer_dim,
                        layer_dim,
                        conv_up,
                        conv_down,
                        conv_update,
                        train_eps=False,
                        max_dim=self.max_dim))
        self.jump = JumpingKnowledge(
            jump_mode) if jump_mode is not None else None
        if jump_mode == 'cat':
            self.lin1 = Linear(num_layers * hidden, hidden)
        else:
            self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        if self.jump_mode is not None:
            self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def pool_complex(self, xs, data):
        # All complexes have nodes so we can extract the batch size from cochains[0]
        batch_size = data.cochains[0].batch.max() + 1
        # The MP output is of shape [message_passing_dim, batch_size, feature_dim]
        pooled_xs = torch.zeros(self.max_dim + 1,
                                batch_size,
                                xs[0].size(-1),
                                device=batch_size.device)
        for i in range(len(xs)):
            # It's very important that size is supplied.
            pooled_xs[i, :, :] = self.pooling_fn(xs[i],
                                                 data.cochains[i].batch,
                                                 size=batch_size)
        return pooled_xs

    def jump_complex(self, jump_xs):
        # Perform JumpingKnowledge at each level of the complex
        xs = []
        for jumpx in jump_xs:
            xs += [self.jump(jumpx)]
        return xs

    def forward(self, data: ComplexBatch):
        model_nonlinearity = get_nonlinearity(self.nonlinearity,
                                              return_module=False)
        xs, jump_xs = None, None
        for c, conv in enumerate(self.convs):
            params = data.get_all_cochain_params(max_dim=self.max_dim)
            xs = conv(*params)
            data.set_xs(xs)

            if self.jump_mode is not None:
                if jump_xs is None:
                    jump_xs = [[] for _ in xs]
                for i, x in enumerate(xs):
                    jump_xs[i] += [x]

        if self.jump_mode is not None:
            xs = self.jump_complex(jump_xs)
        pooled_xs = self.pool_complex(xs, data)
        x = pooled_xs.sum(dim=0)

        x = model_nonlinearity(self.lin1(x))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.lin2(x)
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #26
0
class SparseCIN(torch.nn.Module):
    """
    A cellular version of GIN.

    This model is based on
    https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py
    """
    def __init__(self,
                 num_input_features,
                 num_classes,
                 num_layers,
                 hidden,
                 dropout_rate: float = 0.5,
                 max_dim: int = 2,
                 jump_mode=None,
                 nonlinearity='relu',
                 readout='sum',
                 train_eps=False,
                 final_hidden_multiplier: int = 2,
                 use_coboundaries=False,
                 readout_dims=(0, 1, 2),
                 final_readout='sum',
                 apply_dropout_before='lin2',
                 graph_norm='bn'):
        super(SparseCIN, self).__init__()

        self.max_dim = max_dim
        if readout_dims is not None:
            self.readout_dims = tuple(
                [dim for dim in readout_dims if dim <= max_dim])
        else:
            self.readout_dims = list(range(max_dim + 1))
        self.final_readout = final_readout
        self.dropout_rate = dropout_rate
        self.apply_dropout_before = apply_dropout_before
        self.jump_mode = jump_mode
        self.convs = torch.nn.ModuleList()
        self.nonlinearity = nonlinearity
        self.pooling_fn = get_pooling_fn(readout)
        self.graph_norm = get_graph_norm(graph_norm)
        act_module = get_nonlinearity(nonlinearity, return_module=True)
        for i in range(num_layers):
            layer_dim = num_input_features if i == 0 else hidden
            self.convs.append(
                SparseCINConv(up_msg_size=layer_dim,
                              down_msg_size=layer_dim,
                              boundary_msg_size=layer_dim,
                              passed_msg_boundaries_nn=None,
                              passed_msg_up_nn=None,
                              passed_update_up_nn=None,
                              passed_update_boundaries_nn=None,
                              train_eps=train_eps,
                              max_dim=self.max_dim,
                              hidden=hidden,
                              act_module=act_module,
                              layer_dim=layer_dim,
                              graph_norm=self.graph_norm,
                              use_coboundaries=use_coboundaries))
        self.jump = JumpingKnowledge(
            jump_mode) if jump_mode is not None else None
        self.lin1s = torch.nn.ModuleList()
        for _ in range(max_dim + 1):
            if jump_mode == 'cat':
                # These layers don't use a bias. Then, in case a level is not present the output
                # is just zero and it is not given by the biases.
                self.lin1s.append(
                    Linear(num_layers * hidden,
                           final_hidden_multiplier * hidden,
                           bias=False))
            else:
                self.lin1s.append(
                    Linear(hidden, final_hidden_multiplier * hidden))
        self.lin2 = Linear(final_hidden_multiplier * hidden, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        if self.jump_mode is not None:
            self.jump.reset_parameters()
        self.lin1s.reset_parameters()
        self.lin2.reset_parameters()

    def pool_complex(self, xs, data):
        # All complexes have nodes so we can extract the batch size from cochains[0]
        batch_size = data.cochains[0].batch.max() + 1
        # print(batch_size)
        # The MP output is of shape [message_passing_dim, batch_size, feature_dim]
        pooled_xs = torch.zeros(self.max_dim + 1,
                                batch_size,
                                xs[0].size(-1),
                                device=batch_size.device)
        for i in range(len(xs)):
            # It's very important that size is supplied.
            pooled_xs[i, :, :] = self.pooling_fn(xs[i],
                                                 data.cochains[i].batch,
                                                 size=batch_size)

        new_xs = []
        for i in range(self.max_dim + 1):
            new_xs.append(pooled_xs[i])
        return new_xs

    def jump_complex(self, jump_xs):
        # Perform JumpingKnowledge at each level of the complex
        xs = []
        for jumpx in jump_xs:
            xs += [self.jump(jumpx)]
        return xs

    def forward(self, data: ComplexBatch, include_partial=False):
        act = get_nonlinearity(self.nonlinearity, return_module=False)

        xs, jump_xs = None, None
        res = {}
        for c, conv in enumerate(self.convs):
            params = data.get_all_cochain_params(max_dim=self.max_dim,
                                                 include_down_features=False)
            start_to_process = 0
            # if i == len(self.convs) - 2:
            #     start_to_process = 1
            # if i == len(self.convs) - 1:
            #     start_to_process = 2
            xs = conv(*params, start_to_process=start_to_process)
            data.set_xs(xs)

            if include_partial:
                for k in range(len(xs)):
                    res[f"layer{c}_{k}"] = xs[k]

            if self.jump_mode is not None:
                if jump_xs is None:
                    jump_xs = [[] for _ in xs]
                for i, x in enumerate(xs):
                    jump_xs[i] += [x]

        if self.jump_mode is not None:
            xs = self.jump_complex(jump_xs)

        xs = self.pool_complex(xs, data)
        # Select the dimensions we want at the end.
        xs = [xs[i] for i in self.readout_dims]

        if include_partial:
            for k in range(len(xs)):
                res[f"pool_{k}"] = xs[k]

        new_xs = []
        for i, x in enumerate(xs):
            if self.apply_dropout_before == 'lin1':
                x = F.dropout(x, p=self.dropout_rate, training=self.training)
            new_xs.append(act(self.lin1s[self.readout_dims[i]](x)))

        x = torch.stack(new_xs, dim=0)

        if self.apply_dropout_before == 'final_readout':
            x = F.dropout(x, p=self.dropout_rate, training=self.training)
        if self.final_readout == 'mean':
            x = x.mean(0)
        elif self.final_readout == 'sum':
            x = x.sum(0)
        else:
            raise NotImplementedError
        if self.apply_dropout_before not in ['lin1', 'final_readout']:
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.lin2(x)

        if include_partial:
            res['out'] = x
            return x, res
        return x

    def __repr__(self):
        return self.__class__.__name__
Пример #27
0
class DEA_GNN_JK(torch.nn.Module):
    def __init__(self, num_nodes, embed_dim, 
                 gnn_in_dim, gnn_hidden_dim, gnn_out_dim, gnn_num_layers, 
                 mlp_in_dim, mlp_hidden_dim, mlp_out_dim=1, mlp_num_layers=2, 
                 dropout=0.5, gnn_batchnorm=False, mlp_batchnorm=False, K=2, jk_mode='max'):
        super(DEA_GNN_JK, self).__init__()
        
        assert jk_mode in ['max','sum','mean','lstm','cat']
        # Embedding
        self.emb = torch.nn.Embedding(num_nodes, embedding_dim=embed_dim)

        # GNN 
        convs_list = [TAGConv(gnn_in_dim, gnn_hidden_dim, K)]
        for i in range(gnn_num_layers-2):
            convs_list.append(TAGConv(gnn_hidden_dim, gnn_hidden_dim, K))
        convs_list.append(TAGConv(gnn_hidden_dim, gnn_out_dim, K))
        self.convs = torch.nn.ModuleList(convs_list)

        # MLP
        lins_list = [torch.nn.Linear(mlp_in_dim, mlp_hidden_dim)]
        for i in range(mlp_num_layers-2):
            lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_hidden_dim))
        lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_out_dim))
        self.lins = torch.nn.ModuleList(lins_list)

        # Batchnorm
        self.gnn_batchnorm = gnn_batchnorm
        self.mlp_batchnorm = mlp_batchnorm
        if self.gnn_batchnorm:
            self.gnn_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(gnn_hidden_dim) for i in range(gnn_num_layers)])
        
        if self.mlp_batchnorm:
            self.mlp_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(mlp_hidden_dim) for i in range(mlp_num_layers-1)])

        self.jk_mode = jk_mode
        if self.jk_mode in ['max', 'lstm', 'cat']:
            self.jk = JumpingKnowledge(mode=self.jk_mode, channels=gnn_hidden_dim, num_layers=gnn_num_layers)

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.emb.weight)  
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
        if self.gnn_batchnorm:
            for bn in self.gnn_bns:
                bn.reset_parameters()
        if self.mlp_batchnorm:
            for bn in self.mlp_bns:
                bn.reset_parameters()
        if self.jk_mode in ['max', 'lstm', 'cat']:
            self.jk.reset_parameters()

    def forward(self, x_feature, adj_t, edge_label_index):
        
        if x_feature is not None:
            out = torch.cat([self.emb.weight, x_feature], dim=1)
        else:
            out = self.emb.weight

        out_list = []
        for i in range(len(self.convs)):
            out = self.convs[i](out, adj_t)
            if self.gnn_batchnorm:
                out = self.gnn_bns[i](out)
            out = F.relu(out)
            out = F.dropout(out, p=self.dropout, training=self.training)
            out_list += [out]

        if self.jk_mode in ['max', 'lstm', 'cat']:
            out = self.jk(out_list)
        elif self.jk_mode == 'mean':
            out_stack = torch.stack(out_list, dim=0)
            out = torch.mean(out_stack, dim=0)
        elif self.jk_mode == 'sum':
            out_stack = torch.stack(out_list, dim=0)
            out = torch.sum(out_stack, dim=0)

        gnn_embed = out[edge_label_index,:]
        embed_product = gnn_embed[0, :, :] * gnn_embed[1, :, :]
        out = embed_product

        for i in range(len(self.lins)-1):
            out = self.lins[i](out)
            if self.mlp_batchnorm:
                out = self.mlp_bns[i](out)
            out = F.relu(out)
            out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.lins[-1](out).squeeze(1)

        return out
    
    def loss(self, y_pred, y_true):
        return self.loss_fn(y_pred, y_true)
Пример #28
0
class Net(torch.nn.Module):
    def __init__(self, num_classes, num_layers, feat_dim, embed_dim, jk_layer,
                 process_step, dropout):
        super(Net, self).__init__()

        self.dropout = dropout
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                self.convs.append(
                    AGGINConv(Sequential(Linear(feat_dim, embed_dim), ReLU(),
                                         Linear(embed_dim, embed_dim), ReLU(),
                                         BN(embed_dim)),
                              train_eps=True,
                              dropout=self.dropout))
            else:
                self.convs.append(
                    AGGINConv(Sequential(Linear(embed_dim, embed_dim), ReLU(),
                                         Linear(embed_dim, embed_dim), ReLU(),
                                         BN(embed_dim)),
                              train_eps=True,
                              dropout=self.dropout))

        if jk_layer.isdigit():
            jk_layer = int(jk_layer)
            self.jump = JumpingKnowledge(mode='lstm',
                                         channels=embed_dim,
                                         num_layers=jk_layer)
            self.gpl = (Set2Set(embed_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * embed_dim, embed_dim)
            # self.fc1 = Linear(embed_dim, embed_dim)
            self.fc2 = Linear(embed_dim, num_classes)
        elif jk_layer == 'cat':
            self.jump = JumpingKnowledge(mode=jk_layer)
            self.gpl = (Set2Set(num_layers * embed_dim,
                                processing_steps=process_step))
            self.fc1 = Linear(2 * embed_dim, embed_dim)
            # self.fc1 = Linear(num_layers * embed_dim, embed_dim)
            self.fc2 = Linear(embed_dim, num_classes)
        elif jk_layer == 'max':
            self.jump = JumpingKnowledge(mode=jk_layer)
            self.gpl = (Set2Set(embed_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * embed_dim, embed_dim)
            # self.fc1 = Linear(embed_dim, embed_dim)
            self.fc2 = Linear(embed_dim, num_classes)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        self.gpl.reset_parameters()
        self.jump.reset_parameters()
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        xs = []
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            xs += [x]
        x = self.jump(xs)
        x = self.gpl(x, batch)
        # x = global_max_pool(x, batch)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, p=self.dropout, training=self.training)
        logits = self.fc2(x)

        return logits
Пример #29
0
class SOMPool(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, ratio=0.25):
        super(SOMPool, self).__init__()

        num_nodes = ceil(ratio * dataset[0].num_nodes)
        self.num_nodes = hidden
        self.embed_block1 = Block(dataset.num_features, hidden, hidden)
        self.pool_block1 = Block(dataset.num_features, hidden, num_nodes)
        self.dimnum = dataset.num_features
        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range((num_layers // 2) - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))

        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for embed_block, pool_block in zip(self.embed_blocks,
                                           self.pool_blocks):
            embed_block.reset_parameters()
            pool_block.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        #print(data)
        x, adj, mask = data.x, data.adj, data.mask

        somnum = ceil(sqrt(self.num_nodes / 1.5))
        som = MiniSom(somnum,
                      somnum,
                      self.dimnum,
                      sigma=0.3,
                      learning_rate=0.5)
        tempdata = x.reshape(-1, self.dimnum)
        tempdata = tempdata.cpu().numpy()
        som.train_batch(tempdata, 100)
        qnt = som.quantization(tempdata)
        qnt = torch.from_numpy(qnt).float().to(device)
        qnt = qnt.reshape(adj.size()[0], -1, self.dimnum)
        #print(qnt.size())
        #print(adj.size())
        s = self.pool_block1(qnt, adj, mask, add_loop=True)
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))
        xs = [x.mean(dim=1)]
        x, adj, _, _ = dense_diff_pool(x, adj, s, mask)

        for i, (embed_block, pool_block) in enumerate(
                zip(self.embed_blocks, self.pool_blocks)):
            s = pool_block(x, adj)
            x = F.relu(embed_block(x, adj))
            xs.append(x.mean(dim=1))
            if i < len(self.embed_blocks) - 1:
                x, adj, _, _ = dense_diff_pool(x, adj, s)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__