コード例 #1
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax
    from torch_scatter.utils import broadcast
    import torch.nn.functional as F
    #     from torch_geometric.utils import softmax

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    @torch.jit.script
    def x_feature_constructor(x, graph_node_counts):
        tmp = []
        a: List[int] = graph_node_counts.tolist()
        for tmp_x in x.split(a):
            tmp_x = tmp_x.unsqueeze(1) - tmp_x

            cart = tmp_x[:, :, -3:]

            rho = torch.norm(cart, p=2, dim=-1).unsqueeze(2)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            tmp_x = torch.cat([cart, rho, tmp_x[:, :, :-3]], dim=2)

            tmp.append(
                torch.cat([
                    tmp_x.mean(1),
                    tmp_x.std(1),
                    tmp_x.min(1)[0],
                    tmp_x.max(1)[0]
                ],
                          dim=1))
        return torch.cat(tmp, 0)

    @torch.jit.script
    def time_edge_indeces(t, batch: torch.Tensor):
        time_sort = torch.argsort(t)
        graph_ids, graph_node_counts = torch.unique(batch, return_counts=True)
        batch_time_sort = torch.cat(
            [time_sort[batch[time_sort] == i] for i in graph_ids])
        time_edge_index = torch.cat([
            batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1)
        ],
                                    dim=0)

        tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1]
        li: List[int] = (tmp_ind + 1).tolist()
        time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]]
        time_edge_index = torch.cat([
            time_edge_index,
            torch.cat([
                time_edge_index[1, -1].view(1, 1),
                time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1)
            ])
        ],
                                    dim=1)

        time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])]
        return time_edge_index

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #     p_dropout = args['dropout']

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)
            print(
                "This model assumes Charge is at index 0 and position is the last three"
            )

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            #             self.dropout = torch.nn.Dropout(p=p_dropout)

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, no_final_act=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)
                    if not no_final_act:
                        self.mlp = torch.nn.Sequential(*mlp)
                    else:
                        self.mlp = torch.nn.Sequential(*mlp[:-1])

                def forward(self, x):
                    return self.mlp(x)

#             class AttGNN(torch.nn.Module):
#                 def __init__(self, hcs_in, hcs_out, act = self.act):
#                     super(AttGNN, self).__init__()

# #                     self.att_mlp = torch.nn.Sequential(torch.nn.Linear(2*hcs_in,hcs_in),
# #                                                        act,
# #                                                        torch.nn.Linear(hcs_in,1))
#                     self.att_mlp = torch.nn.Linear(2*hcs_in,1)

#                     self.self_mlp = MLP([hcs_in,hcs_in,hcs_out])
#                     self.msg_mlp = MLP([hcs_in,hcs_in,hcs_out])

#                 def forward(self, x, graph_node_counts):
#                     li : List[int] = graph_node_counts.tolist()
#                     tmp = []
#                     for tmp_x, msg in zip(x.split(li), self.msg_mlp(x).split(li)):
#                         tmp_x = torch.cat([tmp_x.unsqueeze(1) - tmp_x, tmp_x.unsqueeze(1) + tmp_x],dim=2)
#                         tmp.append(torch.matmul(self.att_mlp(tmp_x).squeeze(),msg))
#                     return self.self_mlp(x) + torch.cat(tmp,0)

            class AttGNN(torch.nn.Module):
                def __init__(self, hcs_in, hcs_out, act=self.act):
                    super(AttGNN, self).__init__()

                    self.self_mlp = MLP([hcs_in, hcs_in, hcs_out])
                    self.msg_mlp = torch.nn.Linear(
                        2 * hcs_in, hcs_out)  #MLP([2*hcs_in,hcs_in,hcs_out])
                    self.msg_mlp2 = MLP([4 * hcs_out, 2 * hcs_out, hcs_out])

                def forward(self, x, graph_node_counts):
                    #                     print(graph_node_counts.max().item(), graph_node_counts.float().mean().item(), graph_node_counts.float().std().item(), graph_node_counts.min().item())
                    li: List[int] = graph_node_counts.tolist()
                    tmp = []
                    for tmp_x in x.split(li):
                        tmp_x = torch.cat([
                            tmp_x.unsqueeze(1) - tmp_x,
                            tmp_x.unsqueeze(1) + tmp_x
                        ],
                                          dim=2)
                        tmp_x = self.msg_mlp(tmp_x)
                        tmp.append(
                            self.msg_mlp2(
                                torch.cat([
                                    tmp_x.mean(1),
                                    tmp_x.std(1),
                                    tmp_x.min(1)[0],
                                    tmp_x.max(1)[0]
                                ],
                                          dim=1)))
                    out = self.self_mlp(x) + torch.cat(tmp, 0)
                    del tmp, tmp_x
                    return out

            N_x_feats = N_dom_feats + 4 * (N_dom_feats + 1) + 4 + 3

            #             self.att = MLP([N_x_feats,self.hcs,1],no_final_act=True)

            self.x_encoder = MLP([N_x_feats, self.hcs, self.hcs])
            self.convs = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.convs.append(AttGNN(self.hcs, self.hcs))

            self.decoder = MLP([
                N_scatter_feats * self.hcs, N_scatter_feats // 2 * self.hcs,
                self.hcs, N_outputs
            ],
                               no_final_act=True)
#             self.beta = torch.nn.Parameter(torch.ones(1)*10)

        def return_CoC_and_edge_attr(self, x, batch):
            pos = x[:, -3:]
            charge = x[:, 0].view(-1, 1)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum(
                charge, batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos - CoC[batch]
            rho = torch.norm(cart, p=2, dim=1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)
            return CoC, CoC_edge_attr

        def forward(self, data):
            x, batch = data.x.float(), data.batch

            CoC, CoC_edge_attr = self.return_CoC_and_edge_attr(x, batch)

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            x = torch.cat([
                x,
                x_feature_constructor(x, graph_node_counts), CoC_edge_attr,
                CoC[batch]
            ],
                          dim=1)

            #             att = scatter_softmax(graph_node_counts[batch].view(-1,1)/100*self.att(x),batch,dim=0)
            #             att_d = scatter_distribution(att,batch,dim=0)
            #             mask = att.squeeze() > att_d[batch,0] + att_d[batch,1]
            # #             mask = att.squeeze() > att_d[batch,2] - 1000*att_d[batch,1]/graph_node_counts[batch]

            #             x = x[mask]
            #             batch = batch[mask]
            #             graph_ids, graph_node_counts = batch.unique(return_counts=True)

            x = self.x_encoder(x)
            for conv in self.convs:
                x = conv(x, graph_node_counts)

            return self.decoder(scatter_distribution(x, batch, dim=0))

    return Net
コード例 #2
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    @torch.jit.script
    def x_feature_constructor(x, graph_node_counts):
        tmp = []
        a: List[int] = graph_node_counts.tolist()
        for tmp_x in x.split(a):
            tmp_x = tmp_x.unsqueeze(1) - tmp_x

            cart = tmp_x[:, :, -3:]

            rho = torch.norm(cart, p=2, dim=-1).unsqueeze(2)
            rho_mask = rho.squeeze() != 0
            if rho_mask.sum() != 0:
                cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            tmp_x = torch.cat([cart, rho, tmp_x[:, :, :-3]], dim=2)

            tmp.append(
                torch.cat([
                    tmp_x.mean(1),
                    tmp_x.std(1),
                    tmp_x.min(1)[0],
                    tmp_x.max(1)[0]
                ],
                          dim=1))
        return torch.cat(tmp, 0)

    @torch.jit.script
    def time_edge_indeces(t, batch: torch.Tensor):
        time_sort = torch.argsort(t)
        graph_ids, graph_node_counts = torch.unique(batch, return_counts=True)
        batch_time_sort = torch.cat(
            [time_sort[batch[time_sort] == i] for i in graph_ids])
        time_edge_index = torch.cat([
            batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1)
        ],
                                    dim=0)

        tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1]
        li: List[int] = (tmp_ind + 1).tolist()
        time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]]
        time_edge_index = torch.cat([
            time_edge_index,
            torch.cat([
                time_edge_index[1, -1].view(1, 1),
                time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1)
            ])
        ],
                                    dim=1)

        time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])]
        return time_edge_index

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = 2 * N_dom_feats + 4 * (N_dom_feats +
                                               1) + N_edge_feats + 4

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)

            self.CoC_encoder = torch.nn.Linear(3 + N_scatter_feats * N_x_feats,
                                               self.hcs)

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, no_final_act=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if not no_final_act:
                        self.mlp = torch.nn.Sequential(*mlp)
                    else:
                        self.mlp = torch.nn.Sequential(*mlp[:-1])

                def forward(self, x):
                    return self.mlp(x)

