Ejemplo n.º 1
0
    def __init__(self,
                 n_features,
                 n_classes,
                 num_layers,
                 hidden_gcn,
                 hidden_fc,
                 edge_index,
                 mode='cat',
                 graph_agg_mode='mean'):
        super().__init__()
        self.edge_index = edge_index
        self.conv1 = GCNConv(n_features, hidden_gcn)
        self.convs = torch.nn.ModuleList()
        self.graph_agg_mode = graph_agg_mode
        self.mode = mode

        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_gcn, hidden_gcn))

        if mode == 'cat':
            if graph_agg_mode == 'set2set':
                self.lin1 = Linear(2 * num_layers * hidden_gcn, hidden_fc)
                self.set2setpooling = Set2Set(num_layers * hidden_gcn, 2)
            else:
                self.lin1 = Linear(num_layers * hidden_gcn, hidden_fc)

        else:
            if graph_agg_mode == 'set2set':
                self.lin1 = Linear(hidden_gcn * 2, hidden_fc)
                self.set2setpooling = Set2Set(hidden_gcn, 2)
            else:
                self.lin1 = Linear(hidden_gcn, hidden_fc)

        self.lin2 = Linear(hidden_fc, n_classes)
Ejemplo n.º 2
0
    def __init__(self,
                 num_layer,
                 emb_dim,
                 num_tasks,
                 JK="last",
                 drop_ratio=0,
                 graph_pooling="mean",
                 gnn_type="gin"):
        super(GNN_graphCL, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type=gnn_type)
        self.proj_head = nn.Sequential(nn.Linear(emb_dim, 128),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(128, 128))

        #Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            if self.JK == "concat":
                self.pool = GlobalAttention(
                    gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
            else:
                self.pool = GlobalAttention(
                    gate_nn=torch.nn.Linear(emb_dim, 1))
        elif graph_pooling[:-1] == "set2set":
            set2set_iter = int(graph_pooling[-1])
            if self.JK == "concat":
                self.pool = Set2Set((self.num_layer + 1) * emb_dim,
                                    set2set_iter)
            else:
                self.pool = Set2Set(emb_dim, set2set_iter)
        else:
            raise ValueError("Invalid graph pooling type.")

        #For graph-level binary classification
        if graph_pooling[:-1] == "set2set":
            self.mult = 2
        else:
            self.mult = 1

        if self.JK == "concat":
            self.graph_pred_linear = torch.nn.Linear(
                self.mult * (self.num_layer + 1) * self.emb_dim,
                self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim,
                                                     self.num_tasks)
Ejemplo n.º 3
0
    def __init__(self,
                 atom_vertex_dim,
                 atom_edge_dim,
                 orbital_vertex_dim=NotImplemented,
                 orbital_edge_dim=NotImplemented,
                 output_dim=NotImplemented,
                 mp_step=6,
                 s2s_step=6):
        super(MultiNet, self).__init__()
        self.atom_vertex_dim = atom_vertex_dim
        self.atom_edge_dim = atom_edge_dim
        self.orbital_vertex_dim = orbital_vertex_dim
        self.orbital_edge_dim = orbital_edge_dim
        self.output_dim = output_dim
        self.mp_step = mp_step
        self.s2s_step = s2s_step

        # atom net
        atom_edge_gc = nn.Sequential(nn.Linear(atom_edge_dim[1], atom_vertex_dim[1] ** 2), nn.Dropout(0.2))

        self.atom_vertex_conv = NNConv(atom_vertex_dim[1], atom_vertex_dim[1], atom_edge_gc, aggr="mean", root_weight=True)
        self.atom_vertex_gru = nn.GRU(atom_vertex_dim[1], atom_vertex_dim[1])

        self.atom_s2s = Set2Set(atom_vertex_dim[1], processing_steps=s2s_step)
        self.atom_lin0 = nn.Sequential(nn.Linear(atom_vertex_dim[0], 2 * atom_vertex_dim[0]), nn.CELU(),
                                       nn.Linear(2 * atom_vertex_dim[0], atom_vertex_dim[1]), nn.CELU())
        self.atom_lin1 = nn.Sequential(nn.Linear(atom_edge_dim[0], 2 * atom_edge_dim[0]), nn.CELU(),
                                       nn.Linear(2 * atom_edge_dim[0], atom_edge_dim[1]), nn.CELU())
        self.atom_lin2 = nn.Sequential(nn.Linear(2 * atom_vertex_dim[1], 4 * atom_vertex_dim[1]), nn.CELU())

        # orbital net
        orbital_edge_gc = nn.Sequential(nn.Linear(orbital_edge_dim[1], orbital_vertex_dim[1] ** 2), nn.Dropout(0.2))

        self.orbital_vertex_conv = NNConv(orbital_vertex_dim[1], orbital_vertex_dim[1], orbital_edge_gc, aggr="mean", root_weight=True)
        self.orbital_vertex_gru = nn.GRU(orbital_vertex_dim[1], orbital_vertex_dim[1])

        self.orbital_s2s = Set2Set(orbital_vertex_dim[1], processing_steps=s2s_step)
        self.orbital_lin0 = nn.Sequential(nn.Linear(orbital_vertex_dim[0], 2 * orbital_vertex_dim[0]), nn.CELU(),
                                          nn.Linear(2 * orbital_vertex_dim[0], orbital_vertex_dim[1]), nn.CELU())
        self.orbital_lin1 = nn.Sequential(nn.Linear(orbital_edge_dim[0], 2 * orbital_edge_dim[0]), nn.CELU(),
                                          nn.Linear(2 * orbital_edge_dim[0], orbital_edge_dim[1]), nn.CELU())
        self.orbital_lin2 = nn.Sequential(nn.Linear(2 * orbital_vertex_dim[1], 4 * orbital_vertex_dim[1]), nn.CELU())

        # cross net
        self.cross_lin0 = nn.Sequential(
            nn.Linear(4 * atom_vertex_dim[1] + 4 * orbital_vertex_dim[1], 4 * output_dim),
            nn.CELU(),
            nn.Linear(4 * output_dim, output_dim)
        )
        self.cross_o2a_lin = nn.Sequential(nn.Linear(orbital_vertex_dim[1], 2 * orbital_vertex_dim[1]), nn.CELU(),
                                           nn.Linear(2 * orbital_vertex_dim[1], int(atom_vertex_dim[1] / 2)), nn.CELU())
        self.cross_o2a_s2s = Set2Set(int(atom_vertex_dim[1] / 2), processing_steps=s2s_step)
        self.cross_o2a_gru = nn.GRU(atom_vertex_dim[1], atom_vertex_dim[1])
        self.cross_a2o_lin = nn.Sequential(nn.Linear(atom_vertex_dim[1], 2 * atom_vertex_dim[1]), nn.CELU(),
                                           nn.Linear(2 * atom_vertex_dim[1], orbital_vertex_dim[1]), nn.CELU())
        self.cross_a2o_gru = nn.GRU(orbital_vertex_dim[1], orbital_vertex_dim[1])
Ejemplo n.º 4
0
 def __init__(self, num_of_megnetblock) -> None:
     super().__init__()
     self.atom_preblock = ff(71)
     self.bond_preblock = ff(100)
     # self.firstblock = FirstMegnetBlock()
     self.fullblocks = torch.nn.ModuleList(
         [FullMegnetBlock() for i in range(num_of_megnetblock)])
     # self.fullblocks = torch.nn.ModuleList(
     # [EncoderBlock() for i in range(num_of_megnetblock)])
     self.set2set_v = Set2Set(in_channels=32, processing_steps=3)
     self.set2set_e = Set2Set(in_channels=32, processing_steps=3)
     self.output_layer = ff_output(input_dim=128, output_dim=41)
Ejemplo n.º 5
0
    def __init__(self):
        super(Net, self).__init__()
        internal_dim = 256
        self.lin0 = torch.nn.Linear(8, internal_dim)

        m_nn = Sequential(Linear(5, 128), ReLU(),
                          Linear(128, internal_dim * internal_dim))
        self.conv = NNConv(internal_dim,
                           internal_dim,
                           m_nn,
                           aggr='mean',
                           root_weight=False)
        self.gru = GRU(internal_dim, internal_dim)

        self.set2set = Set2Set(internal_dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * internal_dim, internal_dim)
        self.lin2 = torch.nn.Linear(internal_dim, 242)

        self.lin_edge = nn.Embedding(8, 128)
        self.node_embb = nn.Embedding(5, 5)

        self.lin6 = nn.Sequential(
            nn.Linear(2 * internal_dim, 128),  # egde_attr_size,
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 1)  # egde_attr_size,
        )
