Esempio n. 1
0
 def __init__(
     self,
     hidden_size=64,
     num_layer=2,
     readout="avg",
     layernorm: bool = False,
     set2set_lstm_layer: int = 3,
     set2set_iter: int = 6,
 ):
     super(UnsupervisedGCN, self).__init__()
     self.layers = nn.ModuleList([
         GCNLayer(
             in_feats=hidden_size,
             out_feats=hidden_size,
             activation=F.relu if i + 1 < num_layer else None,
             residual=False,
             batchnorm=False,
             dropout=0.0,
         ) for i in range(num_layer)
     ])
     if readout == "avg":
         self.readout = AvgPooling()
     elif readout == "set2set":
         self.readout = Set2Set(hidden_size,
                                n_iters=set2set_iter,
                                n_layers=set2set_lstm_layer)
         self.linear = nn.Linear(2 * hidden_size, hidden_size)
     elif readout == "root":
         # HACK: process outside the model part
         self.readout = lambda _, x: x
     else:
         raise NotImplementedError
     self.layernorm = layernorm
     if layernorm:
         self.ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
Esempio n. 2
0
    def __init__(
        self,
        node_input_dim=15,
        edge_input_dim=5,
        output_dim=12,
        node_hidden_dim=64,
        edge_hidden_dim=128,
        num_step_message_passing=6,
        num_step_set2set=6,
        num_layer_set2set=3,
    ):
        super(MPNNModel, self).__init__()

        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim),
            nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim),
        )
        self.conv = NNConv(
            in_feats=node_hidden_dim,
            out_feats=node_hidden_dim,
            edge_func=edge_network,
            aggregator_type="sum",
        )
        self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)

        self.set2set = Set2Set(node_hidden_dim, num_step_set2set,
                               num_layer_set2set)

        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Esempio n. 3
0
 def __init__(self,
              in_node_feats,
              in_edge_feats,
              node_hidden_dim,
              edge_hidden_dim,
              num_step_message_passing,
              num_step_set2set,
              num_layer_set2set,
              n_tasks,
              regressor_hidden_feats=128,
              dropout=0.):
     super(MPNNRegressorBypass,
           self).__init__(readout_feats=4 * node_hidden_dim,
                          n_tasks=n_tasks,
                          regressor_hidden_feats=regressor_hidden_feats,
                          dropout=dropout)
     self.shared_gnn = MPNNGNN(in_node_feats, in_edge_feats,
                               node_hidden_dim, edge_hidden_dim,
                               num_step_message_passing)
     for _ in range(n_tasks):
         self.task_gnns.append(
             MPNNGNN(in_node_feats, in_edge_feats, node_hidden_dim,
                     edge_hidden_dim, num_step_message_passing))
         self.readouts.append(
             Set2Set(2 * node_hidden_dim, num_step_set2set,
                     num_layer_set2set))