#             class Conversation_x(torch.nn.Module):
#                 def __init__(self,hcs,act):
#                     super(Conversation, self).__init__()
#                     self.act = act
#                     self.hcs = hcs
#                     self.GRU = torch.nn.GRUCell(self.hcs*2 + 4 + N_x_feats,self.hcs)

#                 def forward(self, x, edge_index, edge_attr, batch, h):
#                     (frm, to) = edge_index

#                     h = self.act( self.GRU( torch.cat([x[to],x[frm],edge_attr],dim=1), h ) )
#                     x = self.act( self.lin_msg1( torch.cat([x,scatter_distribution(h, to, dim=0)],dim=1) ) )

#                     h = self.act( self.GRU( torch.cat([x[frm],x[to],edge_attr],dim=1), h) )
#                     x = self.act( self.lin_msg2( torch.cat([x,scatter_distribution(h, frm, dim=0)],dim=1) ) )

            self.GRUCells = torch.nn.ModuleList()

            self.lins_CoC_msg = torch.nn.ModuleList()
            self.lins_CoC_self = torch.nn.ModuleList()
            self.CoC_batch_norm = torch.nn.ModuleList()

            self.lins_x_msg = torch.nn.ModuleList()
            self.lins_x_self = torch.nn.ModuleList()
            self.x_batch_norm = torch.nn.ModuleList()

            for i in range(N_metalayers):
                self.GRUCells.append(torch.nn.GRUCell(self.hcs * 2, self.hcs))

                self.lins_CoC_msg.append(
                    torch.nn.Linear(N_scatter_feats * self.hcs, self.hcs))
                self.lins_CoC_self.append(torch.nn.Linear(self.hcs, self.hcs))
                self.CoC_batch_norm.append(torch.nn.BatchNorm1d(self.hcs))

                self.lins_x_msg.append(torch.nn.Linear(self.hcs, self.hcs))
                self.lins_x_self.append(torch.nn.Linear(self.hcs, self.hcs))
                self.x_batch_norm.append(torch.nn.BatchNorm1d(self.hcs))

            self.decoders = torch.nn.ModuleList()
            self.decoder_batch_norms = torch.nn.ModuleList()

            self.decoders.append(
                torch.nn.Linear((1 + N_scatter_feats) * self.hcs, self.hcs))
            self.decoder_batch_norms.append(torch.nn.BatchNorm1d(self.hcs))

            self.decoders.append(torch.nn.Linear(self.hcs, self.hcs))
            self.decoder_batch_norms.append(torch.nn.BatchNorm1d(self.hcs))

            self.att = MLP([N_x_feats, self.hcs, 1], no_final_act=True)

            self.decoder = torch.nn.Linear(self.hcs, N_outputs)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############
            pos = x[:, -3:]

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            time_edge_index = time_edge_indeces(x[:, 1], batch)

            edge_attr = edge_feature_constructor(x, time_edge_index)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, 0].view(-1, 1), batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos[:, -3:] - CoC[batch, :3]
            del pos
            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)

            x = torch.cat([
                x,
                x_feature_constructor(x, graph_node_counts), edge_attr,
                x[time_edge_index[0]], CoC_edge_attr
            ],
                          dim=1)

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            att = scatter_softmax(graph_node_counts[batch].view(-1, 1) / 100 *
                                  self.att(x),
                                  batch,
                                  dim=0)
            att_d = scatter_distribution(att, batch, dim=0)
            mask = att.squeeze() > att_d[batch, 0] + att_d[batch, 1]

            x = x[mask]
            batch = batch[mask]

            x = self.act(self.x_encoder(x))
            CoC = self.act(self.CoC_encoder(CoC))

            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i in range(N_metalayers):
                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))

                CoC = self.act(
                    self.CoC_batch_norm[i](self.lins_CoC_msg[i](
                        scatter_distribution(h, batch, dim=0)) +
                                           self.lins_CoC_self[i](CoC)))

                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))

                x = self.act(self.x_batch_norm[i](self.lins_x_msg[i](h) +
                                                  self.lins_x_self[i](x)))

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            for batch_norm, lin in zip(self.decoder_batch_norms,
                                       self.decoders):
                CoC = self.act(batch_norm(lin(CoC)))

            CoC = self.decoder(CoC)
            return CoC

    return Net
コード例 #3
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ, mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 5
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32

    print(
        "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]"
    )
    assert N_dom_feats == len(args['features'].split(', '))

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, clean_out=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)

                def forward(self, x):
                    return self.mlp(x)

            class FullConv(torch.nn.Module):
                def __init__(self, hcs_in, hcs_out, act=self.act):
                    super(FullConv, self).__init__()
                    self.self_mlp = MLP([hcs_in, hcs_out])
                    self.msg_mlp = torch.nn.Sequential(
                        torch.nn.Linear(2 * hcs_in, hcs_out), act)
                    self.msg_mlp2 = MLP([4 * hcs_out, hcs_out])

                def forward(self, x, graph_node_counts):
                    li: List[int] = graph_node_counts.tolist()
                    tmp = []
                    for tmp_x in x.split(li):
                        tmp_x = torch.cat([
                            tmp_x.unsqueeze(1) - tmp_x,
                            tmp_x.unsqueeze(1) + tmp_x
                        ],
                                          dim=2)
                        tmp_x = self.msg_mlp(tmp_x)
                        tmp.append(
                            torch.cat([
                                tmp_x.mean(1),
                                tmp_x.std(1),
                                tmp_x.min(1)[0],
                                tmp_x.max(1)[0]
                            ],
                                      dim=1))
                    return self.self_mlp(x) + self.msg_mlp2(torch.cat(tmp, 0))

            class scatter_norm(torch.nn.Module):
                def __init__(self, hcs):
                    super(scatter_norm, self).__init__()
                    self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats *
                                                           hcs)

                def forward(self, x, batch):
                    return self.batch_norm(
                        scatter_distribution(x, batch, dim=0))

            N_x_feats = N_dom_feats  # + 4*(N_dom_feats + 1)
            N_CoC_feats = N_scatter_feats * N_x_feats + 3

            self.scatter_norm = scatter_norm(N_x_feats)
            self.x_encoder = MLP([N_x_feats, self.hcs])
            self.CoC_encoder = MLP([N_CoC_feats, self.hcs])

            self.convs = torch.nn.ModuleList()
            self.scatter_norms = torch.nn.ModuleList()
            for _ in range(N_metalayers):
                self.convs.append(FullConv(self.hcs, self.hcs))
                self.scatter_norms.append(scatter_norm(self.hcs))

            self.decoder = MLP([(1 + N_scatter_feats * N_metalayers) *
                                self.hcs, self.hcs, N_outputs],
                               clean_out=True)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############

            CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, -5].view(-1, 1), batch, dim=0)
            CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            for conv, scatter_norm in zip(self.convs, self.scatter_norms):
                x = conv(x, graph_node_counts)
                CoC = torch.cat([CoC, scatter_norm(x, batch)], dim=1)

            CoC = self.decoder(CoC)

            return CoC

    return Net
