def __init__(self, in_channels, out_channels, msg_dim, time_enc): super(GraphAttentionEmbedding, self).__init__() self.time_enc = time_enc edge_dim = msg_dim + time_enc.out_channels self.conv = TransformerConv(in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim) self.reset_parameters()
def __init__(self, in_channels, num_classes, hidden_channels, num_layers, heads, dropout=0.3): super().__init__() self.label_emb = MaskLabel(num_classes, in_channels) self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() for i in range(1, num_layers + 1): if i < num_layers: out_channels = hidden_channels // heads concat = True else: out_channels = num_classes concat = False conv = TransformerConv(in_channels, out_channels, heads, concat=concat, beta=True, dropout=dropout) self.convs.append(conv) in_channels = hidden_channels if i < num_layers: self.norms.append(torch.nn.LayerNorm(hidden_channels))
def __init__(self): super().__init__() self.attention_layer = TransformerConv(64,64,edge_dim=64,root_weight=False) # self.reduce = ff(96) self.phi_v = fff(128) self.atom_fc = ff(64) self.edge_update = EdgeUpdate() self.bond_fc = ff(64)
def __init__(self, in_channels, out_channels, msg_dim, time_enc): super().__init__() self.time_enc = time_enc edge_dim = msg_dim + time_enc.out_channels self.conv = TransformerConv(in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim)
class GraphAttentionEmbedding(torch.nn.Module): def __init__(self, in_channels, out_channels, msg_dim, time_enc): super(GraphAttentionEmbedding, self).__init__() self.time_enc = time_enc edge_dim = msg_dim + time_enc.out_channels self.conv = TransformerConv(in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim) self.reset_parameters() def reset_parameters(self): self.time_enc.reset_parameters() self.conv.reset_parameters() def forward(self, x, last_update, edge_index, t, msg): rel_t = last_update[edge_index[0]] - t rel_t_enc = self.time_enc(rel_t.to(x.dtype)) edge_attr = torch.cat([rel_t_enc, msg], dim=-1) return self.conv(x, edge_index, edge_attr)
def test_transformer_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = TransformerConv(8, 32, heads=2) assert conv.__repr__() == 'TransformerConv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) adj = adj.sparse_resize((4, 2)) conv = TransformerConv((8, 16), 32, heads=2) assert conv.__repr__() == 'TransformerConv((8, 16), 32, heads=2)' out = conv((x1, x2), edge_index) assert out.size() == (2, 64) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) t = '(PairTensor, Tensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)
def __init__( self, embed_dim: int, num_layers: int, heads: int = 8, normalization: str = "batch", feed_forward_hidden: int = 512, pooling_method: str = "mean", ) -> None: super().__init__() self.embed_dim = embed_dim self.num_layers = num_layers self.heads = heads assert (self.embed_dim % self.heads) == 0 self.pooling_func = get_pooling_func(pooling_method) self.norm_class = get_normalization_class(normalization) self.gnn_layer_list = nn.ModuleList() self.norm_list = nn.ModuleList() for i in range(self.num_layers): gnn_layer = TransformerConv( in_channels=self.embed_dim, out_channels=self.embed_dim // self.heads, heads=self.heads, edge_dim=self.embed_dim, ) self.gnn_layer_list.append(gnn_layer) self.norm_list.append(GraphNorm(in_channels=self.embed_dim)) self.feed_forward = Sequential( "x, batch", [ (nn.Linear(self.embed_dim, feed_forward_hidden), "x -> x"), nn.GELU(), # (GraphNorm(in_channels=feed_forward_hidden), "x, batch -> x"), nn.Linear(feed_forward_hidden, self.embed_dim), ], ) self.ff_norm = GraphNorm(in_channels=self.embed_dim)
def test_transformer_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = TransformerConv(8, 32, heads=2, beta=True) assert conv.__repr__() == 'TransformerConv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, NoneType, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, NoneType, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) # Test `return_attention_weights`. result = conv(x1, edge_index, return_attention_weights=True) assert result[0].tolist() == out.tolist() assert result[1][0].size() == (2, 4) assert result[1][1].size() == (4, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 assert conv._alpha is None result = conv(x1, adj.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4 assert conv._alpha is None t = ('(Tensor, Tensor, NoneType, bool) -> ' 'Tuple[Tensor, Tuple[Tensor, Tensor]]') jit = torch.jit.script(conv.jittable(t)) result = jit(x1, edge_index, return_attention_weights=True) assert result[0].tolist() == out.tolist() assert result[1][0].size() == (2, 4) assert result[1][1].size() == (4, 2) assert result[1][1].min() >= 0 and result[1][1].max() <= 1 assert conv._alpha is None t = '(Tensor, SparseTensor, NoneType, bool) -> Tuple[Tensor, SparseTensor]' jit = torch.jit.script(conv.jittable(t)) result = jit(x1, adj.t(), return_attention_weights=True) assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4 assert conv._alpha is None adj = adj.sparse_resize((4, 2)) conv = TransformerConv((8, 16), 32, heads=2, beta=True) assert conv.__repr__() == 'TransformerConv((8, 16), 32, heads=2)' out = conv((x1, x2), edge_index) assert out.size() == (2, 64) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) t = '(PairTensor, Tensor, NoneType, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor, NoneType, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
def __init__(self, input_feat_dim, node_dim1, node_dim2, encode_dim, dropout=0.2, adj_drop=0.2, encoder='GraphConv', decoder='concatDec', sigmoid=True, n_nodes=50, hour_emb=100, week_emb=100, ARMAConv_num_stacks=1, ARMAConv_num_layers=1, TransformerConv_heads=1): ''' input_feat_dim = input feature dimension node_dim1, node_dim2 - the node embedding dimensions of the two layers of GCN. encode_dim - final graph node embedding dimension adj_drop - graph edge dropout rate decoder - choose from 'bilinearDec' and 'concatDec'(an MLP decoder) sigmoid - whether edge prediction output go through a sigmoid operator to (0,1) n_nodes - total number of nodes in the graph hour_emb, week_emb - the dimension of hour and week embeddings ''' super().__init__() self.n_nodes = n_nodes self.dropout = dropout self.adj_drop = adj_drop self.sigmoid = sigmoid self.node_dim2 = node_dim2 self.dropout_layer = torch.nn.Dropout(dropout) self.encoder = encoder # encode if encoder == 'SAGE': self.conv1 = SAGEConv(input_feat_dim, node_dim1) #, concat = True) self.conv2 = SAGEConv(node_dim1, node_dim2) #, concat = True) elif encoder == 'GraphConv': self.conv1 = GraphConv(input_feat_dim, node_dim1, aggr='mean') self.conv2 = GraphConv(node_dim1, node_dim2, aggr='mean') elif encoder == 'TAGConv': self.conv1 = TAGConv(input_feat_dim, node_dim1) self.conv2 = TAGConv(node_dim1, node_dim2) elif encoder == 'SGConv': self.conv1 = SGConv(input_feat_dim, node_dim1) self.conv2 = SGConv(node_dim1, node_dim2) elif encoder == 'ARMAConv': self.conv1 = ARMAConv(input_feat_dim, node_dim1, num_stacks=ARMAConv_num_stacks, num_layers=ARMAConv_num_layers) self.conv2 = ARMAConv(node_dim1, node_dim2, num_stacks=ARMAConv_num_stacks, num_layers=ARMAConv_num_layers) elif encoder == 'TransformerConv': self.conv1 = TransformerConv(input_feat_dim, node_dim1, heads=TransformerConv_heads, edge_dim=1) self.conv2 = TransformerConv(node_dim1 * TransformerConv_heads, node_dim2, heads=1, edge_dim=1) else: raise NotImplementedError self.hour_embedding = nn.Embedding(24, hour_emb) self.week_embedding = nn.Embedding(7, week_emb) self.fc = torch.nn.Linear(n_nodes * node_dim2 + hour_emb + week_emb, encode_dim) # decode self.fc2 = torch.nn.Linear(encode_dim + hour_emb + week_emb, n_nodes * node_dim2) if decoder == 'bilinearDec': self.decoder = bilinearDec(node_dim2) elif decoder == 'concatDec': self.decoder = concatDec(node_dim2, node_dim1, dropout) else: raise NotImplementedError
def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, clean_out=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) # class GRUConv(torch.nn.Module): # def __init__(self,hcs = self.hcs, act = self.act): # super(GRUConv, self).__init__() # self.act = act # self.hcs = hcs # self.GRU = torch.nn.GRUCell(self.hcs*2,self.hcs) # self.scatter_norm = scatter_norm(self.hcs) # self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True) # self.lin_CoC_self = MLP([self.hcs, self.hcs],clean_out = True) # self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs) # self.lin_x_msg = MLP([self.hcs, self.hcs],clean_out = True) # self.lin_x_self = MLP([self.hcs, self.hcs],clean_out = True) # self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs) # def forward(self, x, CoC, h, batch): # h = self.act( self.GRU( torch.cat([CoC[batch], x], dim=1), h) ) # msg = self.lin_CoC_msg( self.scatter_norm(h, batch) ) # CoC = self.lin_CoC_self(CoC) # CoC = self.act( self.CoC_batch_norm(msg+CoC) ) # h = self.act( self.GRU( torch.cat([x, CoC[batch]], dim=1), h) ) # msg = self.lin_x_msg(h) # x = self.lin_x_self(x) # x = self.act( self.x_batch_norm(msg+x) ) # return x, CoC, h # class AttConv(torch.nn.Module): # def __init__(self,in_hcs = [self.hcs, self.hcs], out_hcs = self.hcs, heads = 1): # super(AttConv,self).__init__() # self.heads = heads # self.out_hcs = out_hcs # self.lin_key = torch.nn.Linear(in_hcs[0], heads*out_hcs) # self.lin_query = torch.nn.Linear(in_hcs[1], heads*out_hcs) # self.lin_value = torch.nn.Linear(in_hcs[0], heads*out_hcs) # self.sqrt_d = torch.sqrt(out_hcs) # self.reset_parameters() # def reset_parameters(self): # self.lin_key.reset_parameters() # self.lin_query.reset_parameters() # self.lin_value.reset_parameters() # def forward(self, x, CoC, batch): # key = self.lin_key(x).view(-1,self.heads,self.out_hcs) # query = self.lin_query(CoC class scatter_norm(torch.nn.Module): def __init__(self, hcs): super(scatter_norm, self).__init__() self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats * hcs) def forward(self, x, batch): return self.batch_norm( scatter_distribution(x, batch, dim=0)) N_x_feats = N_dom_feats # + 4*(N_dom_feats + 1) N_CoC_feats = N_scatter_feats * N_x_feats + 3 self.scatter_norm = scatter_norm(N_x_feats) self.x_encoder = MLP([N_x_feats, self.hcs]) self.CoC_encoder = MLP([N_CoC_feats, self.hcs]) self.TConv = TransformerConv(in_channels=[self.hcs, self.hcs], out_channels=self.hcs, heads=N_metalayers) self.decoder = MLP( [(N_metalayers) * self.hcs, 3 * self.hcs, self.hcs, N_outputs], clean_out=True)