Ejemplo n.º 6
0
 def __init__(self,
              node_input_dim=18,
              edge_input_dim=7,
              hidden_dim0=128,
              hidden_dim1=64,
              hidden_dim2=32,
              hidden_dim3=32,
              hidden_dim4=16,
              hidden_dim5=16,
              output_dim=1):
     super(BGNN, self).__init__()
     self.ffnn = torch.nn.Linear(node_input_dim, hidden_dim0)
     self.resmpblock0 = ResidualMessagePassingBlock(hidden_dim0,
                                                    hidden_dim1,
                                                    edge_input_dim)
     self.resmpblock1 = ResidualMessagePassingBlock(hidden_dim1,
                                                    hidden_dim2,
                                                    edge_input_dim)
     self.resmpblock2 = ResidualMessagePassingBlock(hidden_dim2,
                                                    hidden_dim3,
                                                    edge_input_dim)
     self.resmpblock3 = ResidualMessagePassingBlock(hidden_dim3,
                                                    hidden_dim4,
                                                    edge_input_dim)
     self.resmpblock4 = ResidualMessagePassingBlock(hidden_dim4,
                                                    hidden_dim5,
                                                    edge_input_dim)
     self.set2set = Set2Set(hidden_dim5, processing_steps=3)
     self.ffnn_out = torch.nn.Linear(hidden_dim5 * 2, output_dim)