コード例 #4
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats
            N_u_feats = N_scatter_feats * N_x_feats

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)
            self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs)
            self.u_encoder = torch.nn.Linear(N_u_feats, self.hcs)

            class EdgeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(EdgeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins.append(
                            torch.nn.Linear(4 * self.hcs, 4 * self.hcs))
                    self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs)

                def forward(self, src, dest, edge_attr, u, batch):
                    # x: [N, F_x], where N is the number of nodes.
                    # src, dest: [E, F_x], where E is the number of edges.
                    # edge_attr: [E, F_e]
                    # edge_index: [2, E] with max entry N - 1.
                    # u: [B, F_u], where B is the number of graphs.
                    # batch: [N] with max entry B - 1.
                    out = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
                    for lin in self.lins:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class NodeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(NodeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins1 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins1.append(
                            torch.nn.Linear(2 * self.hcs, 2 * self.hcs))
                    self.lins2 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins2.append(
                            torch.nn.Linear(
                                (2 + 2 * N_scatter_feats) * self.hcs,
                                (2 + 2 * N_scatter_feats) * self.hcs))
                    self.decoder = torch.nn.Linear(
                        (2 + 2 * N_scatter_feats) * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, u, batch):
                    row, col = edge_index
                    out = torch.cat([x[row], edge_attr], dim=1)
                    for lin in self.lins1:
                        out = self.act(lin(out))
                    out = scatter_distribution(out, col, dim=0)  #8*hcs
                    out = torch.cat([x, out, u[batch]], dim=1)
                    for lin in self.lins2:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class GlobalModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(GlobalModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins1 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins1.append(
                            torch.nn.Linear((1 + N_scatter_feats) * self.hcs,
                                            (1 + N_scatter_feats) * self.hcs))
                    self.decoder = torch.nn.Linear(
                        (1 + N_scatter_feats) * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, u, batch):
                    out = torch.cat(
                        [u, scatter_distribution(x, batch, dim=0)],
                        dim=1)  # 5*hcs
                    for lin in self.lins1:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            from torch_geometric.nn import MetaLayer
            self.ops = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.ops.append(
                    MetaLayer(EdgeModel(self.hcs, self.act),
                              NodeModel(self.hcs, self.act),
                              GlobalModel(self.hcs, self.act)))

            self.decoders = torch.nn.ModuleList()
            self.decoders.append(
                torch.nn.Linear((1 + N_metalayers) * self.hcs +
                                N_scatter_feats * N_x_feats, 10 * self.hcs))
            self.decoders.append(torch.nn.Linear(10 * self.hcs, 8 * self.hcs))
            self.decoders.append(torch.nn.Linear(8 * self.hcs, 6 * self.hcs))
            self.decoders.append(torch.nn.Linear(6 * self.hcs, 4 * self.hcs))
            self.decoders.append(torch.nn.Linear(4 * self.hcs, 2 * self.hcs))

            self.decoder = torch.nn.Linear(2 * self.hcs, N_outputs)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

            x = torch.cat(
                [x, scatter_distribution(edge_attr, edge_index[1], dim=0)],
                dim=1)
            u = torch.cat([scatter_distribution(x, batch, dim=0)], dim=1)

            x0 = x.clone()

            x = self.act(self.x_encoder(x))
            edge_attr = self.act(self.edge_attr_encoder(edge_attr))
            u = self.act(self.u_encoder(u))
            out = u.clone()

            for i, op in enumerate(self.ops):
                x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
                out = torch.cat([out, u.clone()], dim=1)

            out = torch.cat([out, scatter_distribution(x0, batch, dim=0)],
                            dim=1)
            for lin in self.decoders:
                out = self.act(lin(out))

            out = self.decoder(out)
            return out

    return Net


#     assert 'load' in args, "Specify bool 'load' in args, for loading a checkpoint"
#     if args['load']:
#         return Net
#     else:
#         return Net()

# def Load_model(name, args):
#     from FunctionCollection import Loss_Functions
#     import pytorch_lightning as pl

#     from typing import Optional

#     import torch
#     from torch_scatter import scatter_sum, scatter_min, scatter_max
#     from torch_scatter.utils import broadcast

#     @torch.jit.script
#     def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
#                             out: Optional[torch.Tensor] = None,
#                             dim_size: Optional[int] = None,
#                             unbiased: bool = True) -> torch.Tensor:

#         if out is not None:
#             dim_size = out.size(dim)

#         if dim < 0:
#             dim = src.dim() + dim

#         count_dim = dim
#         if index.dim() <= dim:
#             count_dim = index.dim() - 1

#         ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
#         count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

#         index = broadcast(index, src, dim)
#         tmp = scatter_sum(src, index, dim, dim_size=dim_size)
#         count = broadcast(count, tmp, dim).clamp(1)
#         mean = tmp.div(count)

#         var = (src - mean.gather(dim, index))
#         var = var * var
#         var = scatter_sum(var, index, dim, out, dim_size)

#         if unbiased:
#             count = count.sub(1).clamp_(1)
#         var = var.div(count)
#         maximum = scatter_max(src, index, dim, out, dim_size)[0]
#         minimum = scatter_min(src, index, dim, out, dim_size)[0]

#         return torch.cat([mean,var,maximum,minimum],dim=1)

#     N_edge_feats = args['N_edge_feats'] #6
#     N_dom_feats = args['N_dom_feats']#6
#     N_scatter_feats = 4
#     N_targets = args['N_targets']
#     N_metalayers = args['N_metalayers'] #10
#     N_hcs = args['N_hcs'] #32
#     #Possibly add (edge/node/global)_layers

#     crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(name, args)
#     likelihood_fitting = True if name[-4:] == 'NLLH' else False
#     if args['wandb_activated']:
#         import wandb

#     class Net(pl.LightningModule):
#         def __init__(self):
#             super(Net, self).__init__()

#             self.crit = crit
#             self.y_post_processor = y_post_processor
#             self.output_post_processor = output_post_processor
#             self.cal_acc = cal_acc
#             self.likelihood_fitting = likelihood_fitting

#             self.act = torch.nn.SiLU()
#             self.hcs = N_hcs

#             N_x_feats = N_dom_feats + N_scatter_feats*N_edge_feats
#             N_u_feats = N_scatter_feats*N_x_feats

#             self.x_encoder = torch.nn.Linear(N_x_feats,self.hcs)
#             self.edge_attr_encoder = torch.nn.Linear(N_edge_feats,self.hcs)
#             self.u_encoder = torch.nn.Linear(N_u_feats,self.hcs)

#             class EdgeModel(torch.nn.Module):
#                 def __init__(self,hcs,act):
#                     super(EdgeModel, self).__init__()
#                     self.hcs = hcs
#                     self.act = act
#                     self.lins = torch.nn.ModuleList()
#                     for i in range(2):
#                         self.lins.append(torch.nn.Linear(4*self.hcs,4*self.hcs))
#                     self.decoder = torch.nn.Linear(4*self.hcs,self.hcs)

#                 def forward(self, src, dest, edge_attr, u, batch):
#                     # x: [N, F_x], where N is the number of nodes.
#                     # src, dest: [E, F_x], where E is the number of edges.
#                     # edge_attr: [E, F_e]
#                     # edge_index: [2, E] with max entry N - 1.
#                     # u: [B, F_u], where B is the number of graphs.
#                     # batch: [N] with max entry B - 1.
#                     out = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
#                     for lin in self.lins:
#                         out = self.act(lin(out))
#                     return self.act(self.decoder(out))

#             class NodeModel(torch.nn.Module):
#                 def __init__(self,hcs,act):
#                     super(NodeModel, self).__init__()
#                     self.hcs = hcs
#                     self.act = act
#                     self.lins1 = torch.nn.ModuleList()
#                     for i in range(2):
#                         self.lins1.append(torch.nn.Linear(2*self.hcs,2*self.hcs))
#                     self.lins2 = torch.nn.ModuleList()
#                     for i in range(2):
#                         self.lins2.append(torch.nn.Linear((2 + 2*N_scatter_feats)*self.hcs,(2 + 2*N_scatter_feats)*self.hcs))
#                     self.decoder = torch.nn.Linear((2 + 2*N_scatter_feats)*self.hcs,self.hcs)

#                 def forward(self, x, edge_index, edge_attr, u, batch):
#                     row, col = edge_index
#                     out = torch.cat([x[row],edge_attr],dim=1)
#                     for lin in self.lins1:
#                         out = self.act(lin(out))
#                     out = scatter_distribution(out,col,dim=0) #8*hcs
#                     out = torch.cat([x,out,u[batch]],dim=1)
#                     for lin in self.lins2:
#                         out = self.act(lin(out))
#                     return self.act(self.decoder(out))

#             class GlobalModel(torch.nn.Module):
#                 def __init__(self,hcs,act):
#                     super(GlobalModel, self).__init__()
#                     self.hcs = hcs
#                     self.act = act
#                     self.lins1 = torch.nn.ModuleList()
#                     for i in range(2):
#                         self.lins1.append(torch.nn.Linear((1 + N_scatter_feats)*self.hcs,(1 + N_scatter_feats)*self.hcs))
#                     self.decoder = torch.nn.Linear((1 + N_scatter_feats)*self.hcs,self.hcs)

#                 def forward(self, x, edge_index, edge_attr, u, batch):
#                     out = torch.cat([u,scatter_distribution(x,batch,dim=0)],dim=1) # 5*hcs
#                     for lin in self.lins1:
#                         out = self.act(lin(out))
#                     return self.act(self.decoder(out))

#             from torch_geometric.nn import MetaLayer
#             self.ops = torch.nn.ModuleList()
#             for i in range(N_metalayers):
#                 self.ops.append(MetaLayer(EdgeModel(self.hcs,self.act), NodeModel(self.hcs,self.act), GlobalModel(self.hcs,self.act)))

#             self.decoders = torch.nn.ModuleList()
#             self.decoders.append(torch.nn.Linear((1+N_metalayers)*self.hcs + N_scatter_feats*N_x_feats,10*self.hcs))
#             self.decoders.append(torch.nn.Linear(10*self.hcs,8*self.hcs))
#             self.decoders.append(torch.nn.Linear(8*self.hcs,6*self.hcs))
#             self.decoders.append(torch.nn.Linear(6*self.hcs,4*self.hcs))
#             self.decoders.append(torch.nn.Linear(4*self.hcs,2*self.hcs))

#             self.decoder = torch.nn.Linear(2*self.hcs,N_targets)

#         def forward(self,data):
#             x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

#             x = torch.cat([x,scatter_distribution(edge_attr,edge_index[1],dim=0)],dim=1)
#             u = torch.cat([scatter_distribution(x,batch,dim=0)],dim=1)

#             x0 = x.clone()

#             x = self.act(self.x_encoder(x))
#             edge_attr = self.act(self.edge_attr_encoder(edge_attr))
#             u = self.act(self.u_encoder(u))
#             out = u.clone()

#             for i, op in enumerate(self.ops):
#                 x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
#                 out = torch.cat([out,u.clone()],dim=1)

#             out = torch.cat([out,scatter_distribution(x0,batch,dim=0)],dim=1)
#             for lin in self.decoders:
#                 out = self.act(lin(out))

#             out = self.decoder(out)
#             return out

#         def configure_optimizers(self):
#             from FunctionCollection import lambda_lr
#             optimizer = torch.optim.Adam(self.parameters(), lr=args['lr'])
#             scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_lr)
#             return [optimizer], [scheduler]

#         def training_step(self, data, batch_idx):
#             label = self.y_post_processor(data.y)
#             if self.likelihood_fitting:
#                 output, cov = self.output_post_processor( self(data) )
#                 loss = self.crit(output, cov, label)
#             else:
#                 output = self.output_post_processor( self(data) )
#                 loss = self.crit(output, label)

#             acc = self.cal_acc(output, label)
#             if args['wandb_activated']:
#                 wandb.log({"Train Loss": loss.item(),
#                            "Train Acc": acc})
#             return {'loss': loss}

#         def validation_step(self, data, batch_idx):
#             label = self.y_post_processor(data.y)
#             if self.likelihood_fitting:
#                 output, cov = self.output_post_processor( self(data) )
#             else:
#                 output = self.output_post_processor( self(data) )

#             acc = self.cal_acc(output, label)
#             if args['wandb_activated']:
#                 wandb.log({"val_batch_acc": acc})
#             return {'val_batch_acc': torch.tensor(acc,dtype=torch.float)}

#         def validation_epoch_end(self, outputs):
#             avg_acc = torch.stack([x['val_batch_acc'] for x in outputs]).mean()
#             if args['wandb_activated']:
#                 wandb.log({"Val Acc": avg_acc})
#             return {'Val Acc': avg_acc}

#         def test_step(self, data, batch_idx):
#             label = self.y_post_processor(data.y)
#             if self.likelihood_fitting:
#                 output, cov = self.output_post_processor( self(data) )
#             else:
#                 output = self.output_post_processor( self(data) )
#             acc = self.cal_acc(output, label)
#             if args['wandb_activated']:
#                 wandb.log({"Test Acc": acc})
#             return {'test_batch_acc': torch.tensor(acc,dtype=torch.float)}

#         def test_epoch_end(self, outputs):
#             avg_acc = torch.stack([x['test_batch_acc'] for x in outputs]).mean()
#             if args['wandb_activated']:
#                 wandb.log({"Test Acc": avg_acc})
#             return {'Test Acc': avg_acc}

#     return Net()
コード例 #5
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    @torch.jit.script
    def time_edge_indeces(t, batch: torch.Tensor):
        time_sort = torch.argsort(t)
        graph_ids, graph_node_counts = torch.unique(batch, return_counts=True)
        batch_time_sort = torch.cat(
            [time_sort[batch[time_sort] == i] for i in graph_ids])
        time_edge_index = torch.cat([
            batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(1, -1)
        ],
                                    dim=0)

        tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1]
        li: List[int] = (tmp_ind + 1).tolist()
        time_edge_index[1, tmp_ind] = time_edge_index[0, [0] + li[:-1]]
        time_edge_index = torch.cat([
            time_edge_index,
            torch.cat([
                time_edge_index[1, -1].view(1, 1),
                time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1)
            ])
        ],
                                    dim=1)

        time_edge_index = time_edge_index[:, torch.argsort(time_edge_index[1])]
        return time_edge_index

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32

    print(
        "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]"
    )
    assert N_dom_feats == len(args['features'].split(', '))

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            def gaussian(alpha):
                phi = torch.exp(-1 * alpha.pow(2))
                return phi

            class RBF(torch.nn.Module):
                def __init__(self,
                             in_features,
                             out_features,
                             basis_func=gaussian):
                    super(RBF, self).__init__()
                    self.in_features = in_features
                    self.out_features = out_features
                    self.centres = torch.nn.Parameter(
                        torch.Tensor(out_features, in_features))
                    #                     self.log_sigmas = torch.nn.Parameter(torch.Tensor(out_features))
                    self.sigmas = torch.nn.Parameter(
                        torch.Tensor(out_features))
                    self.basis_func = basis_func
                    self.reset_parameters()

                def reset_parameters(self):
                    torch.nn.init.normal_(self.centres, 0, 1)
                    #                     torch.nn.init.constant_(self.log_sigmas, 2.3)
                    torch.nn.init.constant_(self.sigmas, 10)

                def forward(self, x):
                    size = (x.size(0), self.out_features, self.in_features)
                    x = x.unsqueeze(1).expand(size)
                    c = self.centres.unsqueeze(0).expand(size)
                    #                     distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(self.log_sigmas).unsqueeze(0)
                    distances = (x -
                                 c).pow(2).sum(-1) / self.sigmas.unsqueeze(0)
                    #                     print(distances.mean(), distances.std(), distances.min(), distances.max())
                    #                     return self.basis_func(distances)
                    return self.basis_func(distances)

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, clean_out=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)

                def forward(self, x):
                    return self.mlp(x)

            class RBF_scatter(torch.nn.Module):
                def __init__(self, in_features, out_features):
                    super(RBF_scatter, self).__init__()
                    self.RBF_batch_norm = torch.nn.BatchNorm1d(in_features)
                    self.RBF = RBF(in_features, out_features)
                    self.sum_batch_norm = torch.nn.BatchNorm1d(out_features)

                def forward(self, x, batch):
                    return self.sum_batch_norm(
                        scatter_sum(self.RBF(self.RBF_batch_norm(x)),
                                    batch,
                                    dim=0))

            class GRUConv(torch.nn.Module):
                def __init__(self, hcs=self.hcs, act=self.act):
                    super(GRUConv, self).__init__()
                    self.act = act
                    self.hcs = hcs
                    self.GRU = torch.nn.GRUCell(self.hcs * 2, self.hcs)

                    self.lin_CoC_msg = MLP(
                        [N_scatter_feats * self.hcs, self.hcs], clean_out=True)
                    self.lin_CoC_self = MLP([self.hcs, self.hcs],
                                            clean_out=True)

                    self.RBF_scatter = RBF_scatter(self.hcs, 4 * self.hcs)

                    self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs)

                    self.lin_x_msg = MLP([4 * self.hcs, self.hcs],
                                         clean_out=True)
                    self.lin_x_self = MLP([self.hcs, self.hcs], clean_out=True)

                    self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs)

                def forward(self, x, CoC, h, batch):
                    h = self.GRU(torch.cat([CoC[batch], x], dim=1), h)

                    msg = self.lin_CoC_msg(self.RBF_scatter(h, batch))
                    CoC = self.lin_CoC_self(CoC)

                    CoC = self.act(self.CoC_batch_norm(msg + CoC))

                    h = self.GRU(torch.cat([x, CoC[batch]], dim=1), h)

                    msg = self.lin_x_msg(self.RBF_scatter.RBF(h))
                    x = self.lin_x_self(x)

                    x = self.act(self.x_batch_norm(msg + x))
                    return x, CoC, h

            N_x_feats = 2 * (N_dom_feats + 4 *
                             (N_dom_feats + 1)) + N_dom_feats + 4 * (
                                 N_dom_feats + 1) + 1 + 4 + 3
            self.N_x_to_CoC_feats = N_dom_feats + 4 * (
                N_dom_feats + 1) + 4 + N_dom_feats + 4 * (N_dom_feats + 1) + 1
            N_CoC_feats = 3 + N_scatter_feats * (self.N_x_to_CoC_feats)

            self.x_encoder = MLP([N_x_feats, 2 * self.hcs, self.hcs])
            #             self.att = MLP([N_x_feats,1])
            self.CoC_encoder = MLP([N_CoC_feats, 2 * self.hcs, self.hcs])

            self.convs = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.convs.append(GRUConv())

