Exemplo n.º 1
0
def test_asap():
    in_channels = 16
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])
    num_nodes = edge_index.max().item() + 1
    x = torch.randn((num_nodes, in_channels))

    for GNN in [GraphConv, GCNConv]:
        pool = ASAPooling(in_channels,
                          ratio=0.5,
                          GNN=GNN,
                          add_self_loops=False)
        assert pool.__repr__() == ('ASAPooling(16, ratio=0.5)')
        out = pool(x, edge_index)
        assert out[0].size() == (num_nodes // 2, in_channels)
        assert out[1].size() == (2, 2)

        pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True)
        assert pool.__repr__() == ('ASAPooling(16, ratio=0.5)')
        out = pool(x, edge_index)
        assert out[0].size() == (num_nodes // 2, in_channels)
        assert out[1].size() == (2, 4)

        pool = ASAPooling(in_channels, ratio=2, GNN=GNN, add_self_loops=False)
        assert pool.__repr__() == ('ASAPooling(16, ratio=2)')
        out = pool(x, edge_index)
        assert out[0].size() == (2, in_channels)
        assert out[1].size() == (2, 2)
Exemplo n.º 2
0
    def __init__(self, num_vocab, max_seq_len, node_encoder, emb_dim, num_layers, hidden, ratio=0.8, dropout=0, num_class=0):
        super(ASAP, self).__init__()

        self.num_class = num_class
        self.max_seq_len = max_seq_len
        self.node_encoder = node_encoder

        self.conv1 = GraphConv(emb_dim, hidden, aggr='mean')
        self.convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.convs.extend([
            GraphConv(hidden, hidden, aggr='mean')
            for i in range(num_layers - 1)
        ])
        self.pools.extend([
            ASAPooling(hidden, ratio, dropout=dropout)
            for i in range((num_layers) // 2)
        ])
        self.jump = JumpingKnowledge(mode='cat')
        self.lin1 = Linear(num_layers * hidden, hidden)
        # self.lin2 = Linear(hidden, dataset.num_classes)

        if self.num_class > 0:  # classification
            self.graph_pred_linear = torch.nn.Linear(hidden, self.num_class)
        else:
            self.graph_pred_linear_list = torch.nn.ModuleList()
            for i in range(max_seq_len):
                self.graph_pred_linear_list.append(torch.nn.Linear(hidden, num_vocab))
Exemplo n.º 3
0
    def __init__(self, config):
        super(GIN, self).__init__()
        self.config = config

        self.gin_convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.config.num_layers - 1):
            if layer == 0:
                nn = Sequential(
                    Linear(self.config.num_feature_dim,
                           self.config.hidden_dim), ReLU(),
                    Linear(self.config.hidden_dim, self.config.hidden_dim))
            else:
                nn = Sequential(
                    Linear(self.config.hidden_dim, self.config.hidden_dim),
                    ReLU(),
                    Linear(self.config.hidden_dim, self.config.hidden_dim))
            self.gin_convs.append(GINConv(nn))
            self.batch_norms.append(
                torch.nn.BatchNorm1d(self.config.hidden_dim))

        if self.config.pooling_type == "sagpool":
            self.pool1 = SAGPooling(self.config.hidden_dim,
                                    ratio=self.config.poolratio)
        elif self.config.pooling_type == "topk":
            self.pool1 = TopKPooling(self.config.hidden_dim,
                                     ratio=self.config.poolratio)
        elif self.config.pooling_type == "asa":
            self.pool1 = ASAPooling(self.config.hidden_dim,
                                    ratio=self.config.poolratio)

        self.fc1 = Linear(self.config.hidden_dim, self.config.hidden_dim)
        self.fc2 = Linear(self.config.hidden_dim, self.config.embed_dim)
Exemplo n.º 4
0
    def poollayer(self, pooltype):

        self.pooltype = pooltype

        if self.pooltype == 'TopKPool':
            self.pool1 = TopKPooling(1024)
            self.pool2 = TopKPooling(1024)
        elif self.pooltype == 'EdgePool':
            self.pool1 = EdgePooling(1024)
            self.pool2 = EdgePooling(1024)
        elif self.pooltype == 'ASAPool':
            self.pool1 = ASAPooling(1024)
            self.pool2 = ASAPooling(1024)
        elif self.pooltype == 'SAGPool':
            self.pool1 = SAGPooling(1024)
            self.pool2 = SAGPooling(1024)
        else:
            print('Such graph pool method is not implemented!!')

        return self.pool1, self.pool2
Exemplo n.º 5
0
 def __init__(self, dataset, num_layers, hidden, ratio=0.8, dropout=0):
     super().__init__()
     self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')
     self.convs = torch.nn.ModuleList()
     self.pools = torch.nn.ModuleList()
     self.convs.extend([
         GraphConv(hidden, hidden, aggr='mean')
         for i in range(num_layers - 1)
     ])
     self.pools.extend([
         ASAPooling(hidden, ratio, dropout=dropout)
         for i in range((num_layers) // 2)
     ])
     self.jump = JumpingKnowledge(mode='cat')
     self.lin1 = Linear(num_layers * hidden, hidden)
     self.lin2 = Linear(hidden, dataset.num_classes)
Exemplo n.º 6
0
    def __init__(self, config):
        super(GCN, self).__init__()

        self.config = config

        self.gc1 = GCNConv(self.config.num_feature_dim, self.config.hidden)
        self.gc2 = GCNConv(self.config.hidden, self.config.hidden)

        if self.config.pooling_type == "sagpool":
            self.pool1 = SAGPooling(self.config.hidden,
                                    ratio=self.config.poolratio)
        elif self.config.pooling_type == "topk":
            self.pool1 = TopKPooling(self.config.hidden,
                                     ratio=self.config.poolratio)
        elif self.config.pooling_type == "asa":
            self.pool1 = ASAPooling(self.config.hidden,
                                    ratio=self.config.poolratio)

        self.fc = nn.Linear(self.config.hidden, self.config.embed_dim)
 def __init__(self,
              num_features,
              num_classes,
              num_layers,
              hidden,
              ratio=0.8,
              dropout=0):
     super(ASAP, self).__init__()
     self.conv1 = ChebConv(num_features, hidden // 2, K=4, aggr='mean')
     self.conv2 = ChebConv(hidden // 2, hidden, K=4, aggr='mean')
     self.convs = torch.nn.ModuleList()
     self.pools = torch.nn.ModuleList()
     self.convs.extend([
         ChebConv(hidden, hidden, K=2, aggr='mean')
         for i in range(num_layers - 1)
     ])
     self.pools.extend([
         ASAPooling(hidden, ratio, dropout=dropout)
         for i in range((num_layers) // 2)
     ])
     self.jump = JumpingKnowledge(mode='cat')
     self.lin1 = Linear(num_layers * hidden, hidden)
     self.lin2 = Linear(hidden, num_classes)