def __init__( self, node_input_dim=15, edge_input_dim=5, output_dim=12, node_hidden_dim=64, edge_hidden_dim=128, num_step_message_passing=6, num_step_set2set=6, num_layer_set2set=3, ): super(MPNNModel, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim), ) self.conv = NNConv( in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type="sum", ) self.gru = nn.GRU(node_hidden_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
def __init__( self, output_dim=32, node_input_dim=32, node_hidden_dim=32, edge_input_dim=32, edge_hidden_dim=32, num_step_message_passing=6, lstm_as_gate=False, ): super(UnsupervisedMPNN, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim), ) self.conv = NNConv( in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type="sum", ) self.lstm_as_gate = lstm_as_gate if lstm_as_gate: self.lstm = nn.LSTM(node_hidden_dim, node_hidden_dim) else: self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
def __init__(self, node_hidden_dim, edge_hidden_dim, n_classes, node_input_dim=23, edge_input_dim=19, num_step_message_passing=6, num_step_set2set=6, num_layer_set2set=3): super(Mol2NumNet_regressor, self).__init__() self.n_classes = n_classes self.num_step_message_passing = num_step_message_passing self.lin_0 = nn.Linear(node_input_dim, node_hidden_dim) edge_net = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) self.conv_2 = NNConv(in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_net, aggregator_type="sum") self.gru = nn.GRU( node_hidden_dim, node_hidden_dim, num_layers=1, ) self.set2set = Set2Set(input_dim=node_hidden_dim, n_iters=num_step_set2set, n_layers=num_layer_set2set) self.lin2 = nn.Linear(node_hidden_dim, node_hidden_dim) self.lin3 = nn.Linear(2 * node_hidden_dim, 2 * node_hidden_dim) self.predict = nn.Linear(2 * node_hidden_dim, n_classes) self.dropout = nn.Dropout(p=0.1)
def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64, edge_hidden_feats=128, num_step_message_passing=6): super(MPNNGNN, self).__init__() self.project_node_feats = nn.Sequential( nn.Linear(node_in_feats, node_out_feats), nn.ReLU() ) self.num_step_message_passing = num_step_message_passing edge_network = nn.Sequential( nn.Linear(edge_in_feats, edge_hidden_feats), nn.ReLU(), nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats) ) self.gnn_layer = NNConv( in_feats=node_out_feats, out_feats=node_out_feats, edge_func=edge_network, aggregator_type='sum' ) self.gru = nn.GRU(node_out_feats, node_out_feats)
def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6, ): super(GatherModel, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, 2, 1) self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) self.conv = NNConv(in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type='sum', residual=True )
class MPNNGNN(nn.Module): """MPNN. MPNN is introduced in `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__. This class performs message passing in MPNN and returns the updated node representations. Parameters ---------- node_in_feats : int Size for the input node features. node_out_feats : int Size for the output node representations. Default to 64. edge_in_feats : int Size for the input edge features. Default to 128. edge_hidden_feats : int Size for the hidden edge representations. num_step_message_passing : int Number of message passing steps. Default to 6. """ def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64, edge_hidden_feats=128, num_step_message_passing=6): super(MPNNGNN, self).__init__() self.project_node_feats = nn.Sequential( nn.Linear(node_in_feats, node_out_feats), nn.ReLU() ) self.num_step_message_passing = num_step_message_passing edge_network = nn.Sequential( nn.Linear(edge_in_feats, edge_hidden_feats), nn.ReLU(), nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats) ) self.gnn_layer = NNConv( in_feats=node_out_feats, out_feats=node_out_feats, edge_func=edge_network, aggregator_type='sum' ) self.gru = nn.GRU(node_out_feats, node_out_feats) def reset_parameters(self): """Reinitialize model parameters.""" self.project_node_feats[0].reset_parameters() self.gnn_layer.reset_parameters() for layer in self.gnn_layer.edge_nn: if isinstance(layer, nn.Linear): layer.reset_parameters() self.gru.reset_parameters() def forward(self, g, node_feats, edge_feats): """Performs message passing and updates node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes in the batch of graphs. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges in the batch of graphs. Returns ------- node_feats : float32 tensor of shape (V, node_out_feats) Output node representations. """ node_feats = self.project_node_feats(node_feats) # (V, node_out_feats) hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats) for _ in range(self.num_step_message_passing): node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats)) node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats) node_feats = node_feats.squeeze(0) return node_feats
class GatherModel(nn.Module): """ MPNN from `Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>` Parameters ---------- node_input_dim : int Dimension of input node feature, default to be 42. edge_input_dim : int Dimension of input edge feature, default to be 10. node_hidden_dim : int Dimension of node feature in hidden layers, default to be 42. edge_hidden_dim : int Dimension of edge feature in hidden layers, default to be 128. num_step_message_passing : int Number of message passing steps, default to be 6. """ def __init__(self, node_input_dim=42, edge_input_dim=10, node_hidden_dim=42, edge_hidden_dim=42, num_step_message_passing=6, ): super(GatherModel, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, 2, 1) self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) self.conv = NNConv(in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type='sum', residual=True ) def forward(self, g, n_feat, e_feat): """Returns the node embeddings after message passing phase. Parameters ---------- g : DGLGraph Input DGLGraph for molecule(s) n_feat : tensor of dtype float32 and shape (B1, D1) Node features. B1 for number of nodes and D1 for the node feature size. e_feat : tensor of dtype float32 and shape (B2, D2) Edge features. B2 for number of edges and D2 for the edge feature size. Returns ------- res : node features """ init = n_feat.clone() out = F.relu(self.lin0(n_feat)) for i in range(self.num_step_message_passing): if e_feat is not None: m = torch.relu(self.conv(g, out, e_feat)) else: m = torch.relu(self.conv.bias + self.conv.res_fc(out)) out = self.message_layer(torch.cat([m, out], dim=1)) return out + init