#             self.decoderRBF = RBF(self.hcs,4*self.hcs)
#             self.decoder = MLP([self.hcs,self.hcs,self.hcs,N_outputs],clean_out=True)
            self.decoder_RBF_scatter = RBF_scatter(self.hcs, 4 * self.hcs)
            self.decoder = MLP(
                [(1 + N_scatter_feats) * self.hcs, 3 * self.hcs, N_outputs],
                clean_out=True)
#             self.decoder = MLP([(1+N_scatter_feats)*self.hcs,self.hcs,self.hcs])

#             self.decoders = torch.nn.ModuleList()
#             for _ in range(N_outputs):
#                 self.decoders.append(torch.nn.Linear(self.hcs,1))

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############

            time_edge_index = time_edge_indeces(x[:, -4], batch)

            time_edge_attr = self.edge_feature_constructor(x, time_edge_index)

            CoC, CoC_edge_attr = self.return_CoC_and_edge_attr(x, batch)

            x = torch.cat([
                x[time_edge_index[0]], CoC[batch], CoC_edge_attr,
                time_edge_attr, x
            ],
                          dim=1)

            CoC = torch.cat([
                CoC,
                scatter_distribution(
                    x[:, -self.N_x_to_CoC_feats:], batch, dim=0)
            ],
                            dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            h = torch.zeros((x.shape[0], self.hcs), device=self.device)

            for i in range(N_metalayers):
                x, CoC, h = self.convs[i](x, CoC, h, batch)

            CoC = torch.cat([CoC, self.decoder_RBF_scatter(x, batch)], dim=1)
            #             CoC = torch.cat([CoC,scatter_distribution(x, batch, dim=0)],dim=1)

            CoC = self.decoder(CoC)

            #             out = []
            #             for mlp in self.decoders:
            #                 out.append(mlp(CoC))
            #             CoC = torch.cat(out,dim=1)

            return CoC

        def return_CoC_and_edge_attr(self, x, batch):
            pos = x[:, -3:]
            charge = x[:, -5].view(-1, 1)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum(
                charge, batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos - CoC[batch]
            rho = torch.norm(cart, p=2, dim=1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)
            return CoC, CoC_edge_attr

        def edge_feature_constructor(self, x, edge_index):
            (frm, to) = edge_index
            pos = x[:, -3:]
            cart = pos[frm] - pos[to]

            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]

            diff = x[to, :-3] - x[frm, :-3]

            return torch.cat([cart.type_as(pos), rho, diff], dim=1)

    return Net
コード例 #6
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule, Print
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max, scatter_softmax
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ, mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 5
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32

    print(
        "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]"
    )
    assert N_dom_feats == len(args['features'].split(', '))

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.GELU()
            self.hcs = N_hcs

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, clean_out=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)

                def forward(self, x):
                    return self.mlp(x)

            class GRUConv(torch.nn.Module):
                def __init__(self, hcs=self.hcs, act=self.act):
                    super(GRUConv, self).__init__()
                    self.act = act
                    self.hcs = hcs
                    self.GRU = torch.nn.GRUCell(self.hcs * 2, self.hcs)

                    # self.scatter_norm = scatter_norm(self.hcs)
                    # self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True)
                    self.scatter_norm = scatter_att(self.hcs)  #version a
                    self.lin_CoC_msg = MLP([self.hcs, self.hcs],
                                           clean_out=True)  #version a

                    self.lin_CoC_self = MLP([self.hcs, self.hcs],
                                            clean_out=True)

                    self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs)

                    self.lin_x_msg = MLP([self.hcs, self.hcs], clean_out=True)
                    self.lin_x_self = MLP([self.hcs, self.hcs], clean_out=True)

                    self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs)

                def forward(self, x, CoC, h, batch):
                    h = self.GRU(torch.cat([CoC[batch], x], dim=1), h)

                    # msg = self.lin_CoC_msg( self.scatter_norm(h, batch) )
                    msg = self.lin_CoC_msg(self.scatter_norm(h, batch,
                                                             CoC))  #version a
                    CoC = self.lin_CoC_self(CoC)

                    CoC = self.act(self.CoC_batch_norm(msg + CoC))

                    h = self.GRU(torch.cat([x, CoC[batch]], dim=1), h)

                    msg = self.lin_x_msg(h)
                    x = self.lin_x_self(x)

                    x = self.act(self.x_batch_norm(msg + x))
                    return x, CoC, h

            class scatter_norm(torch.nn.Module):  #Original
                def __init__(self, hcs):
                    super(scatter_norm, self).__init__()
                    self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats *
                                                           hcs)

                def forward(self, x, batch):
                    return self.batch_norm(
                        scatter_distribution(x, batch, dim=0))

            class scatter_att(torch.nn.Module):  #Version a
                def __init__(self, hcs):
                    super(scatter_att, self).__init__()
                    self.att_lin = torch.nn.Linear(int(2 * hcs), 1)

                def forward(self, x, batch, CoC):
                    att = scatter_softmax(self.att_lin(
                        torch.cat([x, CoC[batch]], dim=1)),
                                          batch,
                                          dim=0)
                    return scatter_sum(att * x, batch, dim=0)

            N_x_feats = N_dom_feats  # + 4*(N_dom_feats + 1)
            N_CoC_feats = N_scatter_feats * N_x_feats + 3

            self.scatter_norm = scatter_norm(N_x_feats)
            self.x_encoder = MLP([N_x_feats, self.hcs])
            self.CoC_encoder = MLP([N_CoC_feats, self.hcs])

            self.convs = torch.nn.ModuleList()
            for _ in range(N_metalayers):
                self.convs.append(GRUConv())

            # self.scatter_norm2 = scatter_norm(self.hcs)
            self.scatter_norm2 = scatter_att(self.hcs * (2 + N_metalayers) /
                                             2)  #version a

            # self.decoder = MLP([(1+N_scatter_feats)*self.hcs,3*self.hcs,self.hcs,N_outputs],clean_out=True)
            self.decoder = MLP([(2 + N_metalayers) * self.hcs, 3 * self.hcs,
                                self.hcs, N_outputs],
                               clean_out=True)  #version a

