Beispiel #1
0
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 num_layers,
                 dropout,
                 gnn_type='gcn'):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        if gnn_type == 'gat':
            self.convs.append(nn.GATConv(in_channels, hidden_channels, 1))
            for _ in range(num_layers - 2):
                self.convs.append(
                    nn.GATConv(hidden_channels * 1, hidden_channels, 1))
            self.convs.append(nn.GATConv(hidden_channels * 1, out_channels, 1))
        elif gnn_type == 'gcn':
            self.convs.append(
                nn.GraphConv(in_channels, hidden_channels, norm='none'))
            for _ in range(num_layers - 2):
                self.convs.append(
                    nn.GraphConv(hidden_channels, hidden_channels,
                                 norm='none'))
            self.convs.append(
                nn.GraphConv(hidden_channels, out_channels, norm='none'))

        self.dropout = dropout
Beispiel #2
0
 def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout, pred_ntype):
     super().__init__()
     self.convs = nn.ModuleList()
     self.norms = nn.ModuleList()
     self.skips = nn.ModuleList()
     
     self.convs.append(nn.ModuleList([
         dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
         for _ in range(num_etypes)
     ]))
     self.norms.append(nn.BatchNorm1d(hidden_channels))
     self.skips.append(nn.Linear(in_channels, hidden_channels))
     for _ in range(num_layers - 1):
         self.convs.append(nn.ModuleList([
             dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
             for _ in range(num_etypes)
         ]))
         self.norms.append(nn.BatchNorm1d(hidden_channels))
         self.skips.append(nn.Linear(hidden_channels, hidden_channels))
         
     self.mlp = nn.Sequential(
         nn.Linear(hidden_channels, hidden_channels),
         nn.BatchNorm1d(hidden_channels),
         nn.ReLU(),
         nn.Dropout(dropout),
         nn.Linear(hidden_channels, out_channels)
     )
     self.dropout = nn.Dropout(dropout)
     
     self.hidden_channels = hidden_channels
     self.pred_ntype = pred_ntype
     self.num_etypes = num_etypes
Beispiel #3
0
 def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4):
     super().__init__()
     self.layers = nn.ModuleList()
     self.layers.append(
         dglnn.HeteroGraphConv({
             etype: dglnn.GATConv(in_feats, n_hidden // n_heads, n_heads)
             for etype in etypes
         }))
     self.layers.append(
         dglnn.HeteroGraphConv({
             etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads)
             for etype in etypes
         }))
     self.layers.append(
         dglnn.HeteroGraphConv({
             etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads)
             for etype in etypes
         }))
     self.dropout = nn.Dropout(0.5)
     self.linear = nn.Linear(n_hidden, n_classes)  # Should be HeteroLinear
Beispiel #4
0
def test_gat_conv_bi(g, idtype):
    g = g.astype(idtype)
    gat = nn.GATConv(5, 2, 4)
    feat = (F.randn(
        (g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
    init_params = gat.init(jax.random.PRNGKey(2666), g, feat)
    h = gat.apply(init_params, g, feat)
    assert h.shape == (g.number_of_dst_nodes(), 4, 2)

    init_params = gat.init(jax.random.PRNGKey(2666),
                           g,
                           feat,
                           get_attention=True)
    _, a = gat.apply(init_params, g, feat, get_attention=True)
    assert a.shape == (g.number_of_edges(), 4, 1)
Beispiel #5
0
 def __init__(self, gnn_size, in_features, attn_heads):
     super(GCN, self).__init__()
     self.conv1 = dglnn.GATConv(in_features, gnn_size[0], attn_heads)
     self.conv2 = dglnn.GATConv(gnn_size[0], gnn_size[1], attn_heads)