Exemplo n.º 1
0
    def __init__(
        self,
        node_input_dim=15,
        edge_input_dim=5,
        output_dim=12,
        node_hidden_dim=64,
        edge_hidden_dim=128,
        num_step_message_passing=6,
        num_step_set2set=6,
        num_layer_set2set=3,
    ):
        super(MPNNModel, self).__init__()

        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim),
            nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim),
        )
        self.conv = NNConv(
            in_feats=node_hidden_dim,
            out_feats=node_hidden_dim,
            edge_func=edge_network,
            aggregator_type="sum",
        )
        self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)

        self.set2set = Set2Set(node_hidden_dim, num_step_set2set,
                               num_layer_set2set)

        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Exemplo n.º 2
0
    def __init__(
        self,
        output_dim=32,
        node_input_dim=32,
        node_hidden_dim=32,
        edge_input_dim=32,
        edge_hidden_dim=32,
        num_step_message_passing=6,
        lstm_as_gate=False,
    ):
        super(UnsupervisedMPNN, self).__init__()

        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim),
            nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim),
        )
        self.conv = NNConv(
            in_feats=node_hidden_dim,
            out_feats=node_hidden_dim,
            edge_func=edge_network,
            aggregator_type="sum",
        )
        self.lstm_as_gate = lstm_as_gate
        if lstm_as_gate:
            self.lstm = nn.LSTM(node_hidden_dim, node_hidden_dim)
        else:
            self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
 def __init__(self,
              node_hidden_dim,
              edge_hidden_dim,
              n_classes,
              node_input_dim=23,
              edge_input_dim=19,
              num_step_message_passing=6,
              num_step_set2set=6,
              num_layer_set2set=3):
     super(Mol2NumNet_regressor, self).__init__()
     self.n_classes = n_classes
     self.num_step_message_passing = num_step_message_passing
     self.lin_0 = nn.Linear(node_input_dim, node_hidden_dim)
     edge_net = nn.Sequential(
         nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
         nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
     self.conv_2 = NNConv(in_feats=node_hidden_dim,
                          out_feats=node_hidden_dim,
                          edge_func=edge_net,
                          aggregator_type="sum")
     self.gru = nn.GRU(
         node_hidden_dim,
         node_hidden_dim,
         num_layers=1,
     )
     self.set2set = Set2Set(input_dim=node_hidden_dim,
                            n_iters=num_step_set2set,
                            n_layers=num_layer_set2set)
     self.lin2 = nn.Linear(node_hidden_dim, node_hidden_dim)
     self.lin3 = nn.Linear(2 * node_hidden_dim, 2 * node_hidden_dim)
     self.predict = nn.Linear(2 * node_hidden_dim, n_classes)
     self.dropout = nn.Dropout(p=0.1)
Exemplo n.º 4
0
    def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
                 edge_hidden_feats=128, num_step_message_passing=6):
        super(MPNNGNN, self).__init__()

        self.project_node_feats = nn.Sequential(
            nn.Linear(node_in_feats, node_out_feats),
            nn.ReLU()
        )
        self.num_step_message_passing = num_step_message_passing
        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats),
            nn.ReLU(),
            nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
        )
        self.gnn_layer = NNConv(
            in_feats=node_out_feats,
            out_feats=node_out_feats,
            edge_func=edge_network,
            aggregator_type='sum'
        )
        self.gru = nn.GRU(node_out_feats, node_out_feats)