#         def forward(self,data):
#             x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

        def forward(self, *data):
            #             print(data)
            #             print("================================")
            #             Print("Here")
            if type(data) == tuple:
                #                 print("here", len(data),data[0].shape)
                from torch_geometric.data import Data, Batch
                datalist = []
                for x in data:
                    if x.dim() > 2:
                        for tmp_x in x:
                            datalist.append(Data(x=tmp_x.squeeze()))
                    else:
                        datalist.append(Data(x=x.squeeze()))
#                     datalist.append(Data(x=x))
                data = Batch.from_data_list(datalist)
            try:
                x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            except AttributeError:
                x = data
                batch = torch.zeros(x.shape[0],
                                    device=model.device,
                                    dtype=torch.int64)
#             Print("To Here")
###############
            x = x.float()
            ###############
            #             print(x.shape)
            #             print(x)

            CoC = scatter_sum(x[:, -3:] * x[:, -5].unsqueeze(-1), batch,
                              dim=0) / scatter_sum(
                                  x[:, -5].unsqueeze(-1), batch, dim=0)
            CoC[CoC.isnan()] = 0
            CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            h = torch.zeros((x.shape[0], self.hcs), device=self.device)
            out = CoC.clone()  #version a
            for conv in self.convs:
                x, CoC, h = conv(x, CoC, h, batch)
                out = torch.cat([out, CoC.clone()], dim=1)  #version a

            # CoC = torch.cat([CoC,self.scatter_norm2(x, batch, CoC)],dim=1)
            CoC = torch.cat([out, self.scatter_norm2(x, batch, out)],
                            dim=1)  #version a

            CoC = self.decoder(CoC)
            return CoC

    return Net