Esempio n. 4
0
    def __init__(self, num_tasks = 1, num_layers = 5, emb_dim = 300, gnn_type = 'gin',
                 virtual_node = True, residual = False, drop_ratio = 0, JK = "last",
                 graph_pooling = "sum"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''
        super(GNN, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layers, emb_dim, JK = JK,
                                                 drop_ratio = drop_ratio,
                                                 residual = residual,
                                                 gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layers, emb_dim, JK = JK, drop_ratio = drop_ratio,
                                     residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = SumPooling()
        elif self.graph_pooling == "mean":
            self.pool = AvgPooling()
        elif self.graph_pooling == "max":
            self.pool = MaxPooling
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttentionPooling(
                gate_nn = nn.Sequential(nn.Linear(emb_dim, 2*emb_dim),
                                        nn.BatchNorm1d(2*emb_dim),
                                        nn.ReLU(),
                                        nn.Linear(2*emb_dim, 1)))

        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, n_iters = 2, n_layers = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
Esempio n. 5
0
    def __init__(self,
                 node_input_dim=42,
                 edge_input_dim=10,
                 node_hidden_dim=42,
                 edge_hidden_dim=42,
                 num_step_message_passing=6,
                 interaction='dot',
                 num_step_set2_set=2,
                 num_layer_set2set=1,
                 ):
        super(CIGINModel, self).__init__()

        self.node_input_dim = node_input_dim
        self.node_hidden_dim = node_hidden_dim
        self.edge_input_dim = edge_input_dim
        self.edge_hidden_dim = edge_hidden_dim
        self.num_step_message_passing = num_step_message_passing
        self.interaction = interaction
        self.solute_gather = GatherModel(self.node_input_dim, self.edge_input_dim,
                                         self.node_hidden_dim, self.edge_input_dim,
                                         self.num_step_message_passing,
                                         )
        self.solvent_gather = GatherModel(self.node_input_dim, self.edge_input_dim,
                                          self.node_hidden_dim, self.edge_input_dim,
                                          self.num_step_message_passing,
                                          )

        self.fc1 = nn.Linear(8 * self.node_hidden_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.imap = nn.Linear(80, 1)

        self.num_step_set2set = num_step_set2_set
        self.num_layer_set2set = num_layer_set2set
        self.set2set_solute = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set)
        self.set2set_solvent = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set)
Esempio n. 6
0
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 global_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 global_hidden_feats=32,
                 n_tasks=1,
                 num_step_message_passing=6,
                 num_step_set2set=6,
                 num_layer_set2set=3,
                 output_f=None):
        super(MPNNPredictor, self).__init__()

        self.gnn = MPNNGNN(
            node_in_feats=node_in_feats,
            node_out_feats=node_out_feats,
            edge_in_feats=edge_in_feats,
            edge_hidden_feats=edge_hidden_feats,
            num_step_message_passing=num_step_message_passing,
        )
        self.readout = Set2Set(
            input_dim=node_out_feats,
            n_iters=num_step_set2set,
            n_layers=num_layer_set2set,
        )

        self.global_subnet = nn.Sequential(
            nn.Linear(global_feats, global_hidden_feats),
            nn.ReLU(),
            nn.Linear(global_hidden_feats, global_hidden_feats),
            nn.ReLU(),
        )

        self.predict = nn.Sequential(
            nn.Linear(2 * node_out_feats + global_hidden_feats,
                      node_out_feats),
            nn.ReLU(),
            nn.Linear(node_out_feats, n_tasks),
        )

        self.output_f = output_f
Esempio n. 7
0
 def __init__(self,
              in_node_feats,
              in_edge_feats,
              node_hidden_dim,
              edge_hidden_dim,
              num_step_message_passing,
              num_step_set2set,
              num_layer_set2set,
              n_tasks,
              regressor_hidden_feats=128,
              dropout=0.):
     super(MPNNRegressor,
           self).__init__(readout_feats=2 * node_hidden_dim,
                          n_tasks=n_tasks,
                          regressor_hidden_feats=regressor_hidden_feats,
                          dropout=dropout)
     self.gnn = MPNNGNN(in_node_feats, in_edge_feats, node_hidden_dim,
                        edge_hidden_dim, num_step_message_passing)
     self.readout = Set2Set(node_hidden_dim, num_step_set2set,
                            num_layer_set2set)