Ejemplo n.º 7
0
    def __init__(self,
                 latent_dim,
                 output_dim,
                 num_node_feats,
                 num_edge_feats,
                 max_lv=3,
                 act_func='elu',
                 msg_aggregate_type='mean',
                 dropout=None):
        if output_dim > 0:
            embed_dim = output_dim
        else:
            embed_dim = latent_dim
        super(MPNN, self).__init__(embed_dim, dropout)
        if msg_aggregate_type == 'sum':
            msg_aggregate_type = 'add'
        self.max_lv = max_lv
        self.readout = nn.Linear(2 * latent_dim, self.embed_dim)
        self.lin0 = torch.nn.Linear(num_node_feats, latent_dim)
        net = MLP(input_dim=num_edge_feats,
                  hidden_dims=[128, latent_dim * latent_dim],
                  nonlinearity=act_func)
        self.conv = NNConv(latent_dim,
                           latent_dim,
                           net,
                           aggr=msg_aggregate_type,
                           root_weight=False)

        self.act_func = NONLINEARITIES[act_func]
        self.gru = nn.GRU(latent_dim, latent_dim)
        self.set2set = Set2Set(latent_dim, processing_steps=3)
Ejemplo n.º 8
0
    def __init__(self, args, device):
        super(Net, self).__init__()

        self.args = args
        self.device = device

        node_dim = self.args.node_dim
        edge_dim = self.args.edge_dim
        hidden_dim = self.args.hidden_dim
        processing_steps = self.args.processing_steps
        self.depth = self.args.depth

        self.lin0 = torch.nn.Linear(node_dim, hidden_dim)
        nn = Sequential(Linear(edge_dim, hidden_dim * 2), ReLU(),
                        Linear(hidden_dim * 2, hidden_dim * hidden_dim))
        self.conv = NNConv(hidden_dim, hidden_dim, nn, aggr='mean')
        self.gru = GRU(hidden_dim, hidden_dim)

        self.set2set = Set2Set(hidden_dim, processing_steps=processing_steps)
        self.lin1 = torch.nn.Linear(2 * hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, 1)

        self.lin3 = torch.nn.Linear(hidden_dim, 36)
        self.lin4 = torch.nn.Linear(36, 2)

        self.lin5 = torch.nn.Linear(hidden_dim, 36)
        self.lin6 = torch.nn.Linear(36, 2)

        self.apply(init_weights)