コード例 #7
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast
    #     from torch_geometric.nn import GATConv, FeaStConv

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats
            #             N_u_feats = N_scatter_feats*N_x_feats

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)
            self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs)
            self.CoC_encoder = torch.nn.Linear(3 + N_scatter_feats * N_x_feats,
                                               self.hcs)

            class EdgeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(EdgeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins.append(
                            torch.nn.Linear(4 * self.hcs, 4 * self.hcs))
                    self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs)

                def forward(self, src, dest, edge_attr, CoC, batch):
                    # x: [N, F_x], where N is the number of nodes.
                    # src, dest: [E, F_x], where E is the number of edges.
                    # edge_attr: [E, F_e]
                    # edge_index: [2, E] with max entry N - 1.
                    # u: [B, F_u], where B is the number of graphs.
                    # batch: [N] with max entry B - 1.
                    out = torch.cat([src, dest, edge_attr, CoC[batch]], dim=1)
                    for lin in self.lins:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class NodeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(NodeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins1 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins1.append(
                            torch.nn.Linear(2 * self.hcs, 2 * self.hcs))
                    self.lins2 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins2.append(
                            torch.nn.Linear(
                                (2 + 2 * N_scatter_feats) * self.hcs,
                                (2 + 2 * N_scatter_feats) * self.hcs))
                    self.decoder = torch.nn.Linear(
                        (2 + 2 * N_scatter_feats) * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, CoC, batch):
                    row, col = edge_index
                    out = torch.cat([x[row], edge_attr], dim=1)
                    for lin in self.lins1:
                        out = self.act(lin(out))
                    out = scatter_distribution(out, col, dim=0)  #8*hcs
                    out = torch.cat([x, out, CoC[batch]], dim=1)
                    for lin in self.lins2:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class GlobalModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(GlobalModel, self).__init__()

                def forward(self, x, edge_index, edge_attr, CoC, batch):
                    return CoC

            from torch_geometric.nn import MetaLayer
            self.ops = torch.nn.ModuleList()
            self.GRUCells = torch.nn.ModuleList()
            self.lins1 = torch.nn.ModuleList()
            self.lins2 = torch.nn.ModuleList()
            self.lins3 = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.ops.append(
                    MetaLayer(EdgeModel(self.hcs, self.act),
                              NodeModel(self.hcs, self.act),
                              GlobalModel(self.hcs, self.act)))
                self.GRUCells.append(
                    torch.nn.GRUCell(self.hcs * 2 + 4 + N_x_feats, self.hcs))
                self.lins1.append(
                    torch.nn.Linear((1 + N_scatter_feats) * self.hcs,
                                    (1 + N_scatter_feats) * self.hcs))
                self.lins2.append(
                    torch.nn.Linear((1 + N_scatter_feats) * self.hcs,
                                    self.hcs))
                self.lins3.append(torch.nn.Linear(2 * self.hcs, self.hcs))

            self.decoders = torch.nn.ModuleList()
            self.decoders.append(torch.nn.Linear(self.hcs, self.hcs))
            self.decoders.append(torch.nn.Linear(self.hcs, self.hcs))

            self.decoder = torch.nn.Linear(self.hcs, N_outputs)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            pos = x[:, -3:]

            x = torch.cat(
                [x, scatter_distribution(edge_attr, edge_index[1], dim=0)],
                dim=1)

            CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, 0].view(-1, 1), batch, dim=0)
            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            CoC_edge_index = torch.cat([
                torch.arange(x.shape[0]).view(1, -1).type_as(batch),
                batch.view(1, -1)
            ],
                                       dim=0)

            cart = pos[CoC_edge_index[0], -3:] - CoC[CoC_edge_index[1], :3]
            del pos

            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]

            CoC_edge_attr = torch.cat(
                [cart.type_as(x),
                 rho.type_as(x), x[CoC_edge_index[0]]], dim=1)

            x = self.act(self.x_encoder(x))
            edge_attr = self.act(self.edge_attr_encoder(edge_attr))
            CoC = self.act(self.CoC_encoder(CoC))

            #             u = torch.zeros( (batch.max() + 1, self.hcs) ).type_as(x)
            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i, op in enumerate(self.ops):
                x, edge_attr, CoC = op(x, edge_index, edge_attr, CoC, batch)
                h = self.act(self.GRUCells[i](torch.cat([
                    CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr
                ],
                                                        dim=1), h))
                CoC = self.act(self.lins1[i](torch.cat(
                    [CoC, scatter_distribution(h, batch, dim=0)], dim=1)))
                CoC = self.act(self.lins2[i](CoC))
                h = self.act(self.GRUCells[i](torch.cat([
                    CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr
                ],
                                                        dim=1), h))
                x = self.act(self.lins3[i](torch.cat([x, h], dim=1)))

            for lin in self.decoders:
                CoC = self.act(lin(CoC))

            CoC = self.decoder(CoC)
            return CoC

    return Net
コード例 #8
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ, mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 5
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32

    print(
        "if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]"
    )
    assert N_dom_feats == len(args['features'].split(', '))

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            def gaussian(alpha):
                phi = torch.exp(-1 * alpha.pow(2))
                return phi

            class RBF(torch.nn.Module):
                def __init__(self,
                             in_features,
                             out_features,
                             basis_func=gaussian):
                    super(RBF, self).__init__()
                    self.in_features = in_features
                    self.out_features = out_features
                    self.centres = torch.nn.Parameter(
                        torch.Tensor(out_features, in_features))
                    self.log_sigmas = torch.nn.Parameter(
                        torch.Tensor(out_features))
                    #                     self.sigmas = torch.nn.Parameter(torch.Tensor(out_features))
                    self.basis_func = basis_func
                    self.reset_parameters()

                def reset_parameters(self):
                    torch.nn.init.normal_(self.centres, 0, 1)
                    torch.nn.init.constant_(self.log_sigmas, 0)
#                     torch.nn.init.constant_(self.sigmas, 10)

                def forward(self, x):
                    size = (x.size(0), self.out_features, self.in_features)
                    x = x.unsqueeze(1).expand(size)
                    c = self.centres.unsqueeze(0).expand(size)
                    distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(
                        self.log_sigmas).unsqueeze(0)
                    #                     distances = (x - c).pow(2).sum(-1) / self.sigmas.unsqueeze(0)
                    #                     print(distances.mean(), distances.std(), distances.min(), distances.max())
                    return self.basis_func(distances)

            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act=self.act, clean_out=False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1, len(hcs_list)):
                        mlp.append(
                            torch.nn.Linear(hcs_list[i - 1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)

                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)

                def forward(self, x):
                    return self.mlp(x)

            class RBF_scatter(torch.nn.Module):
                def __init__(self,
                             in_features,
                             out_features,
                             basis_func=gaussian):
                    super(RBF_scatter, self).__init__()
                    self.RBF_batch_norm = torch.nn.BatchNorm1d(in_features)
                    self.RBF = RBF(in_features, out_features, basis_func)
                    self.sum_batch_norm = torch.nn.BatchNorm1d(out_features *
                                                               N_scatter_feats)

                def forward(self, x, batch):
                    return self.sum_batch_norm(
                        scatter_distribution(self.RBF(self.RBF_batch_norm(x)),
                                             batch,
                                             dim=0))

            N_x_feats = N_dom_feats  # + 4*(N_dom_feats + 1)

            #             self.RBF_scatter = RBF_scatter(N_x_feats,2*self.hcs, basis_func=gaussian)
            self.encoder = MLP([N_dom_feats, 2 * self.hcs])
            self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats * 2 *
                                                   self.hcs)
            self.decoder = MLP([
                N_scatter_feats * 2 * self.hcs, 5 * self.hcs, self.hcs,
                N_outputs
            ],
                               clean_out=True)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############
            #             x = self.RBF_scatter(x,batch)
            x = self.batch_norm(
                scatter_distribution(self.encoder(x), batch, dim=0))
            x = self.decoder(x)

            return x

    return Net
コード例 #9
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl
    
    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                            out: Optional[torch.Tensor] = None,
                            dim_size: Optional[int] = None,
                            unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ,mean,var,maximum,minimum],dim=1)
    
    N_edge_feats = args['N_edge_feats'] #6
    N_dom_feats = args['N_dom_feats']#6
    N_scatter_feats = 5