Esempio n. 8
0
 def __init__(self,
              node_input_dim=42,
              edge_input_dim=10,
              node_hidden_dim=42,
              edge_hidden_dim=42,
              num_step_message_passing=6,
              ):
     super(GatherModel, self).__init__()
     self.num_step_message_passing = num_step_message_passing
     self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
     self.set2set = Set2Set(node_hidden_dim, 2, 1)
     self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
     edge_network = nn.Sequential(
         nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
         nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
     self.conv = NNConv(in_feats=node_hidden_dim,
                        out_feats=node_hidden_dim,
                        edge_func=edge_network,
                        aggregator_type='sum',
                        residual=True
                        )
Esempio n. 9
0
 def __init__(self, in_feats, out_feats, n_hidden, n_iter_readout):
     super(GGNN, self).__init__()
     self.in_feats = in_feats
     self.out_feats = out_feats
     self.hidden = n_hidden
     self.conv1 = GatedGraphConv(in_feats=in_feats,
                                 out_feats=out_feats,
                                 n_etypes=6,
                                 n_steps=5).cuda()
     self.conv2 = GraphConv(in_feats=in_feats, out_feats=out_feats)
     self.edge_net = nn.Linear(6, n_hidden)
     self.feat_net = nn.Linear(in_feats, n_hidden)
     self.lin_1 = nn.Linear(out_feats, n_hidden)
     self.lstm = nn.LSTM(input_size=n_hidden,
                         hidden_size=n_hidden,
                         num_layers=3,
                         batch_first=True)
     self.lin_2 = nn.Linear(in_features=n_hidden, out_features=3)
     self.predict = nn.Linear(2 * out_feats, 3)
     self.dropout = nn.Dropout(p=0.1)
     self.set2set = Set2Set(input_dim=out_feats,
                            n_iters=n_iter_readout,
                            n_layers=3)
Esempio n. 10
0
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=64,
                 edge_hidden_feats=128,
                 num_step_message_passing=6,
                 num_step_set2set=6,
                 num_layer_set2set=3):
        super(MPNN_encoder, self).__init__()

        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        self.readout = Set2Set(input_dim=node_out_feats,
                               n_iters=num_step_set2set,
                               n_layers=num_layer_set2set)
        self.process = nn.Sequential(
            # nn.Linear(2 * node_out_feats + 114, node_out_feats),
            nn.Linear(2 * node_out_feats, node_out_feats),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )
Esempio n. 11
0
    def __init__(
        self,
        positional_embedding_size=32,
        max_node_freq=8,
        max_edge_freq=8,
        max_degree=128,
        freq_embedding_size=32,
        degree_embedding_size=32,
        output_dim=32,
        node_hidden_dim=32,
        edge_hidden_dim=32,
        num_layers=6,
        num_heads=4,
        num_step_set2set=6,
        num_layer_set2set=3,
        norm=False,
        gnn_model="mpnn",
        degree_input=False,
        lstm_as_gate=False,
    ):
        super(GraphEncoder, self).__init__()

        if degree_input:
            node_input_dim = positional_embedding_size + degree_embedding_size + 1
        else:
            node_input_dim = positional_embedding_size + 1
        edge_input_dim = freq_embedding_size + 1
        if gnn_model == "mpnn":
            self.gnn = UnsupervisedMPNN(
                output_dim=output_dim,
                node_input_dim=node_input_dim,
                node_hidden_dim=node_hidden_dim,
                edge_input_dim=edge_input_dim,
                edge_hidden_dim=edge_hidden_dim,
                num_step_message_passing=num_layers,
                lstm_as_gate=lstm_as_gate,
            )
        elif gnn_model == "gat":
            self.gnn = UnsupervisedGAT(
                node_input_dim=node_input_dim,
                node_hidden_dim=node_hidden_dim,
                edge_input_dim=edge_input_dim,
                num_layers=num_layers,
                num_heads=num_heads,
            )
        elif gnn_model == "gin":
            self.gnn = UnsupervisedGIN(
                num_layers=num_layers,
                num_mlp_layers=2,
                input_dim=node_input_dim,
                hidden_dim=node_hidden_dim,
                output_dim=output_dim,
                final_dropout=0.5,
                learn_eps=False,
                graph_pooling_type="sum",
                neighbor_pooling_type="sum",
                use_selayer=False,
            )
        self.gnn_model = gnn_model

        self.max_node_freq = max_node_freq
        self.max_edge_freq = max_edge_freq
        self.max_degree = max_degree
        self.degree_input = degree_input

        if degree_input:
            self.degree_embedding = nn.Embedding(num_embeddings=max_degree + 1, embedding_dim=degree_embedding_size)

        self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
        self.lin_readout = nn.Sequential(
            nn.Linear(2 * node_hidden_dim, node_hidden_dim),
            nn.ReLU(),
            nn.Linear(node_hidden_dim, output_dim),
        )
        self.norm = norm