def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax from torch_scatter.utils import broadcast import torch.nn.functional as F # from torch_geometric.utils import softmax @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) @torch.jit.script def x_feature_constructor(x, graph_node_counts): tmp = [] a: List[int] = graph_node_counts.tolist() for tmp_x in x.split(a): tmp_x = tmp_x.unsqueeze(1) - tmp_x cart = tmp_x[:, :, -3:] rho = torch.norm(cart, p=2, dim=-1).unsqueeze(2) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] tmp_x = torch.cat([cart, rho, tmp_x[:, :, :-3]], dim=2) tmp.append( torch.cat([ tmp_x.mean(1), tmp_x.std(1), tmp_x.min(1)[0], tmp_x.max(1)[0] ], dim=1)) return torch.cat(tmp, 0) @torch.jit.script def time_edge_indeces(t, batch: torch.Tensor): time_sort = torch.argsort(t) graph_ids, graph_node_counts = torch.unique(batch, return_counts=True) batch_time_sort = torch.cat( [time_sort[batch[time_sort] == i] for i in graph_ids]) time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1) ], dim=0) tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1] li: List[int] = (tmp_ind + 1).tolist() time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]] time_edge_index = torch.cat([ time_edge_index, torch.cat([ time_edge_index[1, -1].view(1, 1), time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1) ]) ], dim=1) time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])] return time_edge_index N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 # p_dropout = args['dropout'] crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) print( "This model assumes Charge is at index 0 and position is the last three" ) self.act = torch.nn.SiLU() self.hcs = N_hcs # self.dropout = torch.nn.Dropout(p=p_dropout) class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, no_final_act=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if not no_final_act: self.mlp = torch.nn.Sequential(*mlp) else: self.mlp = torch.nn.Sequential(*mlp[:-1]) def forward(self, x): return self.mlp(x) # class AttGNN(torch.nn.Module): # def __init__(self, hcs_in, hcs_out, act = self.act): # super(AttGNN, self).__init__() # # self.att_mlp = torch.nn.Sequential(torch.nn.Linear(2*hcs_in,hcs_in), # # act, # # torch.nn.Linear(hcs_in,1)) # self.att_mlp = torch.nn.Linear(2*hcs_in,1) # self.self_mlp = MLP([hcs_in,hcs_in,hcs_out]) # self.msg_mlp = MLP([hcs_in,hcs_in,hcs_out]) # def forward(self, x, graph_node_counts): # li : List[int] = graph_node_counts.tolist() # tmp = [] # for tmp_x, msg in zip(x.split(li), self.msg_mlp(x).split(li)): # tmp_x = torch.cat([tmp_x.unsqueeze(1) - tmp_x, tmp_x.unsqueeze(1) + tmp_x],dim=2) # tmp.append(torch.matmul(self.att_mlp(tmp_x).squeeze(),msg)) # return self.self_mlp(x) + torch.cat(tmp,0) class AttGNN(torch.nn.Module): def __init__(self, hcs_in, hcs_out, act=self.act): super(AttGNN, self).__init__() self.self_mlp = MLP([hcs_in, hcs_in, hcs_out]) self.msg_mlp = torch.nn.Linear( 2 * hcs_in, hcs_out) #MLP([2*hcs_in,hcs_in,hcs_out]) self.msg_mlp2 = MLP([4 * hcs_out, 2 * hcs_out, hcs_out]) def forward(self, x, graph_node_counts): # print(graph_node_counts.max().item(), graph_node_counts.float().mean().item(), graph_node_counts.float().std().item(), graph_node_counts.min().item()) li: List[int] = graph_node_counts.tolist() tmp = [] for tmp_x in x.split(li): tmp_x = torch.cat([ tmp_x.unsqueeze(1) - tmp_x, tmp_x.unsqueeze(1) + tmp_x ], dim=2) tmp_x = self.msg_mlp(tmp_x) tmp.append( self.msg_mlp2( torch.cat([ tmp_x.mean(1), tmp_x.std(1), tmp_x.min(1)[0], tmp_x.max(1)[0] ], dim=1))) out = self.self_mlp(x) + torch.cat(tmp, 0) del tmp, tmp_x return out N_x_feats = N_dom_feats + 4 * (N_dom_feats + 1) + 4 + 3 # self.att = MLP([N_x_feats,self.hcs,1],no_final_act=True) self.x_encoder = MLP([N_x_feats, self.hcs, self.hcs]) self.convs = torch.nn.ModuleList() for i in range(N_metalayers): self.convs.append(AttGNN(self.hcs, self.hcs)) self.decoder = MLP([ N_scatter_feats * self.hcs, N_scatter_feats // 2 * self.hcs, self.hcs, N_outputs ], no_final_act=True) # self.beta = torch.nn.Parameter(torch.ones(1)*10) def return_CoC_and_edge_attr(self, x, batch): pos = x[:, -3:] charge = x[:, 0].view(-1, 1) # Define central nodes at Center of Charge: CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum( charge, batch, dim=0) # Define edge_attr for those edges: cart = pos - CoC[batch] rho = torch.norm(cart, p=2, dim=1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1) return CoC, CoC_edge_attr def forward(self, data): x, batch = data.x.float(), data.batch CoC, CoC_edge_attr = self.return_CoC_and_edge_attr(x, batch) graph_ids, graph_node_counts = batch.unique(return_counts=True) x = torch.cat([ x, x_feature_constructor(x, graph_node_counts), CoC_edge_attr, CoC[batch] ], dim=1) # att = scatter_softmax(graph_node_counts[batch].view(-1,1)/100*self.att(x),batch,dim=0) # att_d = scatter_distribution(att,batch,dim=0) # mask = att.squeeze() > att_d[batch,0] + att_d[batch,1] # # mask = att.squeeze() > att_d[batch,2] - 1000*att_d[batch,1]/graph_node_counts[batch] # x = x[mask] # batch = batch[mask] # graph_ids, graph_node_counts = batch.unique(return_counts=True) x = self.x_encoder(x) for conv in self.convs: x = conv(x, graph_node_counts) return self.decoder(scatter_distribution(x, batch, dim=0)) return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) @torch.jit.script def x_feature_constructor(x, graph_node_counts): tmp = [] a: List[int] = graph_node_counts.tolist() for tmp_x in x.split(a): tmp_x = tmp_x.unsqueeze(1) - tmp_x cart = tmp_x[:, :, -3:] rho = torch.norm(cart, p=2, dim=-1).unsqueeze(2) rho_mask = rho.squeeze() != 0 if rho_mask.sum() != 0: cart[rho_mask] = cart[rho_mask] / rho[rho_mask] tmp_x = torch.cat([cart, rho, tmp_x[:, :, :-3]], dim=2) tmp.append( torch.cat([ tmp_x.mean(1), tmp_x.std(1), tmp_x.min(1)[0], tmp_x.max(1)[0] ], dim=1)) return torch.cat(tmp, 0) @torch.jit.script def time_edge_indeces(t, batch: torch.Tensor): time_sort = torch.argsort(t) graph_ids, graph_node_counts = torch.unique(batch, return_counts=True) batch_time_sort = torch.cat( [time_sort[batch[time_sort] == i] for i in graph_ids]) time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1) ], dim=0) tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1] li: List[int] = (tmp_ind + 1).tolist() time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]] time_edge_index = torch.cat([ time_edge_index, torch.cat([ time_edge_index[1, -1].view(1, 1), time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1) ]) ], dim=1) time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])] return time_edge_index N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = 2 * N_dom_feats + 4 * (N_dom_feats + 1) + N_edge_feats + 4 self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) self.CoC_encoder = torch.nn.Linear(3 + N_scatter_feats * N_x_feats, self.hcs) class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, no_final_act=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if not no_final_act: self.mlp = torch.nn.Sequential(*mlp) else: self.mlp = torch.nn.Sequential(*mlp[:-1]) def forward(self, x): return self.mlp(x) # class Conversation_x(torch.nn.Module): # def __init__(self,hcs,act): # super(Conversation, self).__init__() # self.act = act # self.hcs = hcs # self.GRU = torch.nn.GRUCell(self.hcs*2 + 4 + N_x_feats,self.hcs) # def forward(self, x, edge_index, edge_attr, batch, h): # (frm, to) = edge_index # h = self.act( self.GRU( torch.cat([x[to],x[frm],edge_attr],dim=1), h ) ) # x = self.act( self.lin_msg1( torch.cat([x,scatter_distribution(h, to, dim=0)],dim=1) ) ) # h = self.act( self.GRU( torch.cat([x[frm],x[to],edge_attr],dim=1), h) ) # x = self.act( self.lin_msg2( torch.cat([x,scatter_distribution(h, frm, dim=0)],dim=1) ) ) self.GRUCells = torch.nn.ModuleList() self.lins_CoC_msg = torch.nn.ModuleList() self.lins_CoC_self = torch.nn.ModuleList() self.CoC_batch_norm = torch.nn.ModuleList() self.lins_x_msg = torch.nn.ModuleList() self.lins_x_self = torch.nn.ModuleList() self.x_batch_norm = torch.nn.ModuleList() for i in range(N_metalayers): self.GRUCells.append(torch.nn.GRUCell(self.hcs * 2, self.hcs)) self.lins_CoC_msg.append( torch.nn.Linear(N_scatter_feats * self.hcs, self.hcs)) self.lins_CoC_self.append(torch.nn.Linear(self.hcs, self.hcs)) self.CoC_batch_norm.append(torch.nn.BatchNorm1d(self.hcs)) self.lins_x_msg.append(torch.nn.Linear(self.hcs, self.hcs)) self.lins_x_self.append(torch.nn.Linear(self.hcs, self.hcs)) self.x_batch_norm.append(torch.nn.BatchNorm1d(self.hcs)) self.decoders = torch.nn.ModuleList() self.decoder_batch_norms = torch.nn.ModuleList() self.decoders.append( torch.nn.Linear((1 + N_scatter_feats) * self.hcs, self.hcs)) self.decoder_batch_norms.append(torch.nn.BatchNorm1d(self.hcs)) self.decoders.append(torch.nn.Linear(self.hcs, self.hcs)) self.decoder_batch_norms.append(torch.nn.BatchNorm1d(self.hcs)) self.att = MLP([N_x_feats, self.hcs, 1], no_final_act=True) self.decoder = torch.nn.Linear(self.hcs, N_outputs) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### pos = x[:, -3:] graph_ids, graph_node_counts = batch.unique(return_counts=True) time_edge_index = time_edge_indeces(x[:, 1], batch) edge_attr = edge_feature_constructor(x, time_edge_index) # Define central nodes at Center of Charge: CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch, dim=0) / scatter_sum( x[:, 0].view(-1, 1), batch, dim=0) # Define edge_attr for those edges: cart = pos[:, -3:] - CoC[batch, :3] del pos rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1) x = torch.cat([ x, x_feature_constructor(x, graph_node_counts), edge_attr, x[time_edge_index[0]], CoC_edge_attr ], dim=1) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) att = scatter_softmax(graph_node_counts[batch].view(-1, 1) / 100 * self.att(x), batch, dim=0) att_d = scatter_distribution(att, batch, dim=0) mask = att.squeeze() > att_d[batch, 0] + att_d[batch, 1] x = x[mask] batch = batch[mask] x = self.act(self.x_encoder(x)) CoC = self.act(self.CoC_encoder(CoC)) h = torch.zeros((x.shape[0], self.hcs)).type_as(x) for i in range(N_metalayers): h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x], dim=1), h)) CoC = self.act( self.CoC_batch_norm[i](self.lins_CoC_msg[i]( scatter_distribution(h, batch, dim=0)) + self.lins_CoC_self[i](CoC))) h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x], dim=1), h)) x = self.act(self.x_batch_norm[i](self.lins_x_msg[i](h) + self.lins_x_self[i](x))) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) for batch_norm, lin in zip(self.decoder_batch_norms, self.decoders): CoC = self.act(batch_norm(lin(CoC))) CoC = self.decoder(CoC) return CoC return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ, mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 5 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 print( "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]" ) assert N_dom_feats == len(args['features'].split(', ')) crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, clean_out=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) class FullConv(torch.nn.Module): def __init__(self, hcs_in, hcs_out, act=self.act): super(FullConv, self).__init__() self.self_mlp = MLP([hcs_in, hcs_out]) self.msg_mlp = torch.nn.Sequential( torch.nn.Linear(2 * hcs_in, hcs_out), act) self.msg_mlp2 = MLP([4 * hcs_out, hcs_out]) def forward(self, x, graph_node_counts): li: List[int] = graph_node_counts.tolist() tmp = [] for tmp_x in x.split(li): tmp_x = torch.cat([ tmp_x.unsqueeze(1) - tmp_x, tmp_x.unsqueeze(1) + tmp_x ], dim=2) tmp_x = self.msg_mlp(tmp_x) tmp.append( torch.cat([ tmp_x.mean(1), tmp_x.std(1), tmp_x.min(1)[0], tmp_x.max(1)[0] ], dim=1)) return self.self_mlp(x) + self.msg_mlp2(torch.cat(tmp, 0)) class scatter_norm(torch.nn.Module): def __init__(self, hcs): super(scatter_norm, self).__init__() self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats * hcs) def forward(self, x, batch): return self.batch_norm( scatter_distribution(x, batch, dim=0)) N_x_feats = N_dom_feats # + 4*(N_dom_feats + 1) N_CoC_feats = N_scatter_feats * N_x_feats + 3 self.scatter_norm = scatter_norm(N_x_feats) self.x_encoder = MLP([N_x_feats, self.hcs]) self.CoC_encoder = MLP([N_CoC_feats, self.hcs]) self.convs = torch.nn.ModuleList() self.scatter_norms = torch.nn.ModuleList() for _ in range(N_metalayers): self.convs.append(FullConv(self.hcs, self.hcs)) self.scatter_norms.append(scatter_norm(self.hcs)) self.decoder = MLP([(1 + N_scatter_feats * N_metalayers) * self.hcs, self.hcs, N_outputs], clean_out=True) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch, dim=0) / scatter_sum( x[:, -5].view(-1, 1), batch, dim=0) CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) graph_ids, graph_node_counts = batch.unique(return_counts=True) for conv, scatter_norm in zip(self.convs, self.scatter_norms): x = conv(x, graph_node_counts) CoC = torch.cat([CoC, scatter_norm(x, batch)], dim=1) CoC = self.decoder(CoC) return CoC return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats N_u_feats = N_scatter_feats * N_x_feats self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs) self.u_encoder = torch.nn.Linear(N_u_feats, self.hcs) class EdgeModel(torch.nn.Module): def __init__(self, hcs, act): super(EdgeModel, self).__init__() self.hcs = hcs self.act = act self.lins = torch.nn.ModuleList() for i in range(2): self.lins.append( torch.nn.Linear(4 * self.hcs, 4 * self.hcs)) self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs) def forward(self, src, dest, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # src, dest: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # edge_index: [2, E] with max entry N - 1. # u: [B, F_u], where B is the number of graphs. # batch: [N] with max entry B - 1. out = torch.cat([src, dest, edge_attr, u[batch]], dim=1) for lin in self.lins: out = self.act(lin(out)) return self.act(self.decoder(out)) class NodeModel(torch.nn.Module): def __init__(self, hcs, act): super(NodeModel, self).__init__() self.hcs = hcs self.act = act self.lins1 = torch.nn.ModuleList() for i in range(2): self.lins1.append( torch.nn.Linear(2 * self.hcs, 2 * self.hcs)) self.lins2 = torch.nn.ModuleList() for i in range(2): self.lins2.append( torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, (2 + 2 * N_scatter_feats) * self.hcs)) self.decoder = torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, u, batch): row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) for lin in self.lins1: out = self.act(lin(out)) out = scatter_distribution(out, col, dim=0) #8*hcs out = torch.cat([x, out, u[batch]], dim=1) for lin in self.lins2: out = self.act(lin(out)) return self.act(self.decoder(out)) class GlobalModel(torch.nn.Module): def __init__(self, hcs, act): super(GlobalModel, self).__init__() self.hcs = hcs self.act = act self.lins1 = torch.nn.ModuleList() for i in range(2): self.lins1.append( torch.nn.Linear((1 + N_scatter_feats) * self.hcs, (1 + N_scatter_feats) * self.hcs)) self.decoder = torch.nn.Linear( (1 + N_scatter_feats) * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, u, batch): out = torch.cat( [u, scatter_distribution(x, batch, dim=0)], dim=1) # 5*hcs for lin in self.lins1: out = self.act(lin(out)) return self.act(self.decoder(out)) from torch_geometric.nn import MetaLayer self.ops = torch.nn.ModuleList() for i in range(N_metalayers): self.ops.append( MetaLayer(EdgeModel(self.hcs, self.act), NodeModel(self.hcs, self.act), GlobalModel(self.hcs, self.act))) self.decoders = torch.nn.ModuleList() self.decoders.append( torch.nn.Linear((1 + N_metalayers) * self.hcs + N_scatter_feats * N_x_feats, 10 * self.hcs)) self.decoders.append(torch.nn.Linear(10 * self.hcs, 8 * self.hcs)) self.decoders.append(torch.nn.Linear(8 * self.hcs, 6 * self.hcs)) self.decoders.append(torch.nn.Linear(6 * self.hcs, 4 * self.hcs)) self.decoders.append(torch.nn.Linear(4 * self.hcs, 2 * self.hcs)) self.decoder = torch.nn.Linear(2 * self.hcs, N_outputs) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch x = torch.cat( [x, scatter_distribution(edge_attr, edge_index[1], dim=0)], dim=1) u = torch.cat([scatter_distribution(x, batch, dim=0)], dim=1) x0 = x.clone() x = self.act(self.x_encoder(x)) edge_attr = self.act(self.edge_attr_encoder(edge_attr)) u = self.act(self.u_encoder(u)) out = u.clone() for i, op in enumerate(self.ops): x, edge_attr, u = op(x, edge_index, edge_attr, u, batch) out = torch.cat([out, u.clone()], dim=1) out = torch.cat([out, scatter_distribution(x0, batch, dim=0)], dim=1) for lin in self.decoders: out = self.act(lin(out)) out = self.decoder(out) return out return Net # assert 'load' in args, "Specify bool 'load' in args, for loading a checkpoint" # if args['load']: # return Net # else: # return Net() # def Load_model(name, args): # from FunctionCollection import Loss_Functions # import pytorch_lightning as pl # from typing import Optional # import torch # from torch_scatter import scatter_sum, scatter_min, scatter_max # from torch_scatter.utils import broadcast # @torch.jit.script # def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, # out: Optional[torch.Tensor] = None, # dim_size: Optional[int] = None, # unbiased: bool = True) -> torch.Tensor: # if out is not None: # dim_size = out.size(dim) # if dim < 0: # dim = src.dim() + dim # count_dim = dim # if index.dim() <= dim: # count_dim = index.dim() - 1 # ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) # count = scatter_sum(ones, index, count_dim, dim_size=dim_size) # index = broadcast(index, src, dim) # tmp = scatter_sum(src, index, dim, dim_size=dim_size) # count = broadcast(count, tmp, dim).clamp(1) # mean = tmp.div(count) # var = (src - mean.gather(dim, index)) # var = var * var # var = scatter_sum(var, index, dim, out, dim_size) # if unbiased: # count = count.sub(1).clamp_(1) # var = var.div(count) # maximum = scatter_max(src, index, dim, out, dim_size)[0] # minimum = scatter_min(src, index, dim, out, dim_size)[0] # return torch.cat([mean,var,maximum,minimum],dim=1) # N_edge_feats = args['N_edge_feats'] #6 # N_dom_feats = args['N_dom_feats']#6 # N_scatter_feats = 4 # N_targets = args['N_targets'] # N_metalayers = args['N_metalayers'] #10 # N_hcs = args['N_hcs'] #32 # #Possibly add (edge/node/global)_layers # crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(name, args) # likelihood_fitting = True if name[-4:] == 'NLLH' else False # if args['wandb_activated']: # import wandb # class Net(pl.LightningModule): # def __init__(self): # super(Net, self).__init__() # self.crit = crit # self.y_post_processor = y_post_processor # self.output_post_processor = output_post_processor # self.cal_acc = cal_acc # self.likelihood_fitting = likelihood_fitting # self.act = torch.nn.SiLU() # self.hcs = N_hcs # N_x_feats = N_dom_feats + N_scatter_feats*N_edge_feats # N_u_feats = N_scatter_feats*N_x_feats # self.x_encoder = torch.nn.Linear(N_x_feats,self.hcs) # self.edge_attr_encoder = torch.nn.Linear(N_edge_feats,self.hcs) # self.u_encoder = torch.nn.Linear(N_u_feats,self.hcs) # class EdgeModel(torch.nn.Module): # def __init__(self,hcs,act): # super(EdgeModel, self).__init__() # self.hcs = hcs # self.act = act # self.lins = torch.nn.ModuleList() # for i in range(2): # self.lins.append(torch.nn.Linear(4*self.hcs,4*self.hcs)) # self.decoder = torch.nn.Linear(4*self.hcs,self.hcs) # def forward(self, src, dest, edge_attr, u, batch): # # x: [N, F_x], where N is the number of nodes. # # src, dest: [E, F_x], where E is the number of edges. # # edge_attr: [E, F_e] # # edge_index: [2, E] with max entry N - 1. # # u: [B, F_u], where B is the number of graphs. # # batch: [N] with max entry B - 1. # out = torch.cat([src, dest, edge_attr, u[batch]], dim=1) # for lin in self.lins: # out = self.act(lin(out)) # return self.act(self.decoder(out)) # class NodeModel(torch.nn.Module): # def __init__(self,hcs,act): # super(NodeModel, self).__init__() # self.hcs = hcs # self.act = act # self.lins1 = torch.nn.ModuleList() # for i in range(2): # self.lins1.append(torch.nn.Linear(2*self.hcs,2*self.hcs)) # self.lins2 = torch.nn.ModuleList() # for i in range(2): # self.lins2.append(torch.nn.Linear((2 + 2*N_scatter_feats)*self.hcs,(2 + 2*N_scatter_feats)*self.hcs)) # self.decoder = torch.nn.Linear((2 + 2*N_scatter_feats)*self.hcs,self.hcs) # def forward(self, x, edge_index, edge_attr, u, batch): # row, col = edge_index # out = torch.cat([x[row],edge_attr],dim=1) # for lin in self.lins1: # out = self.act(lin(out)) # out = scatter_distribution(out,col,dim=0) #8*hcs # out = torch.cat([x,out,u[batch]],dim=1) # for lin in self.lins2: # out = self.act(lin(out)) # return self.act(self.decoder(out)) # class GlobalModel(torch.nn.Module): # def __init__(self,hcs,act): # super(GlobalModel, self).__init__() # self.hcs = hcs # self.act = act # self.lins1 = torch.nn.ModuleList() # for i in range(2): # self.lins1.append(torch.nn.Linear((1 + N_scatter_feats)*self.hcs,(1 + N_scatter_feats)*self.hcs)) # self.decoder = torch.nn.Linear((1 + N_scatter_feats)*self.hcs,self.hcs) # def forward(self, x, edge_index, edge_attr, u, batch): # out = torch.cat([u,scatter_distribution(x,batch,dim=0)],dim=1) # 5*hcs # for lin in self.lins1: # out = self.act(lin(out)) # return self.act(self.decoder(out)) # from torch_geometric.nn import MetaLayer # self.ops = torch.nn.ModuleList() # for i in range(N_metalayers): # self.ops.append(MetaLayer(EdgeModel(self.hcs,self.act), NodeModel(self.hcs,self.act), GlobalModel(self.hcs,self.act))) # self.decoders = torch.nn.ModuleList() # self.decoders.append(torch.nn.Linear((1+N_metalayers)*self.hcs + N_scatter_feats*N_x_feats,10*self.hcs)) # self.decoders.append(torch.nn.Linear(10*self.hcs,8*self.hcs)) # self.decoders.append(torch.nn.Linear(8*self.hcs,6*self.hcs)) # self.decoders.append(torch.nn.Linear(6*self.hcs,4*self.hcs)) # self.decoders.append(torch.nn.Linear(4*self.hcs,2*self.hcs)) # self.decoder = torch.nn.Linear(2*self.hcs,N_targets) # def forward(self,data): # x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch # x = torch.cat([x,scatter_distribution(edge_attr,edge_index[1],dim=0)],dim=1) # u = torch.cat([scatter_distribution(x,batch,dim=0)],dim=1) # x0 = x.clone() # x = self.act(self.x_encoder(x)) # edge_attr = self.act(self.edge_attr_encoder(edge_attr)) # u = self.act(self.u_encoder(u)) # out = u.clone() # for i, op in enumerate(self.ops): # x, edge_attr, u = op(x, edge_index, edge_attr, u, batch) # out = torch.cat([out,u.clone()],dim=1) # out = torch.cat([out,scatter_distribution(x0,batch,dim=0)],dim=1) # for lin in self.decoders: # out = self.act(lin(out)) # out = self.decoder(out) # return out # def configure_optimizers(self): # from FunctionCollection import lambda_lr # optimizer = torch.optim.Adam(self.parameters(), lr=args['lr']) # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_lr) # return [optimizer], [scheduler] # def training_step(self, data, batch_idx): # label = self.y_post_processor(data.y) # if self.likelihood_fitting: # output, cov = self.output_post_processor( self(data) ) # loss = self.crit(output, cov, label) # else: # output = self.output_post_processor( self(data) ) # loss = self.crit(output, label) # acc = self.cal_acc(output, label) # if args['wandb_activated']: # wandb.log({"Train Loss": loss.item(), # "Train Acc": acc}) # return {'loss': loss} # def validation_step(self, data, batch_idx): # label = self.y_post_processor(data.y) # if self.likelihood_fitting: # output, cov = self.output_post_processor( self(data) ) # else: # output = self.output_post_processor( self(data) ) # acc = self.cal_acc(output, label) # if args['wandb_activated']: # wandb.log({"val_batch_acc": acc}) # return {'val_batch_acc': torch.tensor(acc,dtype=torch.float)} # def validation_epoch_end(self, outputs): # avg_acc = torch.stack([x['val_batch_acc'] for x in outputs]).mean() # if args['wandb_activated']: # wandb.log({"Val Acc": avg_acc}) # return {'Val Acc': avg_acc} # def test_step(self, data, batch_idx): # label = self.y_post_processor(data.y) # if self.likelihood_fitting: # output, cov = self.output_post_processor( self(data) ) # else: # output = self.output_post_processor( self(data) ) # acc = self.cal_acc(output, label) # if args['wandb_activated']: # wandb.log({"Test Acc": acc}) # return {'test_batch_acc': torch.tensor(acc,dtype=torch.float)} # def test_epoch_end(self, outputs): # avg_acc = torch.stack([x['test_batch_acc'] for x in outputs]).mean() # if args['wandb_activated']: # wandb.log({"Test Acc": avg_acc}) # return {'Test Acc': avg_acc} # return Net()
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) @torch.jit.script def time_edge_indeces(t, batch: torch.Tensor): time_sort = torch.argsort(t) graph_ids, graph_node_counts = torch.unique(batch, return_counts=True) batch_time_sort = torch.cat( [time_sort[batch[time_sort] == i] for i in graph_ids]) time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1) ], dim=0) tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1] li: List[int] = (tmp_ind + 1).tolist() time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]] time_edge_index = torch.cat([ time_edge_index, torch.cat([ time_edge_index[1, -1].view(1, 1), time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1) ]) ], dim=1) time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])] return time_edge_index N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 print( "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]" ) assert N_dom_feats == len(args['features'].split(', ')) crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs def gaussian(alpha): phi = torch.exp(-1 * alpha.pow(2)) return phi class RBF(torch.nn.Module): def __init__(self, in_features, out_features, basis_func=gaussian): super(RBF, self).__init__() self.in_features = in_features self.out_features = out_features self.centres = torch.nn.Parameter( torch.Tensor(out_features, in_features)) # self.log_sigmas = torch.nn.Parameter(torch.Tensor(out_features)) self.sigmas = torch.nn.Parameter( torch.Tensor(out_features)) self.basis_func = basis_func self.reset_parameters() def reset_parameters(self): torch.nn.init.normal_(self.centres, 0, 1) # torch.nn.init.constant_(self.log_sigmas, 2.3) torch.nn.init.constant_(self.sigmas, 10) def forward(self, x): size = (x.size(0), self.out_features, self.in_features) x = x.unsqueeze(1).expand(size) c = self.centres.unsqueeze(0).expand(size) # distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(self.log_sigmas).unsqueeze(0) distances = (x - c).pow(2).sum(-1) / self.sigmas.unsqueeze(0) # print(distances.mean(), distances.std(), distances.min(), distances.max()) # return self.basis_func(distances) return self.basis_func(distances) class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, clean_out=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) class RBF_scatter(torch.nn.Module): def __init__(self, in_features, out_features): super(RBF_scatter, self).__init__() self.RBF_batch_norm = torch.nn.BatchNorm1d(in_features) self.RBF = RBF(in_features, out_features) self.sum_batch_norm = torch.nn.BatchNorm1d(out_features) def forward(self, x, batch): return self.sum_batch_norm( scatter_sum(self.RBF(self.RBF_batch_norm(x)), batch, dim=0)) class GRUConv(torch.nn.Module): def __init__(self, hcs=self.hcs, act=self.act): super(GRUConv, self).__init__() self.act = act self.hcs = hcs self.GRU = torch.nn.GRUCell(self.hcs * 2, self.hcs) self.lin_CoC_msg = MLP( [N_scatter_feats * self.hcs, self.hcs], clean_out=True) self.lin_CoC_self = MLP([self.hcs, self.hcs], clean_out=True) self.RBF_scatter = RBF_scatter(self.hcs, 4 * self.hcs) self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs) self.lin_x_msg = MLP([4 * self.hcs, self.hcs], clean_out=True) self.lin_x_self = MLP([self.hcs, self.hcs], clean_out=True) self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs) def forward(self, x, CoC, h, batch): h = self.GRU(torch.cat([CoC[batch], x], dim=1), h) msg = self.lin_CoC_msg(self.RBF_scatter(h, batch)) CoC = self.lin_CoC_self(CoC) CoC = self.act(self.CoC_batch_norm(msg + CoC)) h = self.GRU(torch.cat([x, CoC[batch]], dim=1), h) msg = self.lin_x_msg(self.RBF_scatter.RBF(h)) x = self.lin_x_self(x) x = self.act(self.x_batch_norm(msg + x)) return x, CoC, h N_x_feats = 2 * (N_dom_feats + 4 * (N_dom_feats + 1)) + N_dom_feats + 4 * ( N_dom_feats + 1) + 1 + 4 + 3 self.N_x_to_CoC_feats = N_dom_feats + 4 * ( N_dom_feats + 1) + 4 + N_dom_feats + 4 * (N_dom_feats + 1) + 1 N_CoC_feats = 3 + N_scatter_feats * (self.N_x_to_CoC_feats) self.x_encoder = MLP([N_x_feats, 2 * self.hcs, self.hcs]) # self.att = MLP([N_x_feats,1]) self.CoC_encoder = MLP([N_CoC_feats, 2 * self.hcs, self.hcs]) self.convs = torch.nn.ModuleList() for i in range(N_metalayers): self.convs.append(GRUConv()) # self.decoderRBF = RBF(self.hcs,4*self.hcs) # self.decoder = MLP([self.hcs,self.hcs,self.hcs,N_outputs],clean_out=True) self.decoder_RBF_scatter = RBF_scatter(self.hcs, 4 * self.hcs) self.decoder = MLP( [(1 + N_scatter_feats) * self.hcs, 3 * self.hcs, N_outputs], clean_out=True) # self.decoder = MLP([(1+N_scatter_feats)*self.hcs,self.hcs,self.hcs]) # self.decoders = torch.nn.ModuleList() # for _ in range(N_outputs): # self.decoders.append(torch.nn.Linear(self.hcs,1)) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### time_edge_index = time_edge_indeces(x[:, -4], batch) time_edge_attr = self.edge_feature_constructor(x, time_edge_index) CoC, CoC_edge_attr = self.return_CoC_and_edge_attr(x, batch) x = torch.cat([ x[time_edge_index[0]], CoC[batch], CoC_edge_attr, time_edge_attr, x ], dim=1) CoC = torch.cat([ CoC, scatter_distribution( x[:, -self.N_x_to_CoC_feats:], batch, dim=0) ], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) h = torch.zeros((x.shape[0], self.hcs), device=self.device) for i in range(N_metalayers): x, CoC, h = self.convs[i](x, CoC, h, batch) CoC = torch.cat([CoC, self.decoder_RBF_scatter(x, batch)], dim=1) # CoC = torch.cat([CoC,scatter_distribution(x, batch, dim=0)],dim=1) CoC = self.decoder(CoC) # out = [] # for mlp in self.decoders: # out.append(mlp(CoC)) # CoC = torch.cat(out,dim=1) return CoC def return_CoC_and_edge_attr(self, x, batch): pos = x[:, -3:] charge = x[:, -5].view(-1, 1) # Define central nodes at Center of Charge: CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum( charge, batch, dim=0) # Define edge_attr for those edges: cart = pos - CoC[batch] rho = torch.norm(cart, p=2, dim=1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1) return CoC, CoC_edge_attr def edge_feature_constructor(self, x, edge_index): (frm, to) = edge_index pos = x[:, -3:] cart = pos[frm] - pos[to] rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] diff = x[to, :-3] - x[frm, :-3] return torch.cat([cart.type_as(pos), rho, diff], dim=1) return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule, Print import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ, mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 5 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 print( "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]" ) assert N_dom_feats == len(args['features'].split(', ')) crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.GELU() self.hcs = N_hcs class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, clean_out=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) class GRUConv(torch.nn.Module): def __init__(self, hcs=self.hcs, act=self.act): super(GRUConv, self).__init__() self.act = act self.hcs = hcs self.GRU = torch.nn.GRUCell(self.hcs * 2, self.hcs) # self.scatter_norm = scatter_norm(self.hcs) # self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True) self.scatter_norm = scatter_att(self.hcs) #version a self.lin_CoC_msg = MLP([self.hcs, self.hcs], clean_out=True) #version a self.lin_CoC_self = MLP([self.hcs, self.hcs], clean_out=True) self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs) self.lin_x_msg = MLP([self.hcs, self.hcs], clean_out=True) self.lin_x_self = MLP([self.hcs, self.hcs], clean_out=True) self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs) def forward(self, x, CoC, h, batch): h = self.GRU(torch.cat([CoC[batch], x], dim=1), h) # msg = self.lin_CoC_msg( self.scatter_norm(h, batch) ) msg = self.lin_CoC_msg(self.scatter_norm(h, batch, CoC)) #version a CoC = self.lin_CoC_self(CoC) CoC = self.act(self.CoC_batch_norm(msg + CoC)) h = self.GRU(torch.cat([x, CoC[batch]], dim=1), h) msg = self.lin_x_msg(h) x = self.lin_x_self(x) x = self.act(self.x_batch_norm(msg + x)) return x, CoC, h class scatter_norm(torch.nn.Module): #Original def __init__(self, hcs): super(scatter_norm, self).__init__() self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats * hcs) def forward(self, x, batch): return self.batch_norm( scatter_distribution(x, batch, dim=0)) class scatter_att(torch.nn.Module): #Version a def __init__(self, hcs): super(scatter_att, self).__init__() self.att_lin = torch.nn.Linear(int(2 * hcs), 1) def forward(self, x, batch, CoC): att = scatter_softmax(self.att_lin( torch.cat([x, CoC[batch]], dim=1)), batch, dim=0) return scatter_sum(att * x, batch, dim=0) N_x_feats = N_dom_feats # + 4*(N_dom_feats + 1) N_CoC_feats = N_scatter_feats * N_x_feats + 3 self.scatter_norm = scatter_norm(N_x_feats) self.x_encoder = MLP([N_x_feats, self.hcs]) self.CoC_encoder = MLP([N_CoC_feats, self.hcs]) self.convs = torch.nn.ModuleList() for _ in range(N_metalayers): self.convs.append(GRUConv()) # self.scatter_norm2 = scatter_norm(self.hcs) self.scatter_norm2 = scatter_att(self.hcs * (2 + N_metalayers) / 2) #version a # self.decoder = MLP([(1+N_scatter_feats)*self.hcs,3*self.hcs,self.hcs,N_outputs],clean_out=True) self.decoder = MLP([(2 + N_metalayers) * self.hcs, 3 * self.hcs, self.hcs, N_outputs], clean_out=True) #version a # def forward(self,data): # x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch def forward(self, *data): # print(data) # print("================================") # Print("Here") if type(data) == tuple: # print("here", len(data),data[0].shape) from torch_geometric.data import Data, Batch datalist = [] for x in data: if x.dim() > 2: for tmp_x in x: datalist.append(Data(x=tmp_x.squeeze())) else: datalist.append(Data(x=x.squeeze())) # datalist.append(Data(x=x)) data = Batch.from_data_list(datalist) try: x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch except AttributeError: x = data batch = torch.zeros(x.shape[0], device=model.device, dtype=torch.int64) # Print("To Here") ############### x = x.float() ############### # print(x.shape) # print(x) CoC = scatter_sum(x[:, -3:] * x[:, -5].unsqueeze(-1), batch, dim=0) / scatter_sum( x[:, -5].unsqueeze(-1), batch, dim=0) CoC[CoC.isnan()] = 0 CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) h = torch.zeros((x.shape[0], self.hcs), device=self.device) out = CoC.clone() #version a for conv in self.convs: x, CoC, h = conv(x, CoC, h, batch) out = torch.cat([out, CoC.clone()], dim=1) #version a # CoC = torch.cat([CoC,self.scatter_norm2(x, batch, CoC)],dim=1) CoC = torch.cat([out, self.scatter_norm2(x, batch, out)], dim=1) #version a CoC = self.decoder(CoC) return CoC return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast # from torch_geometric.nn import GATConv, FeaStConv @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats # N_u_feats = N_scatter_feats*N_x_feats self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs) self.CoC_encoder = torch.nn.Linear(3 + N_scatter_feats * N_x_feats, self.hcs) class EdgeModel(torch.nn.Module): def __init__(self, hcs, act): super(EdgeModel, self).__init__() self.hcs = hcs self.act = act self.lins = torch.nn.ModuleList() for i in range(2): self.lins.append( torch.nn.Linear(4 * self.hcs, 4 * self.hcs)) self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs) def forward(self, src, dest, edge_attr, CoC, batch): # x: [N, F_x], where N is the number of nodes. # src, dest: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # edge_index: [2, E] with max entry N - 1. # u: [B, F_u], where B is the number of graphs. # batch: [N] with max entry B - 1. out = torch.cat([src, dest, edge_attr, CoC[batch]], dim=1) for lin in self.lins: out = self.act(lin(out)) return self.act(self.decoder(out)) class NodeModel(torch.nn.Module): def __init__(self, hcs, act): super(NodeModel, self).__init__() self.hcs = hcs self.act = act self.lins1 = torch.nn.ModuleList() for i in range(2): self.lins1.append( torch.nn.Linear(2 * self.hcs, 2 * self.hcs)) self.lins2 = torch.nn.ModuleList() for i in range(2): self.lins2.append( torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, (2 + 2 * N_scatter_feats) * self.hcs)) self.decoder = torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, CoC, batch): row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) for lin in self.lins1: out = self.act(lin(out)) out = scatter_distribution(out, col, dim=0) #8*hcs out = torch.cat([x, out, CoC[batch]], dim=1) for lin in self.lins2: out = self.act(lin(out)) return self.act(self.decoder(out)) class GlobalModel(torch.nn.Module): def __init__(self, hcs, act): super(GlobalModel, self).__init__() def forward(self, x, edge_index, edge_attr, CoC, batch): return CoC from torch_geometric.nn import MetaLayer self.ops = torch.nn.ModuleList() self.GRUCells = torch.nn.ModuleList() self.lins1 = torch.nn.ModuleList() self.lins2 = torch.nn.ModuleList() self.lins3 = torch.nn.ModuleList() for i in range(N_metalayers): self.ops.append( MetaLayer(EdgeModel(self.hcs, self.act), NodeModel(self.hcs, self.act), GlobalModel(self.hcs, self.act))) self.GRUCells.append( torch.nn.GRUCell(self.hcs * 2 + 4 + N_x_feats, self.hcs)) self.lins1.append( torch.nn.Linear((1 + N_scatter_feats) * self.hcs, (1 + N_scatter_feats) * self.hcs)) self.lins2.append( torch.nn.Linear((1 + N_scatter_feats) * self.hcs, self.hcs)) self.lins3.append(torch.nn.Linear(2 * self.hcs, self.hcs)) self.decoders = torch.nn.ModuleList() self.decoders.append(torch.nn.Linear(self.hcs, self.hcs)) self.decoders.append(torch.nn.Linear(self.hcs, self.hcs)) self.decoder = torch.nn.Linear(self.hcs, N_outputs) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch pos = x[:, -3:] x = torch.cat( [x, scatter_distribution(edge_attr, edge_index[1], dim=0)], dim=1) CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch, dim=0) / scatter_sum( x[:, 0].view(-1, 1), batch, dim=0) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) CoC_edge_index = torch.cat([ torch.arange(x.shape[0]).view(1, -1).type_as(batch), batch.view(1, -1) ], dim=0) cart = pos[CoC_edge_index[0], -3:] - CoC[CoC_edge_index[1], :3] del pos rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat( [cart.type_as(x), rho.type_as(x), x[CoC_edge_index[0]]], dim=1) x = self.act(self.x_encoder(x)) edge_attr = self.act(self.edge_attr_encoder(edge_attr)) CoC = self.act(self.CoC_encoder(CoC)) # u = torch.zeros( (batch.max() + 1, self.hcs) ).type_as(x) h = torch.zeros((x.shape[0], self.hcs)).type_as(x) for i, op in enumerate(self.ops): x, edge_attr, CoC = op(x, edge_index, edge_attr, CoC, batch) h = self.act(self.GRUCells[i](torch.cat([ CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr ], dim=1), h)) CoC = self.act(self.lins1[i](torch.cat( [CoC, scatter_distribution(h, batch, dim=0)], dim=1))) CoC = self.act(self.lins2[i](CoC)) h = self.act(self.GRUCells[i](torch.cat([ CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr ], dim=1), h)) x = self.act(self.lins3[i](torch.cat([x, h], dim=1))) for lin in self.decoders: CoC = self.act(lin(CoC)) CoC = self.decoder(CoC) return CoC return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ, mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 5 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 print( "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]" ) assert N_dom_feats == len(args['features'].split(', ')) crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs def gaussian(alpha): phi = torch.exp(-1 * alpha.pow(2)) return phi class RBF(torch.nn.Module): def __init__(self, in_features, out_features, basis_func=gaussian): super(RBF, self).__init__() self.in_features = in_features self.out_features = out_features self.centres = torch.nn.Parameter( torch.Tensor(out_features, in_features)) self.log_sigmas = torch.nn.Parameter( torch.Tensor(out_features)) # self.sigmas = torch.nn.Parameter(torch.Tensor(out_features)) self.basis_func = basis_func self.reset_parameters() def reset_parameters(self): torch.nn.init.normal_(self.centres, 0, 1) torch.nn.init.constant_(self.log_sigmas, 0) # torch.nn.init.constant_(self.sigmas, 10) def forward(self, x): size = (x.size(0), self.out_features, self.in_features) x = x.unsqueeze(1).expand(size) c = self.centres.unsqueeze(0).expand(size) distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp( self.log_sigmas).unsqueeze(0) # distances = (x - c).pow(2).sum(-1) / self.sigmas.unsqueeze(0) # print(distances.mean(), distances.std(), distances.min(), distances.max()) return self.basis_func(distances) class MLP(torch.nn.Module): def __init__(self, hcs_list, act=self.act, clean_out=False): super(MLP, self).__init__() mlp = [] for i in range(1, len(hcs_list)): mlp.append( torch.nn.Linear(hcs_list[i - 1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) class RBF_scatter(torch.nn.Module): def __init__(self, in_features, out_features, basis_func=gaussian): super(RBF_scatter, self).__init__() self.RBF_batch_norm = torch.nn.BatchNorm1d(in_features) self.RBF = RBF(in_features, out_features, basis_func) self.sum_batch_norm = torch.nn.BatchNorm1d(out_features * N_scatter_feats) def forward(self, x, batch): return self.sum_batch_norm( scatter_distribution(self.RBF(self.RBF_batch_norm(x)), batch, dim=0)) N_x_feats = N_dom_feats # + 4*(N_dom_feats + 1) # self.RBF_scatter = RBF_scatter(N_x_feats,2*self.hcs, basis_func=gaussian) self.encoder = MLP([N_dom_feats, 2 * self.hcs]) self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats * 2 * self.hcs) self.decoder = MLP([ N_scatter_feats * 2 * self.hcs, 5 * self.hcs, self.hcs, N_outputs ], clean_out=True) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### # x = self.RBF_scatter(x,batch) x = self.batch_norm( scatter_distribution(self.encoder(x), batch, dim=0)) x = self.decoder(x) return x return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ,mean,var,maximum,minimum],dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats']#6 N_scatter_feats = 5 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 print("if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]") assert N_dom_feats == len(args['features'].split(', ')) crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs class MLP(torch.nn.Module): def __init__(self, hcs_list, act = self.act, clean_out = False): super(MLP, self).__init__() mlp = [] for i in range(1,len(hcs_list)): mlp.append(torch.nn.Linear(hcs_list[i-1], hcs_list[i])) mlp.append(torch.nn.BatchNorm1d(hcs_list[i])) mlp.append(act) if clean_out: self.mlp = torch.nn.Sequential(*mlp[:-2]) else: self.mlp = torch.nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) # class GRUConv(torch.nn.Module): # def __init__(self,hcs = self.hcs, act = self.act): # super(GRUConv, self).__init__() # self.act = act # self.hcs = hcs # self.GRU = torch.nn.GRUCell(self.hcs*2,self.hcs) # self.scatter_norm = scatter_norm(self.hcs) # self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True) # self.lin_CoC_self = MLP([self.hcs, self.hcs],clean_out = True) # self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs) # self.lin_x_msg = MLP([self.hcs, self.hcs],clean_out = True) # self.lin_x_self = MLP([self.hcs, self.hcs],clean_out = True) # self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs) # def forward(self, x, CoC, h, batch): # h = self.act( self.GRU( torch.cat([CoC[batch], x], dim=1), h) ) # msg = self.lin_CoC_msg( self.scatter_norm(h, batch) ) # CoC = self.lin_CoC_self(CoC) # CoC = self.act( self.CoC_batch_norm(msg+CoC) ) # h = self.act( self.GRU( torch.cat([x, CoC[batch]], dim=1), h) ) # msg = self.lin_x_msg(h) # x = self.lin_x_self(x) # x = self.act( self.x_batch_norm(msg+x) ) # return x, CoC, h # class AttConv(torch.nn.Module): # def __init__(self,in_hcs = [self.hcs, self.hcs], out_hcs = self.hcs, heads = 1): # super(AttConv,self).__init__() # self.heads = heads # self.out_hcs = out_hcs # self.lin_key = torch.nn.Linear(in_hcs[0], heads*out_hcs) # self.lin_query = torch.nn.Linear(in_hcs[1], heads*out_hcs) # self.lin_value = torch.nn.Linear(in_hcs[0], heads*out_hcs) # self.sqrt_d = torch.sqrt(out_hcs) # self.reset_parameters() # def reset_parameters(self): # self.lin_key.reset_parameters() # self.lin_query.reset_parameters() # self.lin_value.reset_parameters() # def forward(self, x, CoC, batch): # key = self.lin_key(x).view(-1,self.heads,self.out_hcs) # query = self.lin_query(CoC class scatter_norm(torch.nn.Module): def __init__(self, hcs): super(scatter_norm, self).__init__() self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats*hcs) def forward(self, x, batch): return self.batch_norm(scatter_distribution(x,batch,dim=0)) N_x_feats = N_dom_feats# + 4*(N_dom_feats + 1) N_CoC_feats = N_scatter_feats*N_x_feats + 3 self.scatter_norm = scatter_norm(N_x_feats) self.x_encoder = MLP([N_x_feats,self.hcs]) self.CoC_encoder = MLP([N_CoC_feats,self.hcs]) self.TConv = torch_geometric.nn.TransformerConv(in_channels = [self.hcs,self.hcs], out_channels = self.hcs, heads = N_metalayers) self.decoder = MLP([(N_metalayers)*self.hcs,3*self.hcs,self.hcs,N_outputs],clean_out=True) def forward(self,data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### CoC = scatter_sum( x[:,-3:]*x[:,-5].view(-1,1), batch, dim=0) / scatter_sum(x[:,-5].view(-1,1), batch, dim=0) CoC = torch.cat([CoC,self.scatter_norm(x,batch)],dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) CoC_x = torch.cat([CoC,x],dim=0) edge_index = self.return_edge_index(batch) CoC_x = self.TConv(CoC_x, edge_index) CoC = self.decoder(CoC_x[batch.unique()]) return CoC def return_edge_index(self,batch): offset = batch.max() + 1 frm = torch.arange(offset, offset + batch.shape[0],dtype=torch.long).view(1,-1) to = batch.view(1,-1) return torch.cat([frm,to],dim=0).contiguous() return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from torch_geometric_temporal import nn from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast from torch_geometric.nn import GATConv @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) self.conv = nn.recurrent.temporalgcn.TGCN(self.hcs, self.hcs) self.decoder = torch.nn.Linear(self.hcs * 4, N_outputs) self.GATConv = GATConv(self.hcs, self.hcs, 3, add_self_loops=False) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch x = torch.cat( [x, scatter_distribution(edge_attr, edge_index[1], dim=0)], dim=1) # x0 = x.clone() time_sort = torch.argsort(x[:, 1]) batch_time_sort = time_sort[torch.argsort(batch[time_sort])] time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view( 1, -1) ], dim=0) graph_ids, graph_node_counts = batch.unique(return_counts=True) tmp_bool = torch.ones(time_edge_index.shape[1], dtype=bool) tmp_bool[(torch.cumsum(graph_node_counts, 0) - 1)[:-1]] = False time_edge_index = time_edge_index[:, tmp_bool] time_edge_index = torch.cat( [time_edge_index, time_edge_index.flip(0)], dim=1) x = self.act(self.x_encoder(x)) x, (e, w) = self.GATConv(x, edge_index, return_attention_weights=True) return x, e, w print(x, e, w) h = self.act(self.conv(x, time_edge_index)) for i in range(N_metalayers): h = self.act(self.conv(x, time_edge_index, H=h)) x = scatter_distribution(h, batch, dim=0) x = self.decoder(x) return x return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule import pytorch_lightning as pl from torch_geometric_temporal import nn from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats N_u_feats = N_scatter_feats * N_x_feats self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs) self.u_encoder = torch.nn.Linear(N_u_feats, self.hcs) class EdgeModel(torch.nn.Module): def __init__(self, hcs, act): super(EdgeModel, self).__init__() self.hcs = hcs self.act = act self.lins = torch.nn.ModuleList() for i in range(2): self.lins.append( torch.nn.Linear(4 * self.hcs, 4 * self.hcs)) self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs) def forward(self, src, dest, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # src, dest: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # edge_index: [2, E] with max entry N - 1. # u: [B, F_u], where B is the number of graphs. # batch: [N] with max entry B - 1. out = torch.cat([src, dest, edge_attr, u[batch]], dim=1) for lin in self.lins: out = self.act(lin(out)) return self.act(self.decoder(out)) class NodeModel(torch.nn.Module): def __init__(self, hcs, act): super(NodeModel, self).__init__() self.hcs = hcs self.act = act self.lins1 = torch.nn.ModuleList() for i in range(2): self.lins1.append( torch.nn.Linear(2 * self.hcs, 2 * self.hcs)) self.lins2 = torch.nn.ModuleList() for i in range(2): self.lins2.append( torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, (2 + 2 * N_scatter_feats) * self.hcs)) self.decoder = torch.nn.Linear( (2 + 2 * N_scatter_feats) * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, u, batch): row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) for lin in self.lins1: out = self.act(lin(out)) out = scatter_distribution(out, col, dim=0) #8*hcs out = torch.cat([x, out, u[batch]], dim=1) for lin in self.lins2: out = self.act(lin(out)) return self.act(self.decoder(out)) class GlobalModel(torch.nn.Module): def __init__(self, hcs, act): super(GlobalModel, self).__init__() self.hcs = hcs self.act = act self.lins1 = torch.nn.ModuleList() for i in range(2): self.lins1.append( torch.nn.Linear((1 + N_scatter_feats) * self.hcs, (1 + N_scatter_feats) * self.hcs)) self.decoder = torch.nn.Linear( (1 + N_scatter_feats) * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, u, batch): out = torch.cat( [u, scatter_distribution(x, batch, dim=0)], dim=1) # 5*hcs for lin in self.lins1: out = self.act(lin(out)) return self.act(self.decoder(out)) self.tgcn_starter = nn.recurrent.temporalgcn.TGCN( N_x_feats, self.hcs) from torch_geometric.nn import MetaLayer self.ops = torch.nn.ModuleList() self.tgcns = torch.nn.ModuleList() for i in range(N_metalayers): self.ops.append( MetaLayer(EdgeModel(self.hcs, self.act), NodeModel(self.hcs, self.act), GlobalModel(self.hcs, self.act))) self.tgcns.append( nn.recurrent.temporalgcn.TGCN(self.hcs, self.hcs)) self.decoders = torch.nn.ModuleList() self.decoders.append( torch.nn.Linear( (1 + N_metalayers) * self.hcs + N_scatter_feats * self.hcs, 10 * self.hcs)) self.decoders.append(torch.nn.Linear(10 * self.hcs, 8 * self.hcs)) self.decoders.append(torch.nn.Linear(8 * self.hcs, 6 * self.hcs)) self.decoders.append(torch.nn.Linear(6 * self.hcs, 4 * self.hcs)) self.decoders.append(torch.nn.Linear(4 * self.hcs, 2 * self.hcs)) self.decoder = torch.nn.Linear(2 * self.hcs, N_outputs) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch x = torch.cat( [x, scatter_distribution(edge_attr, edge_index[1], dim=0)], dim=1) u = torch.cat([scatter_distribution(x, batch, dim=0)], dim=1) time_sort = torch.argsort(x[:, 1]) batch_time_sort = time_sort[torch.argsort(batch[time_sort])] time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view( 1, -1) ], dim=0) graph_ids, graph_node_counts = batch.unique(return_counts=True) tmp_bool = torch.ones(time_edge_index.shape[1], dtype=bool) tmp_bool[(torch.cumsum(graph_node_counts, 0) - 1)[:-1]] = False time_edge_index = time_edge_index[:, tmp_bool] time_edge_index = torch.cat( [time_edge_index, time_edge_index.flip(0)], dim=1) time_edge_index = torch.cat([edge_index, time_edge_index], dim=1) h = self.tgcn_starter(x, time_edge_index) x = self.act(self.x_encoder(x)) edge_attr = self.act(self.edge_attr_encoder(edge_attr)) u = self.act(self.u_encoder(u)) out = u.clone() for i in range(N_metalayers): x, edge_attr, u = self.ops[i](x, edge_index, edge_attr, u, batch) h = self.act(self.tgcns[i](x, time_edge_index, H=h)) out = torch.cat([out, u.clone()], dim=1) out = torch.cat([out, scatter_distribution(h, batch, dim=0)], dim=1) for lin in self.decoders: out = self.act(lin(out)) out = self.decoder(out) return out return Net
def Load_model(name, args): from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor import pytorch_lightning as pl from typing import Optional import torch from torch_scatter import scatter_sum, scatter_min, scatter_max from torch_scatter.utils import broadcast @torch.jit.script def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, maximum, minimum], dim=1) N_edge_feats = args['N_edge_feats'] #6 N_dom_feats = args['N_dom_feats'] #6 N_scatter_feats = 4 # N_targets = args['N_targets'] N_outputs = args['N_outputs'] N_metalayers = args['N_metalayers'] #10 N_hcs = args['N_hcs'] #32 #Possibly add (edge/node/global)_layers crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions( name, args) likelihood_fitting = True if name[-4:] == 'NLLH' else False class Net(customModule): def __init__(self): super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args) self.act = torch.nn.SiLU() self.hcs = N_hcs N_x_feats = 2 * N_dom_feats + N_edge_feats self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs) # class Conversation(torch.nn.Module): # def __init__(self,hcs,act): # super(Conversation, self).__init__() # self.act = act # self.hcs = hcs # self.GRU = torch.nn.GRUCell(3*self.hcs,self.hcs) # self.lin_msg1 = torch.nn.Linear((1+N_scatter_feats)*self.hcs,self.hcs) # self.lin_msg2 = torch.nn.Linear((1+N_scatter_feats)*self.hcs,self.hcs) # def forward(self, x, edge_index, edge_attr, batch, h): # (frm, to) = edge_index # h = self.act( self.GRU( torch.cat([x[to],x[frm],edge_attr],dim=1), h ) ) # x = self.act( self.lin_msg1( torch.cat([x,scatter_distribution(h, to, dim=0)],dim=1) ) ) # h = self.act( self.GRU( torch.cat([x[frm],x[to],edge_attr],dim=1), h) ) # x = self.act( self.lin_msg2( torch.cat([x,scatter_distribution(h, frm, dim=0)],dim=1) ) ) # return x class GRUConv(torch.nn.Module): def __init__(self, hcs, act): super(GRUConv, self).__init__() self.hcs = hcs self.act = act self.GRU = torch.nn.GRUCell(self.hcs, self.hcs) self.lin = torch.nn.Linear(2 * self.hcs, self.hcs) def forward(self, x, edge_index, edge_attr, batch, h): (frm, to) = edge_index h = self.act(self.GRU(x, h)) x = torch.cat([x, h[frm]], dim=1) x = self.act(self.lin(x)) return x self.GRUConvs = torch.nn.ModuleList() # self.ConvConvs = torch.nn.ModuleList() for i in range(N_metalayers): self.GRUConvs.append(GRUConv(self.hcs, self.act)) # self.ConvConvs.append(Conversation(self.hcs,self.act) self.decoders = torch.nn.ModuleList() self.decoders.append(torch.nn.Linear(4 * self.hcs, self.hcs)) self.decoders.append(torch.nn.Linear(self.hcs, self.hcs)) self.decoder = torch.nn.Linear(self.hcs, N_outputs) def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch time_sort = torch.argsort(x[:, 1]) graph_ids, graph_node_counts = batch.unique(return_counts=True) batch_time_sort = torch.cat( [time_sort[batch[time_sort] == i] for i in graph_ids]) time_edge_index = torch.cat([ batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view( 1, -1) ], dim=0) tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1] time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + (tmp_ind + 1).tolist()[:-1]] time_edge_index = torch.cat([ time_edge_index, torch.cat([ time_edge_index[1, -1].view(1, 1), time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1) ]) ], dim=1) time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1] )] edge_attr = edge_feature_constructor(x, time_edge_index) x = torch.cat([x, edge_attr, x[time_edge_index[0]]], dim=1) x = self.act(self.x_encoder(x)) h = torch.zeros((x.shape[0], self.hcs)).type_as(x) for i, conv in enumerate(self.GRUConvs): x = conv(x, time_edge_index, edge_attr, batch, h) out = scatter_distribution(x, batch, dim=0) for lin in self.decoders: out = self.act(lin(out)) out = self.decoder(out) return out return Net