def forward(self, mol_graph): fatoms, fbonds, agraph, bgraph, scope = mol_graph fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) message = nn.ReLU()(binput) for _ in xrange(self.depth - 1): nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) nei_message = self.W_h(nei_message) message = nn.ReLU()(binput + nei_message) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = nn.ReLU()(self.W_o(ainput)) mol_vecs = [] for st, le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) fnode = self.embedding(fnode) fmess = index_select_ND(fnode, 0, fmess) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _, x in scope]) batch_vecs = [] for st, le in scope: cur_vecs = node_vecs[st:st + le] cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le)) batch_vecs.append(cur_vecs) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) ################## # try: fnode = self.embedding(fnode) #print(fnode.size()) # except: # fnode = torch.randn((fnode.size(),hidden_size)).cuda() # #################### fmess = index_select_ND(fnode, 0, fmess) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _,x in scope]) batch_vecs = [] for st,le in scope: cur_vecs = node_vecs[st] #Root is the first node batch_vecs.append( cur_vecs ) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0) fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) graph_message = F.relu(binput) for i in xrange(self.depth - 1): message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0) nei_message = self.W_h(nei_message) graph_message = F.relu(binput + nei_message) message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = F.relu(self.W_o(ainput)) mol_vecs = [] for st,le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) fnode = self.embedding(fnode) fmess1 = index_select_ND(fnode, 0, fmess[:, 0]) fmess2 = self.E_pos(fmess[:, 1]) fmess = self.inputNN( torch.cat([fmess1,fmess2], dim=-1) ) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _,x in scope]) batch_vecs = [] for st,le in scope: cur_vecs = node_vecs[st] #Root is the first node batch_vecs.append( cur_vecs ) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def encode_node(self, tree_tensors, hatom, node_idx): """ return the node embedding learned from MPN given tree tensors, learned atom embeddings and the index of node to be learned. """ hnode, hmess, agraph, bgraph = self.embed_tree(tree_tensors, hatom) hnode = index_select_ND(hnode, 0, node_idx) agraph = index_select_ND(agraph, 0, node_idx) hnode, _ = self.mpn(hnode, hmess, agraph, bgraph, self.depthT, self.W_t, self.outputNode) return hnode
def embed_tree(self, tree_tensors, hatom): fnode, fmess, agraph, bgraph, cgraph, dgraph, _ = tree_tensors finput = self.embedding(fnode) hnode = index_select_ND(hatom, 0, dgraph).sum(dim=1) hnode = self.W_i( torch.cat([finput, hnode], dim=-1) ) hmess1 = hnode.index_select(index=fmess[:,0], dim=0) hmess2 = index_select_ND(hatom, 0, cgraph).sum(dim=1) hmess = self.W_g( torch.cat([hmess1, hmess2], dim=-1) ) return hnode, hmess, agraph, bgraph
def forward(self, tree_tensors, graph_tensors, orders): tensors = self.embed_graph(graph_tensors) hatom, _ = self.mpn(*tensors, self.depthG, self.W_a, self.outputAtom) hatom[0, :] = hatom[0, :] * 0 tensors = self.embed_tree(tree_tensors, hatom) hnode, _ = self.mpn(*tensors, self.depthT, self.W_t, self.outputNode) hnode[0, :] = hnode[0, :] * 0 revise_nodes = [[edge[1] for edge in order] for order in orders] revise_nodes = create_pad_tensor(revise_nodes).to(device).long() embedding = index_select_ND(hnode, 0, revise_nodes).sum(dim=1) return embedding, hnode, hatom
def embed_tree(self, tree_tensors, hatom): """ Prepare the embeddings for tree message passing. Incoprate the learned embeddings for atoms into the tree node embeddings Args: tree_tensors: The data of junction tree hatom: The learned atom embeddings through graph message passing """ fnode, fmess, agraph, bgraph, cgraph, dgraph, _ = tree_tensors finput = self.embedding(fnode) # combine atom embeddings with node embeddings hnode = index_select_ND(hatom, 0, dgraph).sum(dim=1) hnode = self.W_i(torch.cat([finput, hnode], dim=-1)) # combine atom embeddings with edge embeddings hmess1 = hnode.index_select(index=fmess[:, 0], dim=0) hmess2 = index_select_ND(hatom, 0, cgraph).sum(dim=1) hmess = self.W_g(torch.cat([hmess1, hmess2], dim=-1)) return hnode, hmess, agraph, bgraph
def forward(self, holder, depth): fnode = create_var(holder[0]) fmess = create_var(holder[1]) node_graph = create_var(holder[2]) mess_graph = create_var(holder[3]) scope = holder[4] fnode = self.embedding(fnode) x = index_select_ND(fnode, 0, fmess) h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in xrange(depth): h_nei = index_select_ND(h, 0, mess_graph) h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h = h * mask mess_nei = index_select_ND(h, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) root_vecs = [node_vecs[st] for st, le in scope] return torch.stack(root_vecs, dim=0)
def mpn(self, hnode, hmess, agraph, bgraph, depth, W_m, W_n): """ Returns the node embeddings and message embeddings learned through message passing networks Args: hnode: initial node embeddings hmess: initial message embeddings agraph: message adjacency matrix for nodes. ( `agraph[i, j] = 1` represents that node i is connected with message j.) bgraph: message adjacency matrix for messages. ( `bgraph[i, j] = 1` represents that message i is connected with message j.) depth: depth of message passing W_m, W_n: functions used in message passing """ messages = MPNN(hmess, bgraph, W_m, depth, self.hidden_size) mess_nei = index_select_ND(messages, 0, agraph) node_vecs = torch.cat((hnode, mess_nei.sum(dim=1)), dim=-1) node_vecs = W_n(node_vecs) return node_vecs, messages
def forward(self, h, x, mess_graph): mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in xrange(self.depth): h_nei = index_select_ND(h, 0, mess_graph) sum_h = h_nei.sum(dim=1) z_input = torch.cat([x, sum_h], dim=1) z = F.sigmoid(self.W_z(z_input)) r_1 = self.W_r(x).view(-1, 1, self.hidden_size) r_2 = self.U_r(h_nei) r = F.sigmoid(r_1 + r_2) gated_h = r * h_nei sum_gated_h = gated_h.sum(dim=1) h_input = torch.cat([x, sum_gated_h], dim=1) pre_h = F.tanh(self.W_h(h_input)) h = (1.0 - z) * sum_h + z * pre_h h = h * mask return h
def forward(self, atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, bond_adjacency_graph, scope, tree_mess): """ Description: Implements the forward pass for encoding the candidate molecular subgraphs, for the graph decoding step. (Section 2.5) Args: atom_feature_matrix: torch.tensor (shape: num_atoms x ATOM_FEATURE_DIM) Matrix of atom features for all the atoms, over all the molecules, in the dataset. bond_feature_matrix: torch.tensor (shape: num_bonds x ATOM_FEATURE_DIM + BOND_FEATURE_DIM) Matrix of bond features for all the bond, over all the molecules, in the dataset. atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6)) For every atom across the training dataset, this atom_graph gives the bond idxs of all the bonds in which it is present. An atom can at most be present in MAX_NUM_NEIGHBORS(= 6) bonds. bond_adjacency_graph: torch.tensor (shape: num_bonds x MAX_NUM_NEIGHBORS(=6)) For every non-ring bond (cluster-node) across the training dataset, this bond_graph gives the bond idx of those non-ring bonds (cluster-nodes), to which it is connected in the "cluster-graph". scope: List[Tuple(int, int)] List of tuples of (total_atoms, num_atoms). Used to extract the atom features for a particular molecule in the dataset, from the atom_feature_matrix. Returns: mol_vecs: torch.tensor (shape: num_candidate_subgraphs x hidden_size) The encoding of all the candidate subgraphs for scoring purposes. (Section 2.5) """ # create PyTorch Variables atom_feature_matrix = create_var(atom_feature_matrix) bond_feature_matrix = create_var(bond_feature_matrix) atom_adjacency_graph = create_var(atom_adjacency_graph) bond_adjacency_graph = create_var(bond_adjacency_graph) static_messages = self.W_i(bond_feature_matrix) # apply ReLU activation for timestep, t = 0 graph_message = nn.ReLU()(static_messages) # implement message passing for timesteps, t = 1 to T (depth) for timestep in range(self.depth - 1): message = torch.cat([tree_mess, graph_message], dim=0) # obtain messages from all the "inward edges" neighbor_message_vecs = index_select_ND(message, 0, bond_adjacency_graph) # sum up all the "inward edge" message vectors neighbor_message_vecs_sum = neighbor_message_vecs.sum(dim=1) # multiply with the weight matrix for the hidden layer neighbor_message = self.W_h(neighbor_message_vecs_sum) # message at timestep t + 1 graph_message = nn.ReLU()(static_messages + neighbor_message) # neighbor message vectors for each node from the message matrix message = torch.cat([tree_mess, graph_message], dim=0) # neighbor message for each atom neighbor_message_vecs = index_select_ND(message, 0, atom_adjacency_graph) # neighbor message for each atom neighbor_message_atom_matrix = neighbor_message_vecs.sum(dim=1) # concatenate atom feature vector and neighbor hidden message vector atom_input_matrix = torch.cat( [atom_feature_matrix, neighbor_message_atom_matrix], dim=1) atom_hidden_layer_synaptic_input = nn.ReLU()( self.W_o(atom_input_matrix)) # list to store the corresponding molecule vectors for each molecule mol_vecs = [] for start_idx, len in scope: # mol_vec = atom_hidden_layer_synaptic_input.narrow(0, start_idx, len).sum(dim=0) / len mol_vec = atom_hidden_layer_synaptic_input[start_idx:start_idx + len].mean(dim=0) mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, node_wid_list, edge_node_idx_list, node_message_graph, mess_adjacency_graph, scope): """ Args: node_wid_list: torch.LongTensor() (shape: num_edges) The list of wids i.e. idx of the corresponding cluster vocabulary item, for the initial node of each edge. edge_node_idx_list: torch.LongTensor() (shape: num_edges) The list of idx of the initial node of each edge. node_message_graph: torch.LongTensor (shape: num_nodes x MAX_NUM_NEIGHBORS) For each node, the list of idxs of all "inward hidden edge message vectors" for purposes of node feature aggregation. mess_adjacency_graph: torch.LongTensor (shape: num_edges x MAX_NUM_NEIGHBORS) For each edge, the list of idxs of all "inward hidden edge message vectors" for purposes of node feature aggregation. scope: List[Tuple(int, int)] The list to store tuples of (start_idx, len) to segregate all the node features, for a particular junction-tree. mess_dict: Dict{Tuple(int, int): int} The dictionary mapping edge in the form (x.idx, y.idx) to idx of message. Returns: tree_vecs: torch.tensor (shape: batch_size x hidden_size) The hidden vectors for the root nodes, of all the junction-trees, across the entire dataset. """ # create PyTorch variables node_wid_list = create_var(node_wid_list) edge_node_idx_list = create_var(edge_node_idx_list) node_message_graph = create_var(node_message_graph) mess_adjacency_graph = create_var(mess_adjacency_graph) # hidden vectors for all the edges messages = create_var(torch.zeros(mess_adjacency_graph.size(0), self.hidden_size)) # obtain node feature embedding node_feature_embeddings = self.embedding(node_wid_list) # for each edge obtain the embedding for the initial node initial_node_features = index_select_ND(node_feature_embeddings, 0, edge_node_idx_list) # obtain the hidden vectors for all the edges using GRU messages = self.GRU(messages, initial_node_features, mess_adjacency_graph) # for each node, obtain all the neighboring message vectors node_neighbor_mess_vecs = index_select_ND(messages, 0, node_message_graph) # for each node, sum up all the neighboring message vectors node_neighbor_mess_vecs_sum = node_neighbor_mess_vecs.sum(dim=1) # for each node, concatenate the node embedding feature and the sum of hidden neighbor message vectors node_vecs_synaptic_input = torch.cat([node_feature_embeddings, node_neighbor_mess_vecs_sum], dim=-1) # apply the neural network layer node_vecs = self.outputNN(node_vecs_synaptic_input) # list to store feature vectors of the root node, for all the junction-trees, across the entire dataset root_vecs = [] for start_idx, _ in scope: # root node is the first node in the list of nodes of a juncion-tree by design root_vec = node_vecs[start_idx] root_vecs.append(root_vec) # stack the root tensors to form a 2-D tensor tree_vecs = torch.stack(root_vecs, dim=0) return tree_vecs, messages
def forward(self, atom_layer_input, bond_layer_input, atom_adjacency_graph, atom_bond_adjacency_graph, bond_atom_adjacency_graph): """ Args: atom_layer_input: torch.tensor (shape: batch_size x atom_feature_dim) The matrix containing feature vectors, for all the atoms, across the entire batch. * atom_feature_dim = len(ELEM_LIST) + 6 + 5 + 4 + 1 bond_layer_input: torch.tensor (shape: batch_size x bond_feature_dim) The matrix containing feature vectors, for all the bonds, across the entire batch. * bond_feature_dim = 5 + 6 atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6)) For each atom, across the entire batch, the idxs of neighboring atoms. atom_bond_adjacency_graph: torch.tensor(shape: num_atoms x MAX_NUM_NEIGHBORS(=6)) For each atom, across the entire batch, the idxs of all the bonds, in which it is the initial atom. bond_atom_adjacency_graph: torch.tensor (shape num_bonds x 2) For each bond, across the entire batch, the idxs of the 2 atoms, of which the bond is composed of. """ # implement edge gate computation edge_gate_x = torch.index_select(input=atom_layer_input, dim=0, index=bond_atom_adjacency_graph[:, 0]) edge_gate_y = torch.index_select(input=atom_layer_input, dim=0, index=bond_atom_adjacency_graph[:, 1]) assert(bond_layer_input.shape[0] == edge_gate_x.shape[0]) assert (bond_layer_input.shape[0] == edge_gate_y.shape[0]) edge_gate_synaptic_input = self.A(bond_layer_input) + self.B(edge_gate_x) + self.C(edge_gate_y) # apply sigmoid activation for computing edge gates edge_gates = F.sigmoid(edge_gate_synaptic_input) # implement batch normalization for bond/edge features # edge_gate_synaptic_input = self.bn_bond_features(edge_gate_synaptic_input) # apply ReLU activation for computing new bond features # add residual # if self.bond_feature_dim == 0: # bond_layer_output = F.relu(edge_gate_synaptic_input) + bond_layer_input # else: # bond_layer_output = F.relu(edge_gate_synaptic_input) bond_layer_output = F.relu(edge_gate_synaptic_input) # implement node features computation # for each atom, aggregate the features vectors of neighboring atoms atom_neighbor_features_tensor = self.V(index_select_ND(atom_layer_input, 0, atom_adjacency_graph)) # for each atom, get the edge gates for the corresponding neighbor atom features atom_neighbor_edge_gates_tensor = index_select_ND(edge_gates, 0, atom_bond_adjacency_graph) assert(atom_neighbor_edge_gates_tensor.shape == atom_neighbor_features_tensor.shape) # for each atom, multiply the edge gates with corresponding neighbor atom feature vectors atom_neighbor_message_tensor = atom_neighbor_edge_gates_tensor * atom_neighbor_features_tensor atom_neighbor_message_sum = atom_neighbor_message_tensor.sum(dim=1) assert(atom_neighbor_message_sum.shape[0] == atom_layer_input.shape[0]) atom_features_synaptic_input = self.U(atom_layer_input) + atom_neighbor_message_sum # implement batch normalization # atom_features_synaptic_input = self.bn_atom_features(atom_features_synaptic_input) # apply ReLU activation for computing new atom features # add residual # if self.atom_feature_dim == 0: # atom_layer_output = F.relu(atom_features_synaptic_input) + atom_layer_input # else: # atom_layer_output = F.relu(atom_features_synaptic_input) atom_layer_output = F.relu(atom_features_synaptic_input) return atom_layer_output, bond_layer_output # for atom_idx in range(total_atoms): # # feature vector for the current atom # atom_feature_vec = atom_feature_matrix[atom_idx] # # # idxs of all the neighbor atoms, of current atom # neighbor_atom_idx = atom_adjacency_list[atom_idx] # # # feature vectors of all the neighbor atoms. # neighbor_atom_feature_vecs = torch.index_select(input=atom_feature_matrix, dim=0, index=neighbor_atom_idx) # # # idxs of all the bonds, in which this atom is that beginning atom # bond_atom_idx = atom_bond_adjacency_list[atom_idx] # # # feature vectors of all the bonds, in which this atom is that beginning atom # bond_feature_vecs = torch.index_select(input=bond_feature_matrix, dim=0, index=bond_atom_idx) # # # compute new bond features, of all the bond, in which this atom, is the starting atom # bond_features_synaptic_input = self.evaluate_bond_features_synaptic_input(atom_feature_vec, bond_feature_vecs, neighbor_atom_feature_vecs) # # # apply the ReLU activation onto the new edge # new_bond_features = nn.ReLU()(bond_features_synaptic_input) # # # update the feature vectors of all the bonds, in which this atom is the beginning atom # bond_layer_output[bond_atom_idx] = new_bond_features # # # evaluate the edge gates, for all the edges in which this atom is the beginning atom # edge_gates = nn.Sigmoid()(bond_features_synaptic_input) # # # implement point-wise multiplication (Hadamard Product) # edge_gate_prod_neighbor_vecs = edge_gates * self.V(neighbor_atom_feature_vecs) # # # sum up the hadamard product of the edge gates and the neighbor atom feature vectors # neighbor_vec_edge_gate_sum = torch.sum(edge_gate_prod_neighbor_vecs, dim=0) # # # evaluate the new feature vector for the atom # new_atom_vec_synaptic_input = self.U(atom_feature_vec) + neighbor_vec_edge_gate_sum # # # apply ReLU activation # new_atom_vec = nn.ReLU()(new_atom_vec_synaptic_input) # # # set the atom's feature vector to the new value # atom_layer_output[atom_idx] = new_atom_vec return atom_layer_output, bond_layer_output
def forward(self, cand_batch, tree_mess): fatoms, fbonds = [], [] in_bonds, all_bonds = [], [] mess_dict, all_mess = {}, [create_var(torch.zeros(self.hidden_size)) ] #Ensure index 0 is vec(0) total_atoms = 0 scope = [] for e, vec in tree_mess.iteritems(): mess_dict[e] = len(all_mess) all_mess.append(vec) for mol, all_nodes, ctr_node in cand_batch: n_atoms = mol.GetNumAtoms() ctr_bid = ctr_node.idx for atom in mol.GetAtoms(): fatoms.append(atom_features(atom)) in_bonds.append([]) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom() a2 = bond.GetEndAtom() x = a1.GetIdx() + total_atoms y = a2.GetIdx() + total_atoms #Here x_nid,y_nid could be 0 x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum() x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1 y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1 bfeature = bond_features(bond) b = len(all_mess) + len( all_bonds) #bond idx offseted by len(all_mess) all_bonds.append((x, y)) fbonds.append(torch.cat([fatoms[x], bfeature], 0)) in_bonds[y].append(b) b = len(all_mess) + len(all_bonds) all_bonds.append((y, x)) fbonds.append(torch.cat([fatoms[y], bfeature], 0)) in_bonds[x].append(b) if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if (x_bid, y_bid) in mess_dict: mess_idx = mess_dict[(x_bid, y_bid)] in_bonds[y].append(mess_idx) if (y_bid, x_bid) in mess_dict: mess_idx = mess_dict[(y_bid, x_bid)] in_bonds[x].append(mess_idx) scope.append((total_atoms, n_atoms)) total_atoms += n_atoms total_bonds = len(all_bonds) total_mess = len(all_mess) fatoms = torch.stack(fatoms, 0) fbonds = torch.stack(fbonds, 0) agraph = torch.zeros(total_atoms, MAX_NB).long() bgraph = torch.zeros(total_bonds, MAX_NB).long() tree_message = torch.stack(all_mess, dim=0) for a in xrange(total_atoms): for i, b in enumerate(in_bonds[a]): agraph[a, i] = b for b1 in xrange(total_bonds): x, y = all_bonds[b1] for i, b2 in enumerate( in_bonds[x]): #b2 is offseted by len(all_mess) if b2 < total_mess or all_bonds[b2 - total_mess][0] != y: bgraph[b1, i] = b2 fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) graph_message = nn.ReLU()(binput) for i in xrange(self.depth - 1): message = torch.cat([tree_message, graph_message], dim=0) nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) nei_message = self.W_h(nei_message) graph_message = nn.ReLU()(binput + nei_message) message = torch.cat([tree_message, graph_message], dim=0) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = nn.ReLU()(self.W_o(ainput)) mol_vecs = [] for st, le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def mpn(self, hnode, hmess, agraph, bgraph, depth, W_m, W_n): messages = MPNN(hmess, bgraph, W_m, depth, self.hidden_size) mess_nei = index_select_ND(messages, 0, agraph) node_vecs = torch.cat((hnode, mess_nei.sum(dim=1)), dim=-1) node_vecs = W_n(node_vecs) return node_vecs, messages
def encode_node(self, tree_tensors, hatom, node_idx): hnode, hmess, agraph, bgraph = self.embed_tree(tree_tensors, hatom) hnode = index_select_ND(hnode, 0, node_idx) agraph = index_select_ND(agraph, 0, node_idx) hnode, _ = self.mpn(hnode, hmess, agraph, bgraph, self.depthT, self.W_t, self.outputNode) return hnode