Esempio n. 1
0
 def __init__(self, in_channels, out_channels, msg_dim, time_enc):
     super(GraphAttentionEmbedding, self).__init__()
     self.time_enc = time_enc
     edge_dim = msg_dim + time_enc.out_channels
     self.conv = TransformerConv(in_channels,
                                 out_channels // 2,
                                 heads=2,
                                 dropout=0.1,
                                 edge_dim=edge_dim)
     self.reset_parameters()
Esempio n. 2
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 hidden_channels,
                 num_layers,
                 heads,
                 dropout=0.3):
        super().__init__()

        self.label_emb = MaskLabel(num_classes, in_channels)

        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            if i < num_layers:
                out_channels = hidden_channels // heads
                concat = True
            else:
                out_channels = num_classes
                concat = False
            conv = TransformerConv(in_channels,
                                   out_channels,
                                   heads,
                                   concat=concat,
                                   beta=True,
                                   dropout=dropout)
            self.convs.append(conv)
            in_channels = hidden_channels

            if i < num_layers:
                self.norms.append(torch.nn.LayerNorm(hidden_channels))
Esempio n. 3
0
 def __init__(self):
     super().__init__()
     self.attention_layer = TransformerConv(64,64,edge_dim=64,root_weight=False)
     # self.reduce = ff(96)
     self.phi_v = fff(128)
     self.atom_fc = ff(64)
     self.edge_update = EdgeUpdate()
     self.bond_fc = ff(64)
Esempio n. 4
0
 def __init__(self, in_channels, out_channels, msg_dim, time_enc):
     super().__init__()
     self.time_enc = time_enc
     edge_dim = msg_dim + time_enc.out_channels
     self.conv = TransformerConv(in_channels,
                                 out_channels // 2,
                                 heads=2,
                                 dropout=0.1,
                                 edge_dim=edge_dim)
Esempio n. 5
0
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels,
                                    out_channels // 2,
                                    heads=2,
                                    dropout=0.1,
                                    edge_dim=edge_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.time_enc.reset_parameters()
        self.conv.reset_parameters()

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)
Esempio n. 6
0
def test_transformer_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = TransformerConv(8, 32, heads=2)
    assert conv.__repr__() == 'TransformerConv(8, 32, heads=2)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 64)
    assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

    adj = adj.sparse_resize((4, 2))
    conv = TransformerConv((8, 16), 32, heads=2)
    assert conv.__repr__() == 'TransformerConv((8, 16), 32, heads=2)'

    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 64)
    assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)

    t = '(PairTensor, Tensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)
Esempio n. 7
0
    def __init__(
        self,
        embed_dim: int,
        num_layers: int,
        heads: int = 8,
        normalization: str = "batch",
        feed_forward_hidden: int = 512,
        pooling_method: str = "mean",
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.heads = heads
        assert (self.embed_dim % self.heads) == 0
        self.pooling_func = get_pooling_func(pooling_method)
        self.norm_class = get_normalization_class(normalization)
        self.gnn_layer_list = nn.ModuleList()
        self.norm_list = nn.ModuleList()
        for i in range(self.num_layers):
            gnn_layer = TransformerConv(
                in_channels=self.embed_dim,
                out_channels=self.embed_dim // self.heads,
                heads=self.heads,
                edge_dim=self.embed_dim,
            )
            self.gnn_layer_list.append(gnn_layer)
            self.norm_list.append(GraphNorm(in_channels=self.embed_dim))

        self.feed_forward = Sequential(
            "x, batch",
            [
                (nn.Linear(self.embed_dim, feed_forward_hidden), "x -> x"),
                nn.GELU(),
                # (GraphNorm(in_channels=feed_forward_hidden), "x, batch -> x"),
                nn.Linear(feed_forward_hidden, self.embed_dim),
            ],
        )
        self.ff_norm = GraphNorm(in_channels=self.embed_dim)
Esempio n. 8
0
def test_transformer_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = TransformerConv(8, 32, heads=2, beta=True)
    assert conv.__repr__() == 'TransformerConv(8, 32, heads=2)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 64)
    assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, NoneType, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, NoneType, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)

    # Test `return_attention_weights`.
    result = conv(x1, edge_index, return_attention_weights=True)
    assert result[0].tolist() == out.tolist()
    assert result[1][0].size() == (2, 4)
    assert result[1][1].size() == (4, 2)
    assert result[1][1].min() >= 0 and result[1][1].max() <= 1
    assert conv._alpha is None

    result = conv(x1, adj.t(), return_attention_weights=True)
    assert torch.allclose(result[0], out, atol=1e-6)
    assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4
    assert conv._alpha is None

    t = ('(Tensor, Tensor, NoneType, bool) -> '
         'Tuple[Tensor, Tuple[Tensor, Tensor]]')
    jit = torch.jit.script(conv.jittable(t))
    result = jit(x1, edge_index, return_attention_weights=True)
    assert result[0].tolist() == out.tolist()
    assert result[1][0].size() == (2, 4)
    assert result[1][1].size() == (4, 2)
    assert result[1][1].min() >= 0 and result[1][1].max() <= 1
    assert conv._alpha is None

    t = '(Tensor, SparseTensor, NoneType, bool) -> Tuple[Tensor, SparseTensor]'
    jit = torch.jit.script(conv.jittable(t))
    result = jit(x1, adj.t(), return_attention_weights=True)
    assert torch.allclose(result[0], out, atol=1e-6)
    assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4
    assert conv._alpha is None

    adj = adj.sparse_resize((4, 2))
    conv = TransformerConv((8, 16), 32, heads=2, beta=True)
    assert conv.__repr__() == 'TransformerConv((8, 16), 32, heads=2)'

    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 64)
    assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)

    t = '(PairTensor, Tensor, NoneType, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor, NoneType, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