#     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers'] #10
    N_hcs = args['N_hcs'] #32
    
    print("if features = x, Charge should be at x[:,-5], time at x[:,-4] and pos at x[:,-3:]")
    assert N_dom_feats == len(args['features'].split(', '))
    
    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False
    
    class Net(customModule):
        def __init__(self):
            super(Net, self).__init__(crit, y_post_processor, output_post_processor, cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs
            
            class MLP(torch.nn.Module):
                def __init__(self, hcs_list, act = self.act, clean_out = False):
                    super(MLP, self).__init__()
                    mlp = []
                    for i in range(1,len(hcs_list)):
                        mlp.append(torch.nn.Linear(hcs_list[i-1], hcs_list[i]))
                        mlp.append(torch.nn.BatchNorm1d(hcs_list[i]))
                        mlp.append(act)
                    
                    if clean_out:
                        self.mlp = torch.nn.Sequential(*mlp[:-2])
                    else:
                        self.mlp = torch.nn.Sequential(*mlp)
                def forward(self, x):
                    return self.mlp(x)
            
#             class GRUConv(torch.nn.Module):
#                 def __init__(self,hcs = self.hcs, act = self.act):
#                     super(GRUConv, self).__init__()
#                     self.act = act
#                     self.hcs = hcs
#                     self.GRU = torch.nn.GRUCell(self.hcs*2,self.hcs)
                    
#                     self.scatter_norm = scatter_norm(self.hcs)
#                     self.lin_CoC_msg = MLP([N_scatter_feats*self.hcs, self.hcs],clean_out = True)
#                     self.lin_CoC_self = MLP([self.hcs, self.hcs],clean_out = True)
                    
#                     self.CoC_batch_norm = torch.nn.BatchNorm1d(self.hcs)
                    
#                     self.lin_x_msg = MLP([self.hcs, self.hcs],clean_out = True)
#                     self.lin_x_self = MLP([self.hcs, self.hcs],clean_out = True)
                    
#                     self.x_batch_norm = torch.nn.BatchNorm1d(self.hcs)

#                 def forward(self, x, CoC, h, batch):
#                     h = self.act( self.GRU( torch.cat([CoC[batch], x], dim=1), h) )
                    
#                     msg = self.lin_CoC_msg( self.scatter_norm(h, batch) )
#                     CoC = self.lin_CoC_self(CoC)
                    
#                     CoC = self.act( self.CoC_batch_norm(msg+CoC) )
                    
#                     h = self.act( self.GRU( torch.cat([x, CoC[batch]], dim=1), h) )
                    
#                     msg = self.lin_x_msg(h)
#                     x = self.lin_x_self(x)
                    
#                     x = self.act( self.x_batch_norm(msg+x) )
#                     return x, CoC, h
            
#             class AttConv(torch.nn.Module):
#                 def __init__(self,in_hcs = [self.hcs, self.hcs], out_hcs = self.hcs, heads = 1):
#                     super(AttConv,self).__init__()
                    
#                     self.heads = heads
#                     self.out_hcs = out_hcs
                    
#                     self.lin_key = torch.nn.Linear(in_hcs[0], heads*out_hcs)
#                     self.lin_query = torch.nn.Linear(in_hcs[1], heads*out_hcs)
#                     self.lin_value = torch.nn.Linear(in_hcs[0], heads*out_hcs)
                    
#                     self.sqrt_d = torch.sqrt(out_hcs)
                    
#                     self.reset_parameters()
                    
#                 def reset_parameters(self):
#                     self.lin_key.reset_parameters()
#                     self.lin_query.reset_parameters()
#                     self.lin_value.reset_parameters()
                
#                 def forward(self, x, CoC, batch):
#                     key = self.lin_key(x).view(-1,self.heads,self.out_hcs)
#                     query = self.lin_query(CoC
            
            class scatter_norm(torch.nn.Module):
                def __init__(self, hcs):
                    super(scatter_norm, self).__init__()
                    self.batch_norm = torch.nn.BatchNorm1d(N_scatter_feats*hcs)
                def forward(self, x, batch):
                    return self.batch_norm(scatter_distribution(x,batch,dim=0))


            N_x_feats = N_dom_feats# + 4*(N_dom_feats + 1)
            N_CoC_feats = N_scatter_feats*N_x_feats + 3

            self.scatter_norm = scatter_norm(N_x_feats)
            self.x_encoder = MLP([N_x_feats,self.hcs])
            self.CoC_encoder = MLP([N_CoC_feats,self.hcs])
            
            self.TConv = torch_geometric.nn.TransformerConv(in_channels = [self.hcs,self.hcs],
                                                            out_channels = self.hcs,
                                                            heads = N_metalayers)
            
            self.decoder = MLP([(N_metalayers)*self.hcs,3*self.hcs,self.hcs,N_outputs],clean_out=True)

        def forward(self,data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############

            CoC = scatter_sum( x[:,-3:]*x[:,-5].view(-1,1), batch, dim=0) / scatter_sum(x[:,-5].view(-1,1), batch, dim=0)
            CoC = torch.cat([CoC,self.scatter_norm(x,batch)],dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)
            
            CoC_x = torch.cat([CoC,x],dim=0)
            
            edge_index = self.return_edge_index(batch)
            
            CoC_x = self.TConv(CoC_x, edge_index)

            CoC = self.decoder(CoC_x[batch.unique()])

            return CoC
        
        def return_edge_index(self,batch):
            offset = batch.max() + 1
            frm = torch.arange(offset, offset + batch.shape[0],dtype=torch.long).view(1,-1)
            to = batch.view(1,-1)
            return torch.cat([frm,to],dim=0).contiguous()
    return Net
コード例 #10
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from torch_geometric_temporal import nn

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast
    from torch_geometric.nn import GATConv

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)
            self.conv = nn.recurrent.temporalgcn.TGCN(self.hcs, self.hcs)
            self.decoder = torch.nn.Linear(self.hcs * 4, N_outputs)

            self.GATConv = GATConv(self.hcs, self.hcs, 3, add_self_loops=False)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

            x = torch.cat(
                [x, scatter_distribution(edge_attr, edge_index[1], dim=0)],
                dim=1)
            #             x0 = x.clone()

            time_sort = torch.argsort(x[:, 1])
            batch_time_sort = time_sort[torch.argsort(batch[time_sort])]
            time_edge_index = torch.cat([
                batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(
                    1, -1)
            ],
                                        dim=0)
            graph_ids, graph_node_counts = batch.unique(return_counts=True)
            tmp_bool = torch.ones(time_edge_index.shape[1], dtype=bool)
            tmp_bool[(torch.cumsum(graph_node_counts, 0) - 1)[:-1]] = False
            time_edge_index = time_edge_index[:, tmp_bool]
            time_edge_index = torch.cat(
                [time_edge_index, time_edge_index.flip(0)], dim=1)

            x = self.act(self.x_encoder(x))

            x, (e, w) = self.GATConv(x,
                                     edge_index,
                                     return_attention_weights=True)
            return x, e, w
            print(x, e, w)

            h = self.act(self.conv(x, time_edge_index))
            for i in range(N_metalayers):
                h = self.act(self.conv(x, time_edge_index, H=h))
            x = scatter_distribution(h, batch, dim=0)
            x = self.decoder(x)

            return x

    return Net
コード例 #11
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule
    import pytorch_lightning as pl

    from torch_geometric_temporal import nn

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = N_dom_feats + N_scatter_feats * N_edge_feats
            N_u_feats = N_scatter_feats * N_x_feats

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)
            self.edge_attr_encoder = torch.nn.Linear(N_edge_feats, self.hcs)
            self.u_encoder = torch.nn.Linear(N_u_feats, self.hcs)

            class EdgeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(EdgeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins.append(
                            torch.nn.Linear(4 * self.hcs, 4 * self.hcs))
                    self.decoder = torch.nn.Linear(4 * self.hcs, self.hcs)

                def forward(self, src, dest, edge_attr, u, batch):
                    # x: [N, F_x], where N is the number of nodes.
                    # src, dest: [E, F_x], where E is the number of edges.
                    # edge_attr: [E, F_e]
                    # edge_index: [2, E] with max entry N - 1.
                    # u: [B, F_u], where B is the number of graphs.
                    # batch: [N] with max entry B - 1.
                    out = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
                    for lin in self.lins:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class NodeModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(NodeModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins1 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins1.append(
                            torch.nn.Linear(2 * self.hcs, 2 * self.hcs))
                    self.lins2 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins2.append(
                            torch.nn.Linear(
                                (2 + 2 * N_scatter_feats) * self.hcs,
                                (2 + 2 * N_scatter_feats) * self.hcs))
                    self.decoder = torch.nn.Linear(
                        (2 + 2 * N_scatter_feats) * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, u, batch):
                    row, col = edge_index
                    out = torch.cat([x[row], edge_attr], dim=1)
                    for lin in self.lins1:
                        out = self.act(lin(out))
                    out = scatter_distribution(out, col, dim=0)  #8*hcs
                    out = torch.cat([x, out, u[batch]], dim=1)
                    for lin in self.lins2:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            class GlobalModel(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(GlobalModel, self).__init__()
                    self.hcs = hcs
                    self.act = act
                    self.lins1 = torch.nn.ModuleList()
                    for i in range(2):
                        self.lins1.append(
                            torch.nn.Linear((1 + N_scatter_feats) * self.hcs,
                                            (1 + N_scatter_feats) * self.hcs))
                    self.decoder = torch.nn.Linear(
                        (1 + N_scatter_feats) * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, u, batch):
                    out = torch.cat(
                        [u, scatter_distribution(x, batch, dim=0)],
                        dim=1)  # 5*hcs
                    for lin in self.lins1:
                        out = self.act(lin(out))
                    return self.act(self.decoder(out))

            self.tgcn_starter = nn.recurrent.temporalgcn.TGCN(
                N_x_feats, self.hcs)

            from torch_geometric.nn import MetaLayer
            self.ops = torch.nn.ModuleList()
            self.tgcns = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.ops.append(
                    MetaLayer(EdgeModel(self.hcs, self.act),
                              NodeModel(self.hcs, self.act),
                              GlobalModel(self.hcs, self.act)))
                self.tgcns.append(
                    nn.recurrent.temporalgcn.TGCN(self.hcs, self.hcs))

            self.decoders = torch.nn.ModuleList()
            self.decoders.append(
                torch.nn.Linear(
                    (1 + N_metalayers) * self.hcs + N_scatter_feats * self.hcs,
                    10 * self.hcs))
            self.decoders.append(torch.nn.Linear(10 * self.hcs, 8 * self.hcs))
            self.decoders.append(torch.nn.Linear(8 * self.hcs, 6 * self.hcs))
            self.decoders.append(torch.nn.Linear(6 * self.hcs, 4 * self.hcs))
            self.decoders.append(torch.nn.Linear(4 * self.hcs, 2 * self.hcs))

            self.decoder = torch.nn.Linear(2 * self.hcs, N_outputs)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

            x = torch.cat(
                [x, scatter_distribution(edge_attr, edge_index[1], dim=0)],
                dim=1)
            u = torch.cat([scatter_distribution(x, batch, dim=0)], dim=1)

            time_sort = torch.argsort(x[:, 1])
            batch_time_sort = time_sort[torch.argsort(batch[time_sort])]
            time_edge_index = torch.cat([
                batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(
                    1, -1)
            ],
                                        dim=0)
            graph_ids, graph_node_counts = batch.unique(return_counts=True)
            tmp_bool = torch.ones(time_edge_index.shape[1], dtype=bool)
            tmp_bool[(torch.cumsum(graph_node_counts, 0) - 1)[:-1]] = False
            time_edge_index = time_edge_index[:, tmp_bool]
            time_edge_index = torch.cat(
                [time_edge_index, time_edge_index.flip(0)], dim=1)
            time_edge_index = torch.cat([edge_index, time_edge_index], dim=1)

            h = self.tgcn_starter(x, time_edge_index)

            x = self.act(self.x_encoder(x))
            edge_attr = self.act(self.edge_attr_encoder(edge_attr))
            u = self.act(self.u_encoder(u))
            out = u.clone()

            for i in range(N_metalayers):
                x, edge_attr, u = self.ops[i](x, edge_index, edge_attr, u,
                                              batch)
                h = self.act(self.tgcns[i](x, time_edge_index, H=h))
                out = torch.cat([out, u.clone()], dim=1)

            out = torch.cat([out, scatter_distribution(h, batch, dim=0)],
                            dim=1)
            for lin in self.decoders:
                out = self.act(lin(out))

            out = self.decoder(out)
            return out

    return Net
コード例 #12
0
def Load_model(name, args):
    from FunctionCollection import Loss_Functions, customModule, edge_feature_constructor
    import pytorch_lightning as pl

    from typing import Optional

    import torch
    from torch_scatter import scatter_sum, scatter_min, scatter_max
    from torch_scatter.utils import broadcast

    @torch.jit.script
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, maximum, minimum], dim=1)

    N_edge_feats = args['N_edge_feats']  #6
    N_dom_feats = args['N_dom_feats']  #6
    N_scatter_feats = 4
    #     N_targets = args['N_targets']
    N_outputs = args['N_outputs']
    N_metalayers = args['N_metalayers']  #10
    N_hcs = args['N_hcs']  #32
    #Possibly add (edge/node/global)_layers

    crit, y_post_processor, output_post_processor, cal_acc = Loss_Functions(
        name, args)
    likelihood_fitting = True if name[-4:] == 'NLLH' else False

    class Net(customModule):
        def __init__(self):
            super(Net,
                  self).__init__(crit, y_post_processor, output_post_processor,
                                 cal_acc, likelihood_fitting, args)

            self.act = torch.nn.SiLU()
            self.hcs = N_hcs

            N_x_feats = 2 * N_dom_feats + N_edge_feats

            self.x_encoder = torch.nn.Linear(N_x_feats, self.hcs)

            #             class Conversation(torch.nn.Module):
            #                 def __init__(self,hcs,act):
            #                     super(Conversation, self).__init__()
            #                     self.act = act
            #                     self.hcs = hcs
            #                     self.GRU = torch.nn.GRUCell(3*self.hcs,self.hcs)
            #                     self.lin_msg1 = torch.nn.Linear((1+N_scatter_feats)*self.hcs,self.hcs)
            #                     self.lin_msg2 = torch.nn.Linear((1+N_scatter_feats)*self.hcs,self.hcs)

            #                 def forward(self, x, edge_index, edge_attr, batch, h):
            #                     (frm, to) = edge_index

            #                     h = self.act( self.GRU( torch.cat([x[to],x[frm],edge_attr],dim=1), h ) )
            #                     x = self.act( self.lin_msg1( torch.cat([x,scatter_distribution(h, to, dim=0)],dim=1) ) )

            #                     h = self.act( self.GRU( torch.cat([x[frm],x[to],edge_attr],dim=1), h) )
            #                     x = self.act( self.lin_msg2( torch.cat([x,scatter_distribution(h, frm, dim=0)],dim=1) ) )
            #                     return x

            class GRUConv(torch.nn.Module):
                def __init__(self, hcs, act):
                    super(GRUConv, self).__init__()
                    self.hcs = hcs
                    self.act = act

                    self.GRU = torch.nn.GRUCell(self.hcs, self.hcs)
                    self.lin = torch.nn.Linear(2 * self.hcs, self.hcs)

                def forward(self, x, edge_index, edge_attr, batch, h):
                    (frm, to) = edge_index
                    h = self.act(self.GRU(x, h))
                    x = torch.cat([x, h[frm]], dim=1)
                    x = self.act(self.lin(x))
                    return x

            self.GRUConvs = torch.nn.ModuleList()
            #             self.ConvConvs = torch.nn.ModuleList()
            for i in range(N_metalayers):
                self.GRUConvs.append(GRUConv(self.hcs, self.act))
#                 self.ConvConvs.append(Conversation(self.hcs,self.act)

            self.decoders = torch.nn.ModuleList()
            self.decoders.append(torch.nn.Linear(4 * self.hcs, self.hcs))
            self.decoders.append(torch.nn.Linear(self.hcs, self.hcs))

            self.decoder = torch.nn.Linear(self.hcs, N_outputs)

        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch

            time_sort = torch.argsort(x[:, 1])
            graph_ids, graph_node_counts = batch.unique(return_counts=True)
            batch_time_sort = torch.cat(
                [time_sort[batch[time_sort] == i] for i in graph_ids])
            time_edge_index = torch.cat([
                batch_time_sort[:-1].view(1, -1), batch_time_sort[1:].view(
                    1, -1)
            ],
                                        dim=0)

            tmp_ind = (torch.cumsum(graph_node_counts, 0) - 1)[:-1]
            time_edge_index[1,
                            tmp_ind] = time_edge_index[0,
                                                       [0] + (tmp_ind +
                                                              1).tolist()[:-1]]
            time_edge_index = torch.cat([
                time_edge_index,
                torch.cat([
                    time_edge_index[1, -1].view(1, 1),
                    time_edge_index[0, (tmp_ind + 1)[-1]].view(1, 1)
                ])
            ],
                                        dim=1)

            time_edge_index = time_edge_index[:,
                                              torch.argsort(time_edge_index[1]
                                                            )]

            edge_attr = edge_feature_constructor(x, time_edge_index)

            x = torch.cat([x, edge_attr, x[time_edge_index[0]]], dim=1)

            x = self.act(self.x_encoder(x))

            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i, conv in enumerate(self.GRUConvs):
                x = conv(x, time_edge_index, edge_attr, batch, h)

            out = scatter_distribution(x, batch, dim=0)

            for lin in self.decoders:
                out = self.act(lin(out))

            out = self.decoder(out)
            return out

    return Net