Ejemplo n.º 9
0
    def __init__(self,
                 node_input_dim=15,
                 edge_input_dim=5,
                 output_dim=1,
                 node_hidden_dim=64,
                 edge_hidden_dim=128,
                 num_step_message_passing=6,
                 num_step_set2set=6):
        super(MPNN, self).__init__()
        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
        self.conv = NNConv(node_hidden_dim,
                           node_hidden_dim,
                           edge_network,
                           aggr='mean',
                           root_weight=False)
        self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)

        self.set2set = Set2Set(node_hidden_dim,
                               processing_steps=num_step_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Ejemplo n.º 10
0
    def __init__(self, num_input_features: int, nodes_out: int, graph_out: int,
                 num_layers: int, num_towers: int, hidden_u: int, out_u: int, hidden_gru: int,
                 type: str, debug_model=False):
        super().__init__()
        num_input_u = 1 + num_input_features

        self.debug_model = debug_model

        self.edge_counter = EdgeCounter()
        self.initial_lin_u = nn.Linear(num_input_u, hidden_u)

        self.extractor = NodeExtractor(hidden_u, out_u)

        self.gru = nn.GRU(out_u, hidden_gru)
        self.convs = nn.ModuleList([])
        self.batch_norm_u = nn.ModuleList([])
        for i in range(0, num_layers):
            self.batch_norm_u.append(BatchNorm(hidden_u, use_x=False))
            conv = (TypeASMPLayer if type == 'A' else TypeBSMPLayer)(in_features=hidden_u, out_features=hidden_u,
                                                                     num_towers=num_towers)
            self.convs.append(conv)

        # Process the extracted node features
        max_n = 19
        self.set2set = Set2Set(hidden_gru, max_n)

        self.final_node = nn.Sequential(nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
                                        nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
                                        nn.Linear(hidden_gru, nodes_out))

        self.final_graph = nn.Sequential(nn.Linear(2 * hidden_gru, hidden_gru), nn.ReLU(),
                                         nn.BatchNorm1d(hidden_gru),
                                         nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
                                         nn.BatchNorm1d(hidden_gru),
                                         nn.Linear(hidden_gru, graph_out))
Ejemplo n.º 11
0
 def __init__(self):
     super(Net, self).__init__()
     nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
     self.conv = NNConv(dim, dim, nn, root_weight=False)
     self.gru = GRU(dim, dim, batch_first=True)
     self.set2set = Set2Set(dim, dim, processing_steps=3)
     self.fc1 = torch.nn.Linear(2 * dim, dim)
     self.fc2 = torch.nn.Linear(dim, 1)
Ejemplo n.º 12
0
    def __init__(self, num_classes, gnn_layers, embed_dim, hidden_dim,
                 jk_layer, process_step, dropout):
        super(Net, self).__init__()

        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        self.embedding = Embedding(6, embed_dim)

        for i in range(gnn_layers):
            if i == 0:
                self.convs.append(
                    AGGINConv(Sequential(Linear(2 * embed_dim + 2,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))
            else:
                self.convs.append(
                    AGGINConv(Sequential(Linear(hidden_dim,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))

        if jk_layer.isdigit():
            jk_layer = int(jk_layer)
            self.jk = JumpingKnowledge(mode='lstm',
                                       channels=hidden_dim,
                                       gnn_layers=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'cat':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(gnn_layers * hidden_dim,
                                processing_steps=process_step))
            self.fc1 = Linear(2 * gnn_layers * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'max':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
Ejemplo n.º 13
0
 def __init__(self, in_dim, edge_in_dim, hidden_dim=32, depth=3):
     super(TrimNet, self).__init__()
     self.depth = depth
     self.lin0 = Linear(in_dim, hidden_dim)
     self.convs = Sequential(
         *[Block(hidden_dim, edge_in_dim) for i in range(depth)])
     self.set2set = Set2Set(hidden_dim, processing_steps=3)
     self.lin1 = torch.nn.Linear(2 * hidden_dim, 1)
    def __init__(
        self,
        num_tasks,
        num_layer=5,
        emb_dim=300,
        gnn_type="gin",
        virtual_node=True,
        residual=False,
        drop_ratio=0.5,
        jk="last",
        graph_pooling="mean",
    ):
        if num_layer <= 1:
            raise ValueError("Number of GNN layers must be greater than 1.")

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.jk = jk
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        # GNN to generate node embeddings
        gnn_cls = GNN_node_Virtualnode if virtual_node else GNN_node
        self.gnn_node = gnn_cls(
            num_layer,
            emb_dim,
            jk=jk,
            drop_ratio=drop_ratio,
            residual=residual,
            gnn_type=gnn_type,
        )

        # Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
                torch.nn.Linear(emb_dim, 2 * emb_dim),
                torch.nn.BatchNorm1d(2 * emb_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(2 * emb_dim, 1),
            ))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        adj = 2 if graph_pooling == "set2set" else 1
        self.graph_pred_linear = torch.nn.Linear(adj * self.emb_dim,
                                                 self.num_tasks)
Ejemplo n.º 15
0
 def __init__(self, dataset, num_layers, hidden):
     super(Set2SetNet, self).__init__()
     self.conv1 = SAGEConv(dataset.num_features, hidden)
     self.convs = torch.nn.ModuleList()
     for i in range(num_layers - 1):
         self.convs.append(SAGEConv(hidden, hidden))
     self.set2set = Set2Set(hidden, processing_steps=4)
     self.lin1 = Linear(2 * hidden, hidden)
     self.lin2 = Linear(hidden, dataset.num_classes)
Ejemplo n.º 16
0
    def __init__(self, num_features):
        super(Set2SetNet, self).__init__()
        self.conv1 = SAGEConv(num_features, 8)
        self.conv2 = SAGEConv(8, 16)

        self.set2set = Set2Set(16, processing_steps=4)

        # self.fc = torch.nn.Linear(2 * 16, 1)
        self.fc = torch.nn.Linear(2 * 16, 2)
Ejemplo n.º 17
0
    def __init__(self, num_features, dim, num_layers=1):
        super(SUPEncoder, self).__init__()
        self.lin0 = torch.nn.Linear(num_features, dim)

        nnu = nn.Sequential(nn.Linear(5, 128), nn.ReLU(), nn.Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nnu, aggr='mean', root_weight=False)
        self.gru = nn.GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
Ejemplo n.º 18
0
    def __init__(self,
                 D,
                 C,
                 G=0,
                 E=1,
                 Q=96,
                 task='graph',
                 aggr='add',
                 pooltype='max',
                 conclayers=True):

        super(NNNet, self).__init__()

        self.D = D  # node feature dimension
        self.E = E  # edge feature dimension
        self.G = G  # global feature dimension
        self.C = C  # number output classes

        self.Q = Q  # latent dimension

        self.task = task
        self.pooltype = pooltype
        self.conclayers = conclayers

        # Convolution layers
        # nn with size [-1, num_edge_features] x [-1, in_channels * out_channels]
        self.conv1 = NNConv(in_channels=D,
                            out_channels=D,
                            nn=MLP([E, D * D]),
                            aggr=aggr)
        self.conv2 = NNConv(in_channels=D,
                            out_channels=D,
                            nn=MLP([E, D * D]),
                            aggr=aggr)

        # "Fusion" layer taking in conv layer outputs
        if self.conclayers:
            self.lin1 = MLP([D + D, Q])
        else:
            self.lin1 = MLP([D, Q])

        # Set2Set pooling operation produces always output with 2 x input dimension
        # => use linear layer to project down
        if pooltype == 's2s':
            self.S2Spool = Set2Set(in_channels=Q,
                                   processing_steps=3,
                                   num_layers=1)
            self.S2Slin = Linear(2 * Q, Q)

        if (self.G > 0):
            self.Z = Q + self.G
        else:
            self.Z = Q

        # Final layers concatenating everything
        self.mlp1 = MLP([self.Z, self.Z, self.C])
Ejemplo n.º 19
0
    def __init__(self, num_features, dim):
        super().__init__()
        self.lin0 = nn.Linear(num_features, dim)

        mlp = nn.Sequential(nn.Linear(5, 128), nn.ReLU(),
                            nn.Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, mlp, aggr='mean', root_weight=False)
        self.gru = nn.GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
Ejemplo n.º 20
0
    def __init__(self,
                 num_classes: int,
                 num_layer=5,
                 emb_dim=300,
                 node_encoder=None,
                 edge_encoder_ctor: torch.nn.Module = None,
                 residual=False,
                 drop_ratio=0.5,
                 JK="last",
                 graph_pooling="mean",
                 max_seq_len=1):

        super(GCN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.graph_pooling = graph_pooling
        self.max_seq_len = max_seq_len
        self.num_classes = num_classes

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        # GNN to generate node embeddings
        self.gnn_node = GNN_node(num_layer,
                                 emb_dim=emb_dim,
                                 JK=JK,
                                 node_encoder=node_encoder,
                                 edge_encoder_ctor=edge_encoder_ctor,
                                 drop_ratio=drop_ratio,
                                 residual=residual)

        # Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
                torch.nn.Linear(emb_dim, 2 *
                                emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear_list = torch.nn.ModuleList()
        for i in range(max_seq_len):
            self.graph_pred_linear_list.append(
                torch.nn.Linear(emb_dim, self.num_classes))
Ejemplo n.º 21
0
    def __init__(self, args, num_node_features, num_edge_features):
        super(GNN, self).__init__()

        self.depth = args.depth
        self.hidden_size = args.hidden_size
        self.dropout = args.dropout
        self.gnn_type = args.gnn_type
        self.graph_pool = args.graph_pool
        self.tetra = args.tetra
        self.task = args.task

        if self.gnn_type == 'dmpnn':
            self.edge_init = nn.Linear(num_node_features + num_edge_features, self.hidden_size)
            self.edge_to_node = DMPNNConv(args)
        else:
            self.node_init = nn.Linear(num_node_features, self.hidden_size)
            self.edge_init = nn.Linear(num_edge_features, self.hidden_size)

        # layers
        self.convs = torch.nn.ModuleList()

        for _ in range(self.depth):
            if self.gnn_type == 'gin':
                self.convs.append(GINEConv(args))
            elif self.gnn_type == 'gcn':
                self.convs.append(GCNConv(args))
            elif self.gnn_type == 'dmpnn':
                self.convs.append(DMPNNConv(args))
            else:
                ValueError('Undefined GNN type called {}'.format(self.gnn_type))

        # graph pooling
        if self.tetra:
            self.tetra_update = get_tetra_update(args)

        if self.graph_pool == "sum":
            self.pool = global_add_pool
        elif self.graph_pool == "mean":
            self.pool = global_mean_pool
        elif self.graph_pool == "max":
            self.pool = global_max_pool
        elif self.graph_pool == "attn":
            self.pool = GlobalAttention(
                gate_nn=torch.nn.Sequential(torch.nn.Linear(self.hidden_size, 2 * self.hidden_size),
                                            torch.nn.BatchNorm1d(2 * self.hidden_size),
                                            torch.nn.ReLU(),
                                            torch.nn.Linear(2 * self.hidden_size, 1)))
        elif self.graph_pool == "set2set":
            self.pool = Set2Set(self.hidden_size, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        # ffn
        self.mult = 2 if self.graph_pool == "set2set" else 1
        self.ffn = nn.Linear(self.mult * self.hidden_size, 1)
Ejemplo n.º 22
0
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = GCNConv(dataset.num_features, 256, cached=False)
        self.conv2 = GCNConv(256, 128, cached=False)
        self.conv3 = GCNConv(128, 32, cached=False)
        self.set2set = Set2Set(32, processing_steps=1)
        self.linear1 = torch.nn.Linear(64, 256)
        self.linear2 = torch.nn.Linear(256, 128)
        self.linear3 = torch.nn.Linear(128, 32)
        self.linear4 = torch.nn.Linear(32, 1)
 def __init__(self,option):
     super(Henrion_MPNN, self).__init__()
     jet_in_dim, jet_edge_in_dim, particle_in_dim, particle_edge_in_dim = \
         option.jet_features,option.jet_edge_features,option.particle_features, option.particle_edge_features
     hid1,hid2,  edge_out_dim, num_step_message_passing,self.dropout = \
          option.hid1, option.hid2, option.hid1, option.num_step_message_passing, option.dropout
     self.conv1 = Henrion_MPNNConv(jet_in_dim, hid2,dropout=self.dropout)
     self.set2set = Set2Set(hid2, processing_steps=3)
     self.mlp1 = nn.Linear(2*hid2, hid2)
     self.mlp2 = nn.Linear(hid2,4)
     self.logsoftmax=nn.LogSoftmax(dim=-1)
Ejemplo n.º 24
0
    def __init__(self, num_input_features: int, nodes_out: int, graph_out: int,
                 num_layers: int, num_towers: int, hidden_u: int, out_u: int,
                 hidden_gru: int, layer_type: str):
        """ num_input_features: number of node features
            nodes_out: number of output features at each node's level (3 on the benchmark)
            graph_out: number of output features at the graph level (3 on the benchmark)
            num_towers: inside each SMP layers, use towers to reduce the number of parameters
            hidden_u: number of channels in the local contexts
            out_u: number of channels after extraction of node features
            hidden_gru: number of channels inside the gated recurrent unit
            layer_type: 'SMP', 'FastSMP' or 'SimplifiedFastSMP'.
        """
        super().__init__()
        num_input_u = 1 + num_input_features

        self.edge_counter = EdgeCounter()
        self.initial_lin_u = nn.Linear(num_input_u, hidden_u)

        self.extractor = NodeExtractor(hidden_u, out_u)

        layer_type_dict = {
            'SMP': SMPLayer,
            'FastSMP': FastSMPLayer,
            'SimplifiedFastSMP': SimplifiedFastSMPLayer
        }
        conv_layer = layer_type_dict[layer_type]

        self.gru = nn.GRU(out_u, hidden_gru)
        self.convs = nn.ModuleList([])
        self.batch_norm_u = nn.ModuleList([])
        for i in range(0, num_layers):
            self.batch_norm_u.append(BatchNorm(hidden_u, use_x=False))
            conv = conv_layer(in_features=hidden_u,
                              out_features=hidden_u,
                              num_towers=num_towers,
                              use_x=False)
            self.convs.append(conv)

        # Process the extracted node features
        max_n = 19
        self.set2set = Set2Set(hidden_gru, max_n)

        self.final_node = nn.Sequential(nn.Linear(hidden_gru, hidden_gru),
                                        nn.LeakyReLU(),
                                        nn.Linear(hidden_gru, hidden_gru),
                                        nn.LeakyReLU(),
                                        nn.Linear(hidden_gru, nodes_out))

        self.final_graph = nn.Sequential(nn.Linear(2 * hidden_gru, hidden_gru),
                                         nn.ReLU(), nn.BatchNorm1d(hidden_gru),
                                         nn.Linear(hidden_gru, hidden_gru),
                                         nn.LeakyReLU(),
                                         nn.BatchNorm1d(hidden_gru),
                                         nn.Linear(hidden_gru, graph_out))
Ejemplo n.º 25
0
    def __init__(self):
        super(Net, self).__init__()
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nn, aggr='mean')
        self.gru = GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * dim, dim)
        self.lin2 = torch.nn.Linear(dim, 1)
Ejemplo n.º 26
0
    def __init__(self, out_features: int, n_bu: int, n_td: int, C: int,
                 M: int):
        super(GraphHTMN, self).__init__()
        self.bu = UniformBottomUpHTMM(n_bu, C, M) if n_bu > 0 else None
        self.td = TopDownHTMM(n_td, C, M) if n_td > 0 else None

        #self.b_norm = BatchNorm(n_bu + n_td, affine=False)

        self.contrastive = nn.Parameter(_contrastive_matrix(n_bu + n_td),
                                        requires_grad=False)
        self.pooling = Set2Set(self.contrastive.size(1), 2, 1)
        self.output = nn.Linear(2 * self.contrastive.size(1), out_features)
Ejemplo n.º 27
0
    def __init__(self):
        super(Net, self).__init__()

        nhid = args.nhid

        self.conv1 = GraphConv(dataset.num_features, nhid)
        self.conv2 = GraphConv(nhid, nhid)
        self.conv3 = GraphConv(nhid, nhid)
        self.pool = Set2Set(nhid * 3, 10)

        self.lin1 = torch.nn.Linear(nhid * 6, nhid)
        self.lin2 = torch.nn.Linear(nhid, int(nhid / 2))
        self.lin3 = torch.nn.Linear(int(nhid / 2), dataset.num_classes)
    def __init__(self, lrate):
        super(Net, self).__init__()
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nn, aggr='mean')
        self.gru = GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * dim, dim)
        self.lin2 = torch.nn.Linear(dim, 1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lrate)
        self.criterion = torch.nn.MSELoss()
    def __init__(self,option):
        super(EGAT, self).__init__()
        jet_in_dim, node_in_dim, hid1,hid2, edge_in_dim, edge_out_dim, num_step_message_passing,self.dropout = \
            option.jet_features, option.particle_features, option.hid1, option.hid2, option.particle_edge_features,\
            option.hid1, option.num_step_message_passing, option.dropout
        heads = option.heads
        self.conv1 = EGatConv(jet_in_dim, hid2, edge_in_dim, heads, num_step_message_passing,dropout=self.dropout)

        self.set2set = Set2Set(hid2, processing_steps=3)

        self.mlp1 = nn.Linear(2*hid2, hid2)
        self.mlp2 = nn.Linear(hid2,4)
        self.logsoftmax=nn.LogSoftmax(dim=-1)
Ejemplo n.º 30
0
    def __init__(self):
        super(Net, self).__init__()
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)

        # 这个nn是每个节点在message中用到的网络,5对应边特征数 
        # 注:此脚本中transform部分把edge_attr变成了5,而不是本来的4
        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nn, aggr='mean')
        self.gru = GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * dim, dim)
        self.lin2 = torch.nn.Linear(dim, 1)