Esempio n. 9
0
 def __init__(self,
              input_feat_dim,
              node_dim1,
              node_dim2,
              encode_dim,
              dropout=0.2,
              adj_drop=0.2,
              encoder='GraphConv',
              decoder='concatDec',
              sigmoid=True,
              n_nodes=50,
              hour_emb=100,
              week_emb=100,
              ARMAConv_num_stacks=1,
              ARMAConv_num_layers=1,
              TransformerConv_heads=1):
     '''
     input_feat_dim = input feature dimension
     node_dim1, node_dim2 - the node embedding dimensions of the two layers of GCN.
     encode_dim - final graph node embedding dimension
     adj_drop - graph edge dropout rate
     decoder - choose from 'bilinearDec' and 'concatDec'(an MLP decoder)
     sigmoid - whether edge prediction output go through a sigmoid operator to (0,1)
     n_nodes - total number of nodes in the graph
     hour_emb, week_emb - the dimension of hour and week embeddings
     '''
     super().__init__()
     self.n_nodes = n_nodes
     self.dropout = dropout
     self.adj_drop = adj_drop
     self.sigmoid = sigmoid
     self.node_dim2 = node_dim2
     self.dropout_layer = torch.nn.Dropout(dropout)
     self.encoder = encoder
     # encode
     if encoder == 'SAGE':
         self.conv1 = SAGEConv(input_feat_dim, node_dim1)  #, concat = True)
         self.conv2 = SAGEConv(node_dim1, node_dim2)  #, concat = True)
     elif encoder == 'GraphConv':
         self.conv1 = GraphConv(input_feat_dim, node_dim1, aggr='mean')
         self.conv2 = GraphConv(node_dim1, node_dim2, aggr='mean')
     elif encoder == 'TAGConv':
         self.conv1 = TAGConv(input_feat_dim, node_dim1)
         self.conv2 = TAGConv(node_dim1, node_dim2)
     elif encoder == 'SGConv':
         self.conv1 = SGConv(input_feat_dim, node_dim1)
         self.conv2 = SGConv(node_dim1, node_dim2)
     elif encoder == 'ARMAConv':
         self.conv1 = ARMAConv(input_feat_dim,
                               node_dim1,
                               num_stacks=ARMAConv_num_stacks,
                               num_layers=ARMAConv_num_layers)
         self.conv2 = ARMAConv(node_dim1,
                               node_dim2,
                               num_stacks=ARMAConv_num_stacks,
                               num_layers=ARMAConv_num_layers)
     elif encoder == 'TransformerConv':
         self.conv1 = TransformerConv(input_feat_dim,
                                      node_dim1,
                                      heads=TransformerConv_heads,
                                      edge_dim=1)
         self.conv2 = TransformerConv(node_dim1 * TransformerConv_heads,
                                      node_dim2,
                                      heads=1,
                                      edge_dim=1)
     else:
         raise NotImplementedError
     self.hour_embedding = nn.Embedding(24, hour_emb)
     self.week_embedding = nn.Embedding(7, week_emb)
     self.fc = torch.nn.Linear(n_nodes * node_dim2 + hour_emb + week_emb,
                               encode_dim)
     # decode
     self.fc2 = torch.nn.Linear(encode_dim + hour_emb + week_emb,
                                n_nodes * node_dim2)
     if decoder == 'bilinearDec':
         self.decoder = bilinearDec(node_dim2)
     elif decoder == 'concatDec':
         self.decoder = concatDec(node_dim2, node_dim1, dropout)
     else:
         raise NotImplementedError
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, clean_out=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)

                def forward(self, x):
                    return self.mlp(x)

