def mol2dgl(smiles_batch): n_nodes = 0 graph_list = [] for smiles in smiles_batch: atom_feature_list = [] bond_feature_list = [] bond_source_feature_list = [] graph = DGLGraph() mol = get_mol(smiles) for atom in mol.GetAtoms(): graph.add_node(atom.GetIdx()) atom_feature_list.append(atom_features(atom)) for bond in mol.GetBonds(): begin_idx = bond.GetBeginAtom().GetIdx() end_idx = bond.GetEndAtom().GetIdx() features = bond_features(bond) graph.add_edge(begin_idx, end_idx) bond_feature_list.append(features) # set up the reverse direction graph.add_edge(end_idx, begin_idx) bond_feature_list.append(features) atom_x = torch.stack(atom_feature_list) graph.set_n_repr({'x': atom_x}) if len(bond_feature_list) > 0: bond_x = torch.stack(bond_feature_list) graph.set_e_repr({ 'x': bond_x, 'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_() }) graph_list.append(graph) return graph_list
def mol2dgl(cand_batch, mol_tree_batch): cand_graphs = [] tree_mess_source_edges = [] # map these edges from trees to... tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_nodes = [] n_nodes = 0 for mol, mol_tree, ctr_node_id in cand_batch: atom_feature_list = [] bond_feature_list = [] ctr_node = mol_tree.nodes[ctr_node_id] ctr_bid = ctr_node['idx'] g = DGLGraph() for atom in mol.GetAtoms(): atom_feature_list.append(atom_features(atom)) g.add_node(atom.GetIdx()) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom() a2 = bond.GetEndAtom() begin_idx = a1.GetIdx() end_idx = a2.GetIdx() features = bond_features(bond) g.add_edge(begin_idx, end_idx) bond_feature_list.append(features) g.add_edge(end_idx, begin_idx) bond_feature_list.append(features) x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum() # Tree node ID in the batch x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1 y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1 if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if (x_bid, y_bid) in mol_tree_batch.edge_list: tree_mess_target_edges.append( (begin_idx + n_nodes, end_idx + n_nodes)) tree_mess_source_edges.append((x_bid, y_bid)) tree_mess_target_nodes.append(end_idx + n_nodes) if (y_bid, x_bid) in mol_tree_batch.edge_list: tree_mess_target_edges.append( (end_idx + n_nodes, begin_idx + n_nodes)) tree_mess_source_edges.append((y_bid, x_bid)) tree_mess_target_nodes.append(begin_idx + n_nodes) n_nodes += len(g.nodes) atom_x = torch.stack(atom_feature_list) g.set_n_repr({'x': atom_x}) if len(bond_feature_list) > 0: bond_x = torch.stack(bond_feature_list) g.set_e_repr({ 'x': bond_x, 'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_() }) cand_graphs.append(g) return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \ tree_mess_target_nodes