Exemplo n.º 5
0
 def __init__(self,
              node_input_dim=42,
              edge_input_dim=10,
              node_hidden_dim=42,
              edge_hidden_dim=42,
              num_step_message_passing=6,
              ):
     super(GatherModel, self).__init__()
     self.num_step_message_passing = num_step_message_passing
     self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
     self.set2set = Set2Set(node_hidden_dim, 2, 1)
     self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
     edge_network = nn.Sequential(
         nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
         nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
     self.conv = NNConv(in_feats=node_hidden_dim,
                        out_feats=node_hidden_dim,
                        edge_func=edge_network,
                        aggregator_type='sum',
                        residual=True
                        )
Exemplo n.º 6
0
class MPNNGNN(nn.Module):
    """MPNN.

    MPNN is introduced in `Neural Message Passing for Quantum Chemistry
    <https://arxiv.org/abs/1704.01212>`__.

    This class performs message passing in MPNN and returns the updated node representations.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_out_feats : int
        Size for the output node representations. Default to 64.
    edge_in_feats : int
        Size for the input edge features. Default to 128.
    edge_hidden_feats : int
        Size for the hidden edge representations.
    num_step_message_passing : int
        Number of message passing steps. Default to 6.
    """
    def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
                 edge_hidden_feats=128, num_step_message_passing=6):
        super(MPNNGNN, self).__init__()

        self.project_node_feats = nn.Sequential(
            nn.Linear(node_in_feats, node_out_feats),
            nn.ReLU()
        )
        self.num_step_message_passing = num_step_message_passing
        edge_network = nn.Sequential(
            nn.Linear(edge_in_feats, edge_hidden_feats),
            nn.ReLU(),
            nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
        )
        self.gnn_layer = NNConv(
            in_feats=node_out_feats,
            out_feats=node_out_feats,
            edge_func=edge_network,
            aggregator_type='sum'
        )
        self.gru = nn.GRU(node_out_feats, node_out_feats)

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.project_node_feats[0].reset_parameters()
        self.gnn_layer.reset_parameters()
        for layer in self.gnn_layer.edge_nn:
            if isinstance(layer, nn.Linear):
                layer.reset_parameters()
        self.gru.reset_parameters()

    def forward(self, g, node_feats, edge_feats):
        """Performs message passing and updates node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        node_feats : float32 tensor of shape (V, node_out_feats)
            Output node representations.
        """
        node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
        hidden_feats = node_feats.unsqueeze(0)           # (1, V, node_out_feats)

        for _ in range(self.num_step_message_passing):
            node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
            node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
            node_feats = node_feats.squeeze(0)

        return node_feats
Exemplo n.º 7
0
class GatherModel(nn.Module):
    """
    MPNN from
    `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`
    Parameters
    ----------
    node_input_dim : int
        Dimension of input node feature, default to be 42.
    edge_input_dim : int
        Dimension of input edge feature, default to be 10.
    node_hidden_dim : int
        Dimension of node feature in hidden layers, default to be 42.
    edge_hidden_dim : int
        Dimension of edge feature in hidden layers, default to be 128.
    num_step_message_passing : int
        Number of message passing steps, default to be 6.
    """

    def __init__(self,
                 node_input_dim=42,
                 edge_input_dim=10,
                 node_hidden_dim=42,
                 edge_hidden_dim=42,
                 num_step_message_passing=6,
                 ):
        super(GatherModel, self).__init__()
        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.set2set = Set2Set(node_hidden_dim, 2, 1)
        self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
        self.conv = NNConv(in_feats=node_hidden_dim,
                           out_feats=node_hidden_dim,
                           edge_func=edge_network,
                           aggregator_type='sum',
                           residual=True
                           )

    def forward(self, g, n_feat, e_feat):
        """Returns the node embeddings after message passing phase.
        Parameters
        ----------
        g : DGLGraph
            Input DGLGraph for molecule(s)
        n_feat : tensor of dtype float32 and shape (B1, D1)
            Node features. B1 for number of nodes and D1 for
            the node feature size.
        e_feat : tensor of dtype float32 and shape (B2, D2)
            Edge features. B2 for number of edges and D2 for
            the edge feature size.
        Returns
        -------
        res : node features
        """

        init = n_feat.clone()
        out = F.relu(self.lin0(n_feat))
        for i in range(self.num_step_message_passing):
            if e_feat is not None:
                m = torch.relu(self.conv(g, out, e_feat))
            else:
                m = torch.relu(self.conv.bias +  self.conv.res_fc(out))
            out = self.message_layer(torch.cat([m, out], dim=1))
        return out + init