#             class GRUConv(torch.nn.Module):
#                 def __init__(self,hcs = self.hcs, act = self.act):
#                     super(GRUConv, self).__init__()
#                     self.act = act
#                     self.hcs = hcs
#                     self.GRU = torch.nn.GRUCell(self.hcs*2,self.hcs)

#                     self.scatter_norm = scatter_norm(self.hcs)
#                     self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True)
#                     self.lin_CoC_self = MLP([self.hcs, self.hcs],clean_out = True)

#                     self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs)

#                     self.lin_x_msg = MLP([self.hcs, self.hcs],clean_out = True)
#                     self.lin_x_self = MLP([self.hcs, self.hcs],clean_out = True)

#                     self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs)

#                 def forward(self, x, CoC, h, batch):
#                     h = self.act( self.GRU( torch.cat([CoC[batch], x], dim=1), h) )

#                     msg = self.lin_CoC_msg( self.scatter_norm(h, batch) )
#                     CoC = self.lin_CoC_self(CoC)

#                     CoC = self.act( self.CoC_batch_norm(msg+CoC) )

#                     h = self.act( self.GRU( torch.cat([x, CoC[batch]], dim=1), h) )

#                     msg = self.lin_x_msg(h)
#                     x = self.lin_x_self(x)

#                     x = self.act( self.x_batch_norm(msg+x) )
#                     return x, CoC, h

#             class AttConv(torch.nn.Module):
#                 def __init__(self,in_hcs = [self.hcs, self.hcs], out_hcs = self.hcs, heads = 1):
#                     super(AttConv,self).__init__()

#                     self.heads = heads
#                     self.out_hcs = out_hcs

#                     self.lin_key = torch.nn.Linear(in_hcs[0], heads*out_hcs)
#                     self.lin_query = torch.nn.Linear(in_hcs[1], heads*out_hcs)
#                     self.lin_value = torch.nn.Linear(in_hcs[0], heads*out_hcs)

#                     self.sqrt_d = torch.sqrt(out_hcs)

#                     self.reset_parameters()

#                 def reset_parameters(self):
#                     self.lin_key.reset_parameters()
#                     self.lin_query.reset_parameters()
#                     self.lin_value.reset_parameters()

#                 def forward(self, x, CoC, batch):
#                     key = self.lin_key(x).view(-1,self.heads,self.out_hcs)
#                     query = self.lin_query(CoC

            class scatter_norm(torch.nn.Module):
                def __init__(self, hcs):
                    super(scatter_norm, self).__init__()
                    self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats *
                                                           hcs)

                def forward(self, x, batch):
                    return self.batch_norm(
                        scatter_distribution(x, batch, dim=0))

            N_x_feats = N_dom_feats  # + 4*(N_dom_feats + 1)
            N_CoC_feats = N_scatter_feats * N_x_feats + 3

            self.scatter_norm = scatter_norm(N_x_feats)
            self.x_encoder = MLP([N_x_feats, self.hcs])
            self.CoC_encoder = MLP([N_CoC_feats, self.hcs])

            self.TConv = TransformerConv(in_channels=[self.hcs, self.hcs],
                                         out_channels=self.hcs,
                                         heads=N_metalayers)

            self.decoder = MLP(
                [(N_metalayers) * self.hcs, 3 * self.hcs, self.hcs, N_outputs],
                clean_out=True)