def __init__(self, mols: str, args: Namespace): """ Computes the graph structure and featurization of a molecule. :param smiles: A smiles string. :param args: Arguments. """ self.n_atoms = 0 # number of atoms self.n_bonds = 0 # number of bonds self.f_atoms = [] # mapping from atom index to atom features self.f_bonds = [ ] # mapping from bond index to concat(in_atom, bond) features self.a2b = [] # mapping from atom index to incoming bond indices self.b2a = [ ] # mapping from bond index to the index of the atom the bond is coming from self.b2revb = [ ] # mapping from bond index to the index of the reverse bond self.parity_atoms = [ ] # mapping from atom index to CW (+1), CCW (-1) or undefined tetra (0) self.edge_index = [] # list of tuples indicating presence of bonds self.y = [] # extract reactant, ts, product r_mol, ts_mol, p_mol = mols # fake the number of "atoms" if we are collapsing substructures n_atoms = r_mol.GetNumAtoms() # topological and 3d distance matrices tD_r = Chem.GetDistanceMatrix(r_mol) tD_p = Chem.GetDistanceMatrix(p_mol) D_r = Chem.Get3DDistanceMatrix(r_mol) D_p = Chem.Get3DDistanceMatrix(p_mol) D_ts = Chem.Get3DDistanceMatrix(ts_mol) # temporary featurization for a1 in range(n_atoms): # Node features self.f_atoms.append(atom_features(r_mol.GetAtomWithIdx(a1))) # Edge features for a2 in range(a1 + 1, n_atoms): # fully connected graph self.edge_index.extend([(a1, a2), (a2, a1)]) # for now, naively include both reac and prod b1_feats = [D_r[a1][a2], D_p[a1][a2]] b2_feats = [D_r[a2][a1], D_p[a2][a1]] # r_bond = r_mol.GetBondBetweenAtoms(a1, a2) # b1_feats.extend(bond_features(r_bond)) # b2_feats.extend(bond_features(r_bond)) # # p_bond = p_mol.GetBondBetweenAtoms(a1, a2) # b1_feats.extend(bond_features(p_bond)) # b2_feats.extend(bond_features(p_bond)) self.f_bonds.append(b1_feats) self.f_bonds.append(b2_feats) self.y.extend([D_ts[a1][a2], D_ts[a2][a1]])
def prepare_batch(batch_mols, MAX_SIZE): # Initialization size = len(batch_mols) V = np.zeros((size, MAX_SIZE, num_elements+1), dtype=np.float32) # vertices [MAX_SIZE[5]] E = np.zeros((size, MAX_SIZE, MAX_SIZE, 3), dtype=np.float32) # leftmost number in tuple is outermost array # 3 because 3 features: aromatic, bonded, exp(avg dist) # batch index * max atoms * max atoms * 3 edge features sizes = np.zeros(size, dtype=np.int32) # populated later on, corresponds to each ts coordinates = np.zeros((size, MAX_SIZE, 3), dtype=np.float32) # number of mols in batch * max number of atoms * 3 i.e. (xyz) for each atom in mol # Build atom features for bx in range(size): # iterate through batch? yes, bx is batch index reactant, product = batch_mols[bx] N_atoms = reactant.GetNumAtoms() sizes[bx] = int(N_atoms) # cast to int [can it not be an int?], but basically have each size value as int of number atoms corresponding to reactant # topological distances matrix i.e. number of bonds between atoms in mol e.g. molecule v1-v2-v3-v4 will have tdm[1][4]=3 # also symm matrix MAX_D = 10. # i.e. don't have more than 10 bonds between molecules D = (Chem.GetDistanceMatrix(reactant) + Chem.GetDistanceMatrix(product)) / 2 D[D > MAX_D] = MAX_D D_3D_rbf = np.exp( -( (Chem.Get3DDistanceMatrix(reactant) + Chem.Get3DDistanceMatrix(product) ) / 2) ) # lP: squared. AV: [is it?] # distance matrix between atoms aka topographic distance matrix aka geometric distance matrix # just averaging the distances corresponding to the same atom pairs in reactant and product for i in range(N_atoms): # Edge features for j in range(N_atoms): E[bx, i, j, 2] = D_3D_rbf[i][j] if D[i][j] == 1.: # if stays bonded if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic(): E[bx, i, j, 0] = 1. E[bx, i, j, 1] = 1. # so each reaction (reactant-product pair) has 3 features: whether aromatic, whether bond broken/formed, exp(avg dist) # Recover coordinates # for k, mol_typ in enumerate([reactant, ts, product]): pos = reactant.GetConformer().GetAtomPosition(i) np.asarray([pos.x, pos.y, pos.z]) coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z]) # bx is basically mol_id or rxn_id; each molecule has i atoms with (xyz) # Node features: whether HCNO present and then atomic number/10 atom = reactant.GetAtomWithIdx(i) # get type of atom e_ix = elements.index(atom.GetSymbol()) # get chem symbol of atom and corresponding elements index V[bx, i, e_ix] = 1. # whether HCNO present V[bx, i, num_elements] = atom.GetAtomicNum() / 10. # atomic number/10 batch_dict = { "nodes": V, "edges": E, "sizes": sizes, "coordinates": coordinates } return batch_dict, batch_mols
def prepare_batch(batch_mols): # Initialization size = len(batch_mols) V = np.zeros((size, max_size, num_elements + 1), dtype=np.float32) E = np.zeros((size, max_size, max_size, 3), dtype=np.float32) sizes = np.zeros(size, dtype=np.int32) coordinates = np.zeros((size, max_size, 3), dtype=np.float32) # Build atom features for bx in range(size): reactant, ts, product = batch_mols[bx] N_atoms = reactant.GetNumAtoms() sizes[bx] = int(N_atoms) # Topological distances matrix MAX_D = 10. D = (Chem.GetDistanceMatrix(reactant) + Chem.GetDistanceMatrix(product)) / 2 D[D > MAX_D] = 10. D_3D_rbf = np.exp( -((Chem.Get3DDistanceMatrix(reactant) + Chem.Get3DDistanceMatrix(product)) / 2)) # squared for i in range(N_atoms): # Edge features for j in range(N_atoms): E[bx, i, j, 2] = D_3D_rbf[i][j] if D[i][j] == 1.: # if stays bonded if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic(): E[bx, i, j, 0] = 1. E[bx, i, j, 1] = 1. # Recover coordinates; adapted for all # for k, mol_typ in enumerate([reactant, ts, product]): pos = ts.GetConformer().GetAtomPosition(i) np.asarray([pos.x, pos.y, pos.z]) coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z]) # Node features atom = reactant.GetAtomWithIdx(i) e_ix = elements.index(atom.GetSymbol()) V[bx, i, e_ix] = 1. V[bx, i, num_elements] = atom.GetAtomicNum() / 10. # V[bx, i, num_elements + 1] = atom.GetExplicitValence() / 10. # print(np.sum(np.square(V)),np.sum(np.square(E)), sizes) batch_dict = { "nodes": tf.constant(V), "edges": tf.constant(E), "sizes": tf.constant(sizes), "coordinates": tf.constant(coordinates) } return batch_dict
def xyz2ac(atomic_num_list, xyz): import numpy as np mol = get_proto_mol(atomic_num_list) conf = Chem.Conformer(mol.GetNumAtoms()) for i in range(mol.GetNumAtoms()): conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2])) mol.AddConformer(conf) d_mat = Chem.Get3DDistanceMatrix(mol) pt = Chem.GetPeriodicTable() num_atoms = len(atomic_num_list) ac = np.zeros((num_atoms, num_atoms)).astype(int) for i in range(num_atoms): a_i = mol.GetAtomWithIdx(i) rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * 1.24 for j in range(i + 1, num_atoms): a_j = mol.GetAtomWithIdx(j) rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * 1.24 if d_mat[i, j] <= rcov_i + rcov_j: ac[i, j] = 1 ac[j, i] = 1 return ac, mol
def make_fingerprints(data, length=512, verbose=False): fp_list = [ fingerprint(Chem.rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect, "Torsion "), fingerprint(lambda x: GetMorganFingerprintAsBitVect(x, 2, nBits=length), "Morgan"), fingerprint(FingerprintMol, "Estate (1995)"), fingerprint(lambda x: GetAvalonFP(x, nBits=length), "Avalon bit based (2006)"), fingerprint(lambda x: np.append(GetAvalonFP(x, nBits=length), Descriptors.MolWt(x)), "Avalon+mol. weight"), fingerprint(lambda x: GetErGFingerprint(x), "ErG fingerprint (2006)"), fingerprint(lambda x: RDKFingerprint(x, fpSize=length), "RDKit fingerprint"), fingerprint(lambda x: MACCSkeys.GenMACCSKeys(x), "MACCS fingerprint"), fingerprint(lambda x: get_fingerprint(x,fp_type='pubchem'), "PubChem"), # fingerprint(lambda x: get_fingerprint(x, fp_type='FP4'), "FP4") fingerprint(lambda x: Generate.Gen2DFingerprint(x,Gobbi_Pharm2D.factory,dMat=Chem.Get3DDistanceMatrix(x)), "3D pharmacophore"), ] for fp in fp_list: if (verbose): print("doing", fp.name) fp.apply_fp(data) return fp_list
def _get_distance_matrix(self, combo: Chem.Mol, A: Union[Chem.Mol, np.ndarray], B: Union[Chem.Mol, np.ndarray]) -> np.ndarray: """ Called by ``_find_closest`` and ``_determine_mergers_novel_ringcore_pair`` in collapse ring (for expansion). This is a distance matrix blanked so it is only distances to other fragment """ # TODO move to base once made. # input type if isinstance(A, Chem.Mol): mol_A = A A_idxs = np.arange(mol_A.GetNumAtoms()) else: mol_A = None A_idxs = np.array(A) if isinstance(B, Chem.Mol): mol_B = B B_idxs = np.arange(mol_B.GetNumAtoms()) + mol_A.GetNumAtoms() else: mol_B = None B_idxs = np.array(B) # make matrix distance_matrix = Chem.Get3DDistanceMatrix(combo) length = combo.GetNumAtoms() # nan fill the self values self._nan_fill_submatrix(distance_matrix, A_idxs) self._nan_fill_submatrix(distance_matrix, B_idxs) return distance_matrix
def get_AC(mol, covalent_factor=1.3): """ Generate adjacent matrix from atoms and coordinates. AC is a (num_atoms, num_atoms) matrix with 1 being covalent bond and 0 is not covalent_factor - 1.3 is an arbitrary factor args: mol - rdkit molobj with 3D conformer optional covalent_factor - increase covalent bond length threshold with facto returns: AC - adjacent matrix """ # Calculate distance matrix dMat = Chem.Get3DDistanceMatrix(mol) pt = Chem.GetPeriodicTable() num_atoms = mol.GetNumAtoms() AC = np.zeros((num_atoms, num_atoms), dtype=int) for i in range(num_atoms): a_i = mol.GetAtomWithIdx(i) Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * covalent_factor for j in range(i + 1, num_atoms): a_j = mol.GetAtomWithIdx(j) Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * covalent_factor if dMat[i, j] <= Rcov_i + Rcov_j: AC[i, j] = 1 AC[j, i] = 1 return AC
def join_overclose(self, mol: Chem.RWMol, to_check, cutoff=2.2): # was 1.8 """ Cutoff is adapted to element. :param mol: :param to_check: list of atoms indices that need joining (but not to each other) :param cutoff: CC bond :return: """ pt = Chem.GetPeriodicTable() dm = Chem.Get3DDistanceMatrix(mol) for i in to_check: atom_i = mol.GetAtomWithIdx(i) for j, atom_j in enumerate(mol.GetAtoms()): # calculate cutoff if not C-C if atom_i.GetSymbol() == '*' or atom_j.GetSymbol() == '*': ij_cutoff = cutoff elif atom_i.GetSymbol() == 'C' and atom_j.GetSymbol() == 'C': ij_cutoff = cutoff else: ij_cutoff = cutoff - 1.36 + sum([pt.GetRcovalent(atom.GetAtomicNum()) for atom in (atom_i, atom_j)]) # determine if to join if i == j or j in to_check: continue elif dm[i, j] > ij_cutoff: continue else: self._add_bond_if_possible(mol, atom_i, atom_j)
def xyz2AC(atomicNumList, xyz): mol = get_proto_mol(atomicNumList) conf = Chem.Conformer(mol.GetNumAtoms()) for i in range(mol.GetNumAtoms()): conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2])) mol.AddConformer(conf) dMat = Chem.Get3DDistanceMatrix(mol) pt = Chem.GetPeriodicTable() num_atoms = len(atomicNumList) AC = np.zeros((num_atoms, num_atoms)).astype(int) for i in range(num_atoms): a_i = mol.GetAtomWithIdx(i) Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * 1.30 for j in range(i + 1, num_atoms): a_j = mol.GetAtomWithIdx(j) Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * 1.30 if dMat[i, j] <= Rcov_i + Rcov_j: AC[i, j] = 1 AC[j, i] = 1 return AC, mol
def _find_centroid(self): a = Chem.Get3DDistanceMatrix(self.mol) n = np.linalg.norm(a, axis=0) # Frobenius norm i = int(np.argmin(n)) s = np.max(a, axis=0)[i] self.NBR_ATOM.append(self.mol.GetAtomWithIdx(i).GetPDBResidueInfo().GetName()) self.NBR_RADIUS.append(str(s))
def get_gobbi_similarity(correct_ligand, mol_to_fix, type_fp='normal', use_features=False): # ref = Chem.MolFromSmiles('NC(=[NH2+])c1ccc(C[C@@H](NC(=O)CNS(=O)(=O)c2ccc3ccccc3c2)C(=O)N2CCCCC2)cc1') ref = Chem.MolFromSmiles( 'C1=CC(=C(C=C1C2=C(C(=O)C3=C(C=C(C=C3O2)O)O)O)O)O') # mol1 = Chem.MolFromPDBFile(RDConfig.RDBaseDir + '/rdkit/Chem/test_data/1DWD_ligand.pdb') mol1 = AllChem.AssignBondOrdersFromTemplate(ref, correct_ligand) # mol2 = Chem.MolFromPDBFile(RDConfig.RDBaseDir + '/rdkit/Chem/test_data/1PPC_ligand.pdb') mol2 = AllChem.AssignBondOrdersFromTemplate(ref, mol_to_fix) factory = Gobbi_Pharm2D.factory fp1 = Generate.Gen2DFingerprint(mol1, factory, dMat=Chem.Get3DDistanceMatrix(mol1)) fp2 = Generate.Gen2DFingerprint(mol2, factory, dMat=Chem.Get3DDistanceMatrix(mol2)) # Tanimoto similarity tani = DataStructs.TanimotoSimilarity(fp1, fp2) print('GOBBI similarity is ------> ', tani)
def _generate_interaction_matrix(self, rdkit_mol): """Generate interaction matrix using the real-valued Euclidian distance as interaction feature.""" num_atoms = rdkit_mol.GetNumAtoms() interaction_matrix = np.zeros( (num_atoms, num_atoms, self.num_interaction_features)) interaction_matrix[:, :, 0] = Chem.Get3DDistanceMatrix(rdkit_mol) # add zero padding padding_size = self.max_num_atoms - num_atoms interaction_matrix = np.pad(interaction_matrix, ((0, padding_size), (0, padding_size), (0, 0)), mode='constant') return interaction_matrix
def join_rings(self, mol: Chem.RWMol, cutoff=1.8): # special case: x0749. bond between two rings # namely bonds are added to non-ring atoms. so in the case of bonded rings this is required. rings = self._get_ring_info(mol) dm = Chem.Get3DDistanceMatrix(mol) for ringA, ringB in itertools.combinations(rings, 2): if not self._are_rings_bonded(mol, ringA, ringB): mini = np.take(dm, ringA, 0) mini = np.take(mini, ringB, 1) d = np.nanmin(mini) if d < cutoff: p = np.where(mini == d) f = ringA[int(p[0][0])] s = ringB[int(p[1][0])] #mol.AddBond(f, s, Chem.BondType.SINGLE) self._add_bond_if_possible(mol, mol.GetAtomWithIdx(f), mol.GetAtomWithIdx(s))
def find_closest_to_ligand(cls, pdb: Chem.Mol, ligand_resn: str) -> Tuple[Chem.Atom, Chem.Atom]: """ Find the closest atom to the ligand :param pdb: a rdkit Chem object :param ligand_resn: 3 letter code :return: tuple of non-ligand atom and ligand atom """ ligand = [atom.GetIdx() for atom in pdb.GetAtoms() if atom.GetPDBResidueInfo().GetResidueName() == ligand_resn] dm = Chem.Get3DDistanceMatrix(pdb) mini = np.take(dm, ligand, 0) mini[mini == 0] = np.nan mini[:, ligand] = np.nan a, b = np.where(mini == np.nanmin(mini)) lig_atom = pdb.GetAtomWithIdx(ligand[int(a[0])]) nonlig_atom = pdb.GetAtomWithIdx(int(b[0])) return (nonlig_atom, lig_atom)
def _ring_overlap_scenario(self, mol, rings): # resolve the case where a border of two rings is lost. # the atoms have to be ajecent. dm = Chem.Get3DDistanceMatrix(mol) mergituri = [] for ring in rings: for n in ring['atom'].GetNeighbors(): if n.GetIntProp('_ori_i') == -1 and ring['atom'].GetIdx() < n.GetIdx(): # it may share a border. # variables to assess if overlap or bond conf = mol.GetConformer() rpos = np.array(conf.GetAtomPosition(ring['atom'].GetIdx())) npos = np.array(conf.GetAtomPosition(n.GetIdx())) if np.linalg.norm(rpos - npos) > 4: # is bond # is it connected via a bond and not an overlapping atom? # 2.8 ring dia vs. 2.8 + 1.5 CC bond # this will be fixed depending on if from history or not. pass else: # is overlap A_idxs_old = ring['ori_is'] A_idxs = [self._get_new_index(mol, i, search_collapsed=False) for i in A_idxs_old] B_idxs_old = json.loads(n.GetProp('_ori_is')) B_idxs = [self._get_new_index(mol, i, search_collapsed=False) for i in B_idxs_old] # do the have overlapping atoms already? if len(set(A_idxs).intersection(B_idxs)) != 0: continue else: #they still need merging # which atoms of A are closer to B center and vice versa pairs = [] for G_idxs, ref_i in [(A_idxs, n.GetIdx()), (B_idxs, ring['atom'].GetIdx())]: tm = np.take(dm, G_idxs, 0) tm2 = np.take(tm, [ref_i], 1) p = np.where(tm2 == np.nanmin(tm2)) f = G_idxs[int(p[0][0])] tm2[p[0][0], :] = np.ones(tm2.shape[1]) * float('inf') p = np.where(tm2 == np.nanmin(tm2)) s = G_idxs[int(p[1][0])] pairs.append((f,s)) # now determine which are closer if dm[pairs[0][0], pairs[1][0]] < dm[pairs[0][0], pairs[1][1]]: mergituri.append((pairs[0][0], pairs[1][0])) mergituri.append((pairs[0][1], pairs[1][1])) else: mergituri.append((pairs[0][0], pairs[1][1])) mergituri.append((pairs[0][1], pairs[1][0])) return mergituri
def xyz2AC(atomicNumList, xyz): mol = get_proto_mol(atomicNumList) conf = Chem.Conformer(mol.GetNumAtoms()) for i in range(mol.GetNumAtoms()): conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2])) mol.AddConformer(conf) dMat = Chem.Get3DDistanceMatrix(mol) Rcovtable = [0.31, 0.28, 1.28, 0.96, 0.85, 0.76, 0.71, 0.66, 0.57] num_atoms = len(atomicNumList) AC = np.zeros((num_atoms, num_atoms)).astype(int) for i in range(num_atoms): a_i = mol.GetAtomWithIdx(i) Rcov_i = Rcovtable[a_i.GetAtomicNum() - 1] * 1.30 for j in range(i + 1, num_atoms): a_j = mol.GetAtomWithIdx(j) Rcov_j = Rcovtable[a_j.GetAtomicNum() - 1] * 1.30 if dMat[i, j] <= Rcov_i + Rcov_j: AC[i, j] = 1 AC[j, i] = 1 return AC, mol
def absorb_overclose(self, mol, to_check=None, to_check2=None, cutoff:float=1.): # to_check list of indices to check and absorb into, else all atoms are tested dm = Chem.Get3DDistanceMatrix(mol) morituri = [] if to_check is None: to_check = list(range(mol.GetNumAtoms())) if to_check2 is None: to_check2 = list(range(mol.GetNumAtoms())) for i in to_check: if i in morituri: continue for j in to_check2: if i == j or j in morituri: continue elif dm[i, j] < cutoff: self._absorb(mol, i, j) morituri.append(j) else: pass # kill morituri for i in sorted(morituri, reverse=True): mol.RemoveAtom(i) return len(morituri)
def is_clashing(mol, conf_id, clash_threshold): matrix = Chem.Get3DDistanceMatrix(mol, confId=conf_id).flatten() matrix = matrix[matrix > 0] return sum(matrix < clash_threshold) != 0
def process_geometries(self, geometries): data_list = [] for rxn_id, rxn in enumerate(tqdm(geometries)): if rxn_id == self.n_rxns: break r, ts, p = rxn num_atoms = r.GetNumAtoms() f_bonds, edge_index, y = [], [], [] f_atoms = torch.zeros(self.MAX_NUM_ATOMS, self.NUM_NODE_FEATS, dtype=torch.float) # topological and 3d distance matrices, NOTE: topological currently unused tD_r = Chem.GetDistanceMatrix(r) tD_p = Chem.GetDistanceMatrix(p) D_r = Chem.Get3DDistanceMatrix(r) D_p = Chem.Get3DDistanceMatrix(p) D_ts = Chem.Get3DDistanceMatrix(ts) for i in range(num_atoms): # node feats atom = r.GetAtomWithIdx(i) e_ix = self.ELEM_TYPES[atom.GetSymbol()] f_atoms[i][e_ix] = 1 f_atoms[i][self.NUM_ELEMS] = atom.GetAtomicNum() / 10. # edge features for j in range(i + 1, num_atoms): # fully connected graph edge_index.extend([(i, j), (j, i)]) # for now, naively include both reac and prod b1_feats = [D_r[i][j], D_p[i][j]] b2_feats = [D_r[j][i], D_p[j][i]] f_bonds.append(b1_feats) f_bonds.append(b2_feats) y.extend([D_ts[i][j], D_ts[j][i]]) # node_feats = torch.tensor(f_atoms, dtype=torch.float) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() edge_attr = torch.tensor(f_bonds, dtype=torch.float) y = torch.tensor(y, dtype=torch.float) data = Data(x=f_atoms, edge_attr=edge_attr, edge_index=edge_index, y=y, idx=rxn_id) data_list.append(data) return data_list
def prepare_batch(batch_mols, max_size, elements): """ Returns batch with atom features, edge features, and coordinates initialised. Edge features based on topological distances (number of bonds between atoms) and 3D RBF distances. """ # func constants num_elements = len(elements) # initialise size = len(batch_mols) V = np.zeros((size, max_size, num_elements + 1), dtype=np.float32) E = np.zeros((size, max_size, max_size, 3), dtype=np.float32) sizes = np.zeros(size, dtype=np.int32) coordinates = np.zeros((size, max_size, 3), dtype=np.float32) # build molecule features for each reaction for bx in range(size): # get r, ts, p for reaction reactant, ts, product = batch_mols[bx] num_atoms = reactant.GetNumAtoms() sizes[bx] = int(num_atoms) # topological distances matrix D = (Chem.GetDistanceMatrix(reactant) + Chem.GetDistanceMatrix(product)) / 2 D[D > MAX_NUM_BONDS] = MAX_NUM_BONDS # 3D rbf matrix D_3D_rbf = np.exp(-((Chem.Get3DDistanceMatrix(reactant) + Chem.Get3DDistanceMatrix(product)) / 2)) for i in range(num_atoms): # edge features (stays bonded and bond aromatic?, stays bonded?, 3D rbf distance) for j in range(num_atoms): # if stays bonded if D[i][j] == 1.: # if aromatic bond if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic(): E[bx, i, j, 0] = 1. E[bx, i, j, 1] = 1 # add 3D rbf dist E[bx, i, j, 2] = D_3D_rbf[i][j] # node features atom = reactant.GetAtomWithIdx(i) e_ix = elements.index(atom.GetSymbol()) V[bx, i, e_ix] = 1. V[bx, i, num_elements] = atom.GetAtomicNum() / 10. # recover coordinates pos = ts.GetConformer().GetAtomPosition(i) np.asarray([pos.x, pos.y, pos.z]) coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z]) batch_dict = { "nodes": tf.constant(V), "edges": tf.constant(E), "sizes": tf.constant(sizes), "coordinates": tf.constant(coordinates) } return batch_dict
def _transform_mol(self, mol): res = nanarray((len(mol.atoms), self.max_atoms)) res[:, :len(mol.atoms)] = Chem.Get3DDistanceMatrix(mol) return res
def construct_feature_matrices(self, mol): """ Given an rdkit mol, return atom feature matrices, bond feature matrices, and connectivity matrices. Returns dict with entries 'n_atom' : number of atoms in the molecule 'n_bond' : number of edges (likely n_atom * n_neighbors) 'atom' : (n_atom,) length list of atom classes 'bond' : (n_bond,) list of bond classes. 0 for no bond 'distance' : (n_bond,) list of bond distances 'connectivity' : (n_bond, 2) array of source atom, target atom pairs. """ n_atom = len(mol.GetAtoms()) # n_bond is actually the number of atom-atom pairs, so this is defined # by the number of neighbors for each atom. if self.n_neighbors <= (n_atom - 1): n_bond = self.n_neighbors * n_atom elif n_atom == 1: n_bond = 1 else: # If there are fewer atoms than n_neighbors, all atoms will be # connected n_bond = (n_atom - 1) * n_atom # Initialize the matrices to be filled in during the following loop. atom_feature_matrix = np.zeros(n_atom, dtype='int') bond_feature_matrix = np.zeros(n_bond, dtype='int') bond_distance_matrix = np.zeros(n_bond, dtype=np.float32) connectivity = np.zeros((n_bond, 2), dtype='int') # Hopefully we've filtered out all problem mols by now. if mol is None: raise RuntimeError("Issue in loading mol") distance_matrix = Chem.Get3DDistanceMatrix(mol) # Get a list of the atoms in the molecule. atom_seq = mol.GetAtoms() atoms = [atom_seq[i] for i in range(n_atom)] # Here we loop over each atom, and the inner loop iterates over each # neighbor of the current atom. bond_index = 0 # keep track of our current bond. for n, atom in enumerate(atoms): # update atom feature matrix atom_feature_matrix[n] = self.atom_tokenizer( self.atom_features(atom)) # if n_neighbors is greater than total atoms, then each atom is a # neighbor. if (self.n_neighbors + 1) > len(mol.GetAtoms()): end_index = len(mol.GetAtoms()) else: end_index = (self.n_neighbors + 1) # Loop over each of the nearest neighbors neighbor_inds = distance_matrix[n, :].argsort()[1:end_index] for neighbor in neighbor_inds: # update bond feature matrix bond = mol.GetBondBetweenAtoms(n, int(neighbor)) if bond is None: bond_feature_matrix[bond_index] = 0 else: rev = False if bond.GetBeginAtomIdx() == n else True bond_feature_matrix[bond_index] = self.bond_tokenizer( self.bond_features(bond, flipped=rev)) distance = distance_matrix[n, neighbor] bond_distance_matrix[bond_index] = distance # update connectivity matrix connectivity[bond_index, 0] = n connectivity[bond_index, 1] = neighbor bond_index += 1 return { 'n_atom': n_atom, 'n_bond': n_bond, 'atom': atom_feature_matrix, 'bond': bond_feature_matrix, 'distance': bond_distance_matrix, 'connectivity': connectivity, }
def check_test_distribution(args): """Check test distribution of TS reaction core is same as original MIT work.""" mols_folder = r'data/raw/' reactant_file = mols_folder + 'test_reactants.sdf' test_r = Chem.SDMolSupplier(reactant_file, removeHs=False, sanitize=False) test_r = [x for x in test_r] product_file = mols_folder + 'test_products.sdf' test_p = Chem.SDMolSupplier(product_file, removeHs=False, sanitize=False) test_p = [x for x in test_p] test_ts_file = mols_folder + 'test_ts.sdf' test_ts = Chem.SDMolSupplier(test_ts_file, removeHs=False, sanitize=False) test_ts = [ts for ts in test_ts] _, _, test_loader = construct_dataset_and_loaders(args) D_gts = [] for idx, rxn_batch in enumerate(test_loader): X_gt, mask = to_dense_batch( rxn_batch.pos_ts, rxn_batch.x_ts_batch, 0., max(rxn_batch.num_atoms)) # pos_ts = [b * max_num_nodes, 3] batched_D_gt = X_to_dist(X_gt) batched_D_gt = [x[-1] for x in torch.split(batched_D_gt, 1, 0)] D_gts.extend(batched_D_gt) assert len( D_gts ) == 842, f"Number of test dist matrices is {len(D_gts)}, you need 842 which was originally published in the MIT model." mine, gt = [], [] for idx in range(len(test_ts)): # num_atoms + mask for reaction core num_atoms = test_ts[idx].GetNumAtoms() core_mask = (Chem.GetAdjacencyMatrix(test_p[idx]) + Chem.GetAdjacencyMatrix(test_r[idx])) == 1 gt.append(np.ravel(Chem.Get3DDistanceMatrix(test_ts[idx]) * core_mask)) mine.append(np.ravel(D_gts[idx][0:num_atoms, 0:num_atoms] * core_mask)) all_ds = [gt, mine] all_ds = [np.concatenate(ds).ravel() for ds in all_ds] all_ds = [ds[ds != 0] for ds in all_ds] # only keep non-zero values fig, ax = plt.subplots(figsize=(12, 9)) sns.distplot(all_ds[0], color='b', kde_kws={ "lw": 5, "label": "GT" }, hist=False) sns.distplot(all_ds[1], color='r', kde_kws={ "lw": 3, "label": "Mine" }, hist=False) ax.legend(loc='upper right') ax.legend(fontsize=12) ax.set_ylabel('Density', fontsize=22) ax.set_xlabel(r'Distance ($\AA$)', fontsize=22) ax.tick_params(axis='both', which='major', labelsize=22) ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(True) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(True)
def featurization(r_mol: Chem.rdchem.Mol, p_mol: Chem.rdchem.Mol, ): """ Generates features of the reactant and product for one reaction as input for the network. Args: r_mol: RDKit molecule object for the reactant. p_mol: RDKit molecule object for the product. Returns: data: Torch Geometric Data object, storing the atom and bond features """ # compute properties with rdkit (only works if dataset is clean) r_mol.UpdatePropertyCache() p_mol.UpdatePropertyCache() # fake the number of "atoms" if we are collapsing substructures n_atoms = r_mol.GetNumAtoms() # topological and 3d distance matrices tD_r = Chem.GetDistanceMatrix(r_mol) tD_p = Chem.GetDistanceMatrix(p_mol) D_r = Chem.Get3DDistanceMatrix(r_mol) D_p = Chem.Get3DDistanceMatrix(p_mol) f_atoms = list() # atom (node) features edge_index = list() # list of tuples indicating presence of bonds f_bonds = list() # bond (edge) features for a1 in range(n_atoms): # Node features f_atoms.append(atom_features(r_mol.GetAtomWithIdx(a1))) # Edge features for a2 in range(a1 + 1, n_atoms): # fully connected graph edge_index.extend([(a1, a2), (a2, a1)]) # for now, naively include both reac and prod b1_feats = [D_r[a1][a2], D_p[a1][a2]] b2_feats = [D_r[a2][a1], D_p[a2][a1]] # r_bond = r_mol.GetBondBetweenAtoms(a1, a2) # b1_feats.extend(bond_features(r_bond)) # b2_feats.extend(bond_features(r_bond)) # # p_bond = p_mol.GetBondBetweenAtoms(a1, a2) # b1_feats.extend(bond_features(p_bond)) # b2_feats.extend(bond_features(p_bond)) f_bonds.append(b1_feats) f_bonds.append(b2_feats) data = tg.data.Data() data.x = torch.tensor(f_atoms, dtype=torch.float) data.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() data.edge_attr = torch.tensor(f_bonds, dtype=torch.float) return data
def construct_feature_matrices(self, entry): """ Given an entry contining rdkit molecule, bond_index and for the target property, return atom feature matrices, bond feature matrices, distance matrices, connectivity matrices and bond ref matrices. returns dict with entries see MolPreproccessor 'bond_index' : ref array to the bond index """ mol, atom_index_array = entry n_atom = len(mol.GetAtoms()) n_pro = len(atom_index_array) # n_bond is actually the number of atom-atom pairs, so this is defined # by the number of neighbors for each atom. #if there is cutoff, distance_matrix = Chem.Get3DDistanceMatrix(mol) if self.n_neighbors <= (n_atom - 1): n_bond = self.n_neighbors * n_atom else: # If there are fewer atoms than n_neighbors, all atoms will be # connected n_bond = distance_matrix[(distance_matrix < self.cutoff) & (distance_matrix != 0)].size if n_bond == 0: n_bond = 1 # Initialize the matrices to be filled in during the following loop. atom_feature_matrix = np.zeros(n_atom, dtype='int') bond_feature_matrix = np.zeros(n_bond, dtype='int') bond_distance_matrix = np.zeros(n_bond, dtype=np.float32) atom_index_matrix = np.full(n_atom, -1, dtype='int') connectivity = np.zeros((n_bond, 2), dtype='int') # Hopefully we've filtered out all problem mols by now. if mol is None: raise RuntimeError("Issue in loading mol") # Get a list of the atoms in the molecule. atom_seq = mol.GetAtoms() atoms = [atom_seq[i] for i in range(n_atom)] # Here we loop over each atom, and the inner loop iterates over each # neighbor of the current atom. bond_index = 0 # keep track of our current bond. for n, atom in enumerate(atoms): # update atom feature matrix atom_feature_matrix[n] = self.atom_tokenizer( self.atom_features(atom)) try: atom_index_matrix[n] = atom_index_array.tolist().index( atom.GetIdx()) except: pass # if n_neighbors is greater than total atoms, then each atom is a # neighbor. if (self.n_neighbors + 1) > len(mol.GetAtoms()): neighbor_end_index = len(mol.GetAtoms()) else: neighbor_end_index = (self.n_neighbors + 1) distance_atom = distance_matrix[n, :] cutoff_end_index = distance_atom[distance_atom < self.cutoff].size end_index = min(neighbor_end_index, cutoff_end_index) # Loop over each of the nearest neighbors neighbor_inds = distance_matrix[n, :].argsort()[1:end_index] if len(neighbor_inds) == 0: neighbor_inds = [n] for neighbor in neighbor_inds: # update bond feature matrix bond = mol.GetBondBetweenAtoms(n, int(neighbor)) if bond is None: bond_feature_matrix[bond_index] = 0 else: rev = False if bond.GetBeginAtomIdx() == n else True bond_feature_matrix[bond_index] = self.bond_tokenizer( self.bond_features(bond, flipped=rev)) distance = distance_matrix[n, neighbor] bond_distance_matrix[bond_index] = distance # update connectivity matrix connectivity[bond_index, 0] = n connectivity[bond_index, 1] = neighbor bond_index += 1 return { 'n_atom': n_atom, 'n_bond': n_bond, 'n_pro': n_pro, 'atom': atom_feature_matrix, 'bond': bond_feature_matrix, 'distance': bond_distance_matrix, 'connectivity': connectivity, 'atom_index': atom_index_matrix, }
def process(self): # Host-Host hh = pd.read_table(self.raw_paths[2]).values.T # Virus-Virus vv = pd.read_table(self.raw_paths[3], usecols=[ 'EntrezGeneID_InteractorA', 'EntrezGeneID_InteractorB', 'OrganismID_InteractorA', 'OrganismID_InteractorB' ], na_values='-').fillna(-1).convert_dtypes(int).values human_only = np.all(vv[:, 2:] == 9606, axis=-1) # H**o Sapiens if self.add_missing_proteins: hh = np.concatenate((hh, vv[human_only, :2].T)) entrez_ids, interactome = np.unique(hh, return_inverse=True) interactome = interactome.reshape((2, -1)) pid2pos = {idx: pos for pos, idx in enumerate(entrez_ids)} vv = vv[~human_only] mask = vv.T[2:] == 9606 vv = vv.T[:2] virus_ids, vv[~mask] = np.unique(vv[~mask], return_inverse=True) vv[mask] = [pid2pos.get(pid, -1) for pid in vv[mask]] # Virus-Host human_related = np.any(mask, axis=0) vv = np.where(mask[:1], vv[::-1], vv) vh = vv[:, human_related & (vv[1] != -1)] vv = vv[:, ~human_related] # K-Hop Restriction if self.restrict_to is not None: mask = np.zeros(entrez_ids.shape[0], dtype=bool) row, col = interactome mask[vh[1].astype(int)] = True for _ in range(self.restrict_to): last = mask.copy() mask[col] |= last[row] mask[row] |= last[col] interactome = interactome[:, mask[row] & mask[col]] old_ids, interactome = np.unique(interactome, return_inverse=True) interactome = interactome.reshape((2, -1)) pid2pos = {idx: pos for pos, idx in enumerate(entrez_ids[mask])} vh[1] = [pid2pos[entrez_ids[pid]] for pid in vh[1]] entrez_ids = entrez_ids[mask] # Drug Structures drugs = PandasTools.LoadSDF(self.raw_paths[0], idName='RowID')[['RowID', 'ROMol']] drugs.iloc[:, 0] = drugs.iloc[:, 0].apply(lambda did: int(did[2:])) drugs = drugs.set_index('RowID') # Drugs-Host dh = pd.read_table(self.raw_paths[1], converters={ '#DrugBankID': lambda did: int(did[2:]), 'EntrezGeneID': lambda pid: int(pid2pos.get(int(pid), -1)) }) dh = dh[dh.iloc[:, 1] != -1].join(drugs[[]], on='#DrugBankID', how='inner').values.T drug_ids, dh[0] = np.unique(dh[0], return_inverse=True) drugs = drugs.loc[drug_ids, 'ROMol'] data_list = [] for mol in drugs: if self.feature_extractor is None: adj = Chem.GetAdjacencyMatrix(mol) * Chem.Get3DDistanceMatrix(mol) nz = np.nonzero(adj) features = { 'num_nodes': adj.shape[0], 'edge_index': torch.LongTensor(np.stack(nz)), 'edge_attr': torch.FloatTensor(adj[nz]) } else: features = self.feature_extractor(mol) data = Data(**features) if self.pre_transform is not None: data = self.pre_transform(data) if self.pre_filter is None or self.pre_filter(data): data_list.append(data) self.data, self.slices = self.collate(data_list) num_proteins = entrez_ids.shape[0] num_virus = virus_ids.shape[0] host_host = { 'num_nodes': num_proteins, 'edge_index': to_undirected(torch.LongTensor(interactome), num_nodes=num_proteins) } virus_virus = { 'num_nodes': num_virus, 'edge_index': to_undirected(torch.LongTensor(vv.astype(int)), num_nodes=num_virus) } if self.protein_features is not None: host_host.update(self.protein_features(entrez_ids)) # virus_virus.update(self.protein_features(virus_ids)) virus_host = torch.LongTensor(vh.astype(int)) drug_host = torch.LongTensor(dh.astype(int)) self.entrez_ids = torch.LongTensor(entrez_ids.astype(int)) self.virus_ids = torch.LongTensor(virus_ids.astype(int)) self.drug_ids = torch.LongTensor(drug_ids.astype(int)) torch.save((self.data, self.slices), self.processed_paths[0]) torch.save(drug_host, self.processed_paths[1]) torch.save(host_host, self.processed_paths[2]) torch.save(virus_host, self.processed_paths[3]) torch.save(virus_virus, self.processed_paths[4]) torch.save((self.entrez_ids, self.virus_ids, self.drug_ids), self.processed_paths[5])
def process_geometry_file(self, geometry_file, current_list=None): """ Transforms molecules to their atom features and adjacency lists. Notes: - Code mostly lifted from PyG QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html """ data_list = current_list if current_list else [] counted = len(data_list) full_path = self.root + geometry_file geometries = Chem.SDMolSupplier(full_path, removeHs=False, sanitize=False) # get atom and edge features for each geometry for i, mol in enumerate(tqdm(geometries)): if i == self.n_rxns: break N = mol.GetNumAtoms() #atom_positions = [] #for c in mol.GetConformers(): # atom_positions.append(c.GetPositions()) #atom_positions = torch.tensor(atom_positions, dtype = torch.float) # get atom positions as matrix w shape [num_nodes, coord_dim] = [num_atoms, 3] atom_data = geometries.GetItemText(i).split('\n')[4:4 + N] atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data] atom_positions = torch.tensor(atom_positions, dtype=torch.float) # also get positions as flattened 3d dist matrix y = [] D_ts = Chem.Get3DDistanceMatrix(mol) for i in range(N): for j in range(i + 1, N): y.extend([D_ts[i][j], D_ts[j][i]]) y = torch.tensor(y, dtype=torch.float) # all the features type_idx, atomic_number, aromatic = [], [], [] sp, sp2, sp3 = [], [], [] num_hs = [] # atom/node features for atom in mol.GetAtoms(): type_idx.append(self.TYPES[atom.GetSymbol()]) atomic_number.append(atom.GetAtomicNum()) aromatic.append(1 if atom.GetIsAromatic() else 0) hybridisation = atom.GetHybridization() sp.append(1 if hybridisation == HybridizationType.SP else 0) sp2.append(1 if hybridisation == HybridizationType.SP2 else 0) sp3.append(1 if hybridisation == HybridizationType.SP3 else 0) # TODO: lucky does: whether bonded, 3D_rbf # bond/edge features row, col, edge_type = [], [], [] for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() row += [start, end] col += [end, start] # edge type for each bond type; *2 because both ways edge_type += 2 * [self.BONDS[bond.GetBondType()]] # edge_index is graph connectivity in COO format with shape [2, num_edges] edge_index = torch.tensor([row, col], dtype=torch.long) edge_type = torch.tensor(edge_type, dtype=torch.long) # edge_attr is edge feature matrix with shape [num_edges, num_edge_features] edge_attr = F.one_hot(edge_type, num_classes=len(self.BONDS)).to(torch.float) # order edges based on combined ascending order asc_order_perm = (edge_index[0] * N + edge_index[1]).argsort() edge_index = edge_index[:, asc_order_perm] edge_type = edge_type[asc_order_perm] edge_attr = edge_attr[asc_order_perm] row, col = edge_index z = torch.tensor(atomic_number, dtype=torch.long) hs = (z == 1).to(torch.float) # hydrogens num_hs = scatter(hs[row], col, dim_size=N).tolist() # scatter helps with one-hot x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.TYPES)) x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous() x = torch.cat([x1.to(torch.float), x2], dim=-1) idx = counted + i mol_data = Data(x=x, z=z, pos=atom_positions, y=y, edge_index=edge_index, edge_attr=edge_attr, idx=idx) data_list.append(mol_data) return data_list
def create_ds_dict(d_files, d_folder='d_inits/', mols_folder=r'data/raw/'): # base_folder is where the test mol sdf files are # all_test_res is dict of D_preds, TODO: add assert # TODO: add way to automate loading multiple files ... pass in file names # get test mols test_ts_file = mols_folder + 'test_ts.sdf' reactant_file = mols_folder + 'test_reactants.sdf' product_file = mols_folder + 'test_products.sdf' test_r = Chem.SDMolSupplier(reactant_file, removeHs=False, sanitize=False) test_r = [x for x in test_r] test_ts = Chem.SDMolSupplier(test_ts_file, removeHs=False, sanitize=False) test_ts = [ts for ts in test_ts] test_p = Chem.SDMolSupplier(product_file, removeHs=False, sanitize=False) test_p = [x for x in test_p] # save and load mit_d_init = np.load(d_folder + 'mit_best.npy') d_inits = [] for d_file in d_files: d_inits.append(np.load(d_folder + d_file)) num_d_inits = len(d_inits) # lists for plotting gt, mit, lin_approx = [], [], [] d_init_lists = [[] for _ in range(num_d_inits)] for idx in range(len(test_ts)): # num_atoms + mask for reaction core num_atoms = test_ts[idx].GetNumAtoms() core_mask = (Chem.GetAdjacencyMatrix(test_p[idx]) + Chem.GetAdjacencyMatrix(test_r[idx])) == 1 # main 3 gt.append(np.ravel(Chem.Get3DDistanceMatrix(test_ts[idx]) * core_mask)) mit.append( np.ravel(mit_d_init[idx][0:num_atoms, 0:num_atoms] * core_mask)) lin_approx.append( np.ravel((Chem.Get3DDistanceMatrix(test_r[idx]) + Chem.Get3DDistanceMatrix(test_p[idx])) / 2 * core_mask)) # other d_inits for j, d_init_list in enumerate(d_init_lists): d_init_lists[j].append( np.ravel(d_inits[j][idx][0:num_atoms, 0:num_atoms] * core_mask)) # make plottable all_ds = [gt, mit, lin_approx, *d_init_lists] all_ds = [np.concatenate(ds).ravel() for ds in all_ds] assert all_same([len(ds) for ds in all_ds ]), "Lengths of all ds after concat don't match." all_ds = [ds[ds != 0] for ds in all_ds] # only keep non-zero values assert all_same([len(ds) for ds in all_ds ]), "Lengths of all ds after removing zeroes don't match." ds_dict = {'gt': (all_ds[0], 'Ground Truth'), 'mit': (all_ds[1], 'MIT D_init'), \ 'lin_approx': (all_ds[2], 'Linear Approximation')} base_ds_counter = len(ds_dict) for d_id in range(len(d_init_lists)): name = f'D_fin{d_id}' ds_dict[name] = (all_ds[base_ds_counter + d_id], name) return ds_dict
def __init__(self, smiles: Union[str, Tuple[str, List[int]]], args: Namespace): if type(smiles) == tuple: smiles, index_map, substructures = smiles self.index_map = index_map # map from real atom idx to idx after collapsing substructures # the following map gives an arbitrary index for each substructure, but that'll be masked to 0 anyway self.reverse_index_map = [ index_map.index(i) for i in range(max(index_map) + 1) ] self.substructures = substructures self.substructure_atoms = set().union(*substructures) self.collapsing_substructures = True else: self.collapsing_substructures = False self.smiles = smiles self.n_atoms = 0 # number of atoms self.n_bonds = 0 # number of bonds self.f_atoms = [] # mapping from atom index to atom features self.f_bonds = [ ] # mapping from bond index to concat(in_atom, bond) features self.a2b = [] # mapping from atom index to incoming bond indices self.b2a = [ ] # mapping from bond index to the index of the atom the bond is coming from self.b2revb = [ ] # mapping from bond index to the index of the reverse bond # Convert smiles to molecule mol = Chem.MolFromSmiles(smiles) if not self.collapsing_substructures: self.index_map = range(mol.GetNumAtoms()) self.reverse_index_map = self.index_map self.substructures = set() self.substructure_atoms = set() # Add hydrogens if args.addHs: mol = Chem.AddHs(mol) # Get 3D distance matrix if args.three_d: mol = Chem.AddHs(mol) AllChem.EmbedMolecule(mol, AllChem.ETKDG()) if not args.addHs: mol = Chem.RemoveHs(mol) try: distances_3d = Chem.Get3DDistanceMatrix(mol) distances_3d = np.digitize( distances_3d, THREE_D_DISTANCE_BINS) # bin 3d distances except: # zero distance matrix, in case rdkit errors out print('distance embedding failed') distances_3d = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms())) # Get topological (i.e. path-length) distance matrix and number of atoms distances_path = Chem.GetDistanceMatrix(mol) # fake the number of "atoms" if we are collapsing substructures self.n_atoms = mol.GetNumAtoms( ) if not self.collapsing_substructures else max(self.index_map) + 1 # Get atom features if 'functional_group' in args.additional_atom_features: fg_featurizer = FunctionalGroupFeaturizer(args) fg_features = fg_featurizer.featurize(mol) for i, atom in enumerate(mol.GetAtoms()): if 'functional_group' in args.additional_atom_features: self.f_atoms.append(atom_features(atom, fg_features[i])) else: self.f_atoms.append(atom_features(atom)) for atom_idx in self.substructure_atoms: # mask all of these features to 0 since they'll end up collapsed in substructures self.f_atoms[atom_idx] = [ 0 for _ in range(len(self.f_atoms[atom_idx])) ] self.f_atoms = [ self.f_atoms[self.reverse_index_map[i]] for i in range(self.n_atoms) ] for _ in range(self.n_atoms): self.a2b.append([]) if args.learn_virtual_edges: model = args.lve_model # this is the MPN model, added to args so we can access it here f_atoms_cuda = torch.Tensor(self.f_atoms).cuda() processed_f_atoms_cuda = model.lve(f_atoms_cuda) lve_scores = torch.matmul(processed_f_atoms_cuda, f_atoms_cuda.t()) symmetric_lve_scores = lve_scores + lve_scores.t() # Get bond features for a1 in range(self.n_atoms): for a2 in range(a1 + 1, self.n_atoms): bond = mol.GetBondBetweenAtoms(self.reverse_index_map[a1], self.reverse_index_map[a2]) zero_bond = self.reverse_index_map[ a1] in self.substructure_atoms or self.reverse_index_map[ a2] in self.substructure_atoms # need to check all possible atoms in the substructure to see if there's a bond if zero_bond: candidate_endpoints = [] for atom_idx in [a1, a2]: reverse_idx = self.reverse_index_map[atom_idx] if reverse_idx in self.substructure_atoms: for substructure in self.substructures: if reverse_idx in substructure: candidate_endpoints.append(substructure) break else: candidate_endpoints.append([atom_idx]) for alternate_a1 in candidate_endpoints[0]: for alternate_a2 in candidate_endpoints[1]: if mol.GetBondBetweenAtoms( alternate_a1, alternate_a2) is not None: bond = mol.GetBondBetweenAtoms( alternate_a1, alternate_a2) # Randomly drop O(n_atoms) virtual edges so a total of O(n_atoms) edges instead of O(n_atoms^2) if bond is None: if not args.virtual_edges: continue if args.drop_virtual_edges and hash(str( (a1, a2))) % self.n_atoms != 0: continue # this option below doesn't seem to be as good # if args.drop_virtual_edges and (hash(mol_fatoms[a1])+hash(mol_fatoms[a2])) % n_atoms != 0: # continue if args.learn_virtual_edges: # if score less than 0, don't add the edge model = args.lve_model if symmetric_lve_scores[ a1, a2] < 0: # want symmetry in a1/a2 continue distance_3d = distances_3d[a1, a2] if args.three_d else None distance_path = distances_path[ a1, a2] if args.virtual_edges else None f_bond = bond_features(bond, distance_path=distance_path, distance_3d=distance_3d) if zero_bond: f_bond = [0 for _ in range(len(f_bond))] if args.atom_messages: self.f_bonds.append(f_bond) self.f_bonds.append(f_bond) else: self.f_bonds.append(self.f_atoms[a1] + f_bond) self.f_bonds.append(self.f_atoms[a2] + f_bond) # Update index mappings b1 = self.n_bonds b2 = b1 + 1 self.a2b[a2].append(b1) # b1 = a1 --> a2 self.b2a.append(a1) self.a2b[a1].append(b2) # b2 = a2 --> a1 self.b2a.append(a2) self.b2revb.append(b2) self.b2revb.append(b1) self.n_bonds += 2