Beispiel #1
0
    def __init__(self, mols: str, args: Namespace):
        """
        Computes the graph structure and featurization of a molecule.
        :param smiles: A smiles string.
        :param args: Arguments.
        """
        self.n_atoms = 0  # number of atoms
        self.n_bonds = 0  # number of bonds
        self.f_atoms = []  # mapping from atom index to atom features
        self.f_bonds = [
        ]  # mapping from bond index to concat(in_atom, bond) features
        self.a2b = []  # mapping from atom index to incoming bond indices
        self.b2a = [
        ]  # mapping from bond index to the index of the atom the bond is coming from
        self.b2revb = [
        ]  # mapping from bond index to the index of the reverse bond
        self.parity_atoms = [
        ]  # mapping from atom index to CW (+1), CCW (-1) or undefined tetra (0)
        self.edge_index = []  # list of tuples indicating presence of bonds
        self.y = []

        # extract reactant, ts, product
        r_mol, ts_mol, p_mol = mols

        # fake the number of "atoms" if we are collapsing substructures
        n_atoms = r_mol.GetNumAtoms()

        # topological and 3d distance matrices
        tD_r = Chem.GetDistanceMatrix(r_mol)
        tD_p = Chem.GetDistanceMatrix(p_mol)
        D_r = Chem.Get3DDistanceMatrix(r_mol)
        D_p = Chem.Get3DDistanceMatrix(p_mol)
        D_ts = Chem.Get3DDistanceMatrix(ts_mol)

        # temporary featurization
        for a1 in range(n_atoms):

            # Node features
            self.f_atoms.append(atom_features(r_mol.GetAtomWithIdx(a1)))

            # Edge features
            for a2 in range(a1 + 1, n_atoms):
                # fully connected graph
                self.edge_index.extend([(a1, a2), (a2, a1)])

                # for now, naively include both reac and prod
                b1_feats = [D_r[a1][a2], D_p[a1][a2]]
                b2_feats = [D_r[a2][a1], D_p[a2][a1]]

                # r_bond = r_mol.GetBondBetweenAtoms(a1, a2)
                # b1_feats.extend(bond_features(r_bond))
                # b2_feats.extend(bond_features(r_bond))
                #
                # p_bond = p_mol.GetBondBetweenAtoms(a1, a2)
                # b1_feats.extend(bond_features(p_bond))
                # b2_feats.extend(bond_features(p_bond))

                self.f_bonds.append(b1_feats)
                self.f_bonds.append(b2_feats)
                self.y.extend([D_ts[a1][a2], D_ts[a2][a1]])
Beispiel #2
0
def prepare_batch(batch_mols, MAX_SIZE):

    # Initialization
    size = len(batch_mols) 
    V = np.zeros((size, MAX_SIZE, num_elements+1), dtype=np.float32) # vertices  [MAX_SIZE[5]]
    E = np.zeros((size, MAX_SIZE, MAX_SIZE, 3), dtype=np.float32) # leftmost number in tuple is outermost array
                                                                  # 3 because 3 features: aromatic, bonded, exp(avg dist)
                                                                  # batch index * max atoms * max atoms * 3 edge features
    sizes = np.zeros(size, dtype=np.int32) # populated later on, corresponds to each ts
    coordinates = np.zeros((size, MAX_SIZE, 3), dtype=np.float32) # number of mols in batch * max number of atoms * 3 i.e. (xyz) for each atom in mol

    # Build atom features
    for bx in range(size): # iterate through batch? yes, bx is batch index
        reactant, product = batch_mols[bx] 
        N_atoms = reactant.GetNumAtoms()
        sizes[bx] = int(N_atoms)  # cast to int [can it not be an int?], but basically have each size value as int of number atoms corresponding to reactant

        # topological distances matrix i.e. number of bonds between atoms in mol e.g. molecule v1-v2-v3-v4 will have tdm[1][4]=3
        # also symm matrix
        MAX_D = 10. # i.e. don't have more than 10 bonds between molecules
        D = (Chem.GetDistanceMatrix(reactant) + Chem.GetDistanceMatrix(product)) / 2
        D[D > MAX_D] = MAX_D 

        D_3D_rbf = np.exp( -( (Chem.Get3DDistanceMatrix(reactant) + Chem.Get3DDistanceMatrix(product) ) / 2) )  # lP: squared. AV: [is it?]
        # distance matrix between atoms aka topographic distance matrix aka geometric distance matrix
        # just averaging the distances corresponding to the same atom pairs in reactant and product

        for i in range(N_atoms):
            # Edge features
            for j in range(N_atoms):
                E[bx, i, j, 2] = D_3D_rbf[i][j]
                if D[i][j] == 1.:  # if stays bonded
                    if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic():
                        E[bx, i, j, 0] = 1.
                    E[bx, i, j, 1] = 1. 
                    # so each reaction (reactant-product pair) has 3 features: whether aromatic, whether bond broken/formed, exp(avg dist)

            # Recover coordinates
            # for k, mol_typ in enumerate([reactant, ts, product]):
            pos = reactant.GetConformer().GetAtomPosition(i) 
            np.asarray([pos.x, pos.y, pos.z])
            coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z]) # bx is basically mol_id or rxn_id; each molecule has i atoms with (xyz)

            # Node features: whether HCNO present and then atomic number/10
            atom = reactant.GetAtomWithIdx(i) # get type of atom
            e_ix = elements.index(atom.GetSymbol()) # get chem symbol of atom and corresponding elements index
            V[bx, i, e_ix] = 1. # whether HCNO present
            V[bx, i, num_elements] = atom.GetAtomicNum() / 10. # atomic number/10

    batch_dict = {
        "nodes": V,
        "edges": E,
        "sizes": sizes,
        "coordinates": coordinates
    }
    return batch_dict, batch_mols
Beispiel #3
0
def prepare_batch(batch_mols):

    # Initialization
    size = len(batch_mols)
    V = np.zeros((size, max_size, num_elements + 1), dtype=np.float32)
    E = np.zeros((size, max_size, max_size, 3), dtype=np.float32)
    sizes = np.zeros(size, dtype=np.int32)
    coordinates = np.zeros((size, max_size, 3), dtype=np.float32)

    # Build atom features
    for bx in range(size):
        reactant, ts, product = batch_mols[bx]
        N_atoms = reactant.GetNumAtoms()
        sizes[bx] = int(N_atoms)

        # Topological distances matrix
        MAX_D = 10.
        D = (Chem.GetDistanceMatrix(reactant) +
             Chem.GetDistanceMatrix(product)) / 2
        D[D > MAX_D] = 10.

        D_3D_rbf = np.exp(
            -((Chem.Get3DDistanceMatrix(reactant) +
               Chem.Get3DDistanceMatrix(product)) / 2))  # squared

        for i in range(N_atoms):
            # Edge features
            for j in range(N_atoms):
                E[bx, i, j, 2] = D_3D_rbf[i][j]
                if D[i][j] == 1.:  # if stays bonded
                    if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic():
                        E[bx, i, j, 0] = 1.
                    E[bx, i, j, 1] = 1.

            # Recover coordinates; adapted for all
            # for k, mol_typ in enumerate([reactant, ts, product]):
            pos = ts.GetConformer().GetAtomPosition(i)
            np.asarray([pos.x, pos.y, pos.z])
            coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z])

            # Node features
            atom = reactant.GetAtomWithIdx(i)
            e_ix = elements.index(atom.GetSymbol())
            V[bx, i, e_ix] = 1.
            V[bx, i, num_elements] = atom.GetAtomicNum() / 10.
            # V[bx, i, num_elements + 1] = atom.GetExplicitValence() / 10.

    # print(np.sum(np.square(V)),np.sum(np.square(E)), sizes)
    batch_dict = {
        "nodes": tf.constant(V),
        "edges": tf.constant(E),
        "sizes": tf.constant(sizes),
        "coordinates": tf.constant(coordinates)
    }
    return batch_dict
Beispiel #4
0
def xyz2ac(atomic_num_list, xyz):
    import numpy as np
    mol = get_proto_mol(atomic_num_list)

    conf = Chem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2]))
    mol.AddConformer(conf)

    d_mat = Chem.Get3DDistanceMatrix(mol)
    pt = Chem.GetPeriodicTable()

    num_atoms = len(atomic_num_list)
    ac = np.zeros((num_atoms, num_atoms)).astype(int)

    for i in range(num_atoms):
        a_i = mol.GetAtomWithIdx(i)
        rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * 1.24
        for j in range(i + 1, num_atoms):
            a_j = mol.GetAtomWithIdx(j)
            rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * 1.24
            if d_mat[i, j] <= rcov_i + rcov_j:
                ac[i, j] = 1
                ac[j, i] = 1

    return ac, mol
def make_fingerprints(data, length=512, verbose=False):
    fp_list = [
        fingerprint(Chem.rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect,
                    "Torsion "),
        fingerprint(lambda x: GetMorganFingerprintAsBitVect(x, 2, nBits=length),
                    "Morgan"),
        fingerprint(FingerprintMol, "Estate (1995)"),
        fingerprint(lambda x: GetAvalonFP(x, nBits=length),
                    "Avalon bit based (2006)"),
        fingerprint(lambda x: np.append(GetAvalonFP(x, nBits=length), Descriptors.MolWt(x)),
                    "Avalon+mol. weight"),
        fingerprint(lambda x: GetErGFingerprint(x), "ErG fingerprint (2006)"),
        fingerprint(lambda x: RDKFingerprint(x, fpSize=length),
                    "RDKit fingerprint"),
        fingerprint(lambda x: MACCSkeys.GenMACCSKeys(x),
                    "MACCS fingerprint"),
        fingerprint(lambda x: get_fingerprint(x,fp_type='pubchem'), "PubChem"),
        # fingerprint(lambda x: get_fingerprint(x, fp_type='FP4'), "FP4")
        fingerprint(lambda x: Generate.Gen2DFingerprint(x,Gobbi_Pharm2D.factory,dMat=Chem.Get3DDistanceMatrix(x)),
                    "3D pharmacophore"),

    ]

    for fp in fp_list:
        if (verbose): print("doing", fp.name)
        fp.apply_fp(data)

    return fp_list
Beispiel #6
0
    def _get_distance_matrix(self, combo: Chem.Mol, A: Union[Chem.Mol,
                                                             np.ndarray],
                             B: Union[Chem.Mol, np.ndarray]) -> np.ndarray:
        """
        Called by ``_find_closest`` and ``_determine_mergers_novel_ringcore_pair`` in collapse ring (for expansion).
        This is a distance matrix blanked so it is only distances to other fragment

        """
        # TODO move to base once made.
        # input type
        if isinstance(A, Chem.Mol):
            mol_A = A
            A_idxs = np.arange(mol_A.GetNumAtoms())
        else:
            mol_A = None
            A_idxs = np.array(A)
        if isinstance(B, Chem.Mol):
            mol_B = B
            B_idxs = np.arange(mol_B.GetNumAtoms()) + mol_A.GetNumAtoms()
        else:
            mol_B = None
            B_idxs = np.array(B)
        # make matrix
        distance_matrix = Chem.Get3DDistanceMatrix(combo)
        length = combo.GetNumAtoms()
        # nan fill the self values
        self._nan_fill_submatrix(distance_matrix, A_idxs)
        self._nan_fill_submatrix(distance_matrix, B_idxs)
        return distance_matrix
Beispiel #7
0
def get_AC(mol, covalent_factor=1.3):
    """
    Generate adjacent matrix from atoms and coordinates.
    AC is a (num_atoms, num_atoms) matrix with 1 being covalent bond and 0 is not
    covalent_factor - 1.3 is an arbitrary factor
    args:
        mol - rdkit molobj with 3D conformer
    optional
        covalent_factor - increase covalent bond length threshold with facto
    returns:
        AC - adjacent matrix
    """

    # Calculate distance matrix
    dMat = Chem.Get3DDistanceMatrix(mol)

    pt = Chem.GetPeriodicTable()
    num_atoms = mol.GetNumAtoms()
    AC = np.zeros((num_atoms, num_atoms), dtype=int)

    for i in range(num_atoms):
        a_i = mol.GetAtomWithIdx(i)
        Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * covalent_factor
        for j in range(i + 1, num_atoms):
            a_j = mol.GetAtomWithIdx(j)
            Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * covalent_factor
            if dMat[i, j] <= Rcov_i + Rcov_j:
                AC[i, j] = 1
                AC[j, i] = 1

    return AC
    def join_overclose(self, mol: Chem.RWMol, to_check, cutoff=2.2): # was 1.8
        """
        Cutoff is adapted to element.

        :param mol:
        :param to_check: list of atoms indices that need joining (but not to each other)
        :param cutoff: CC bond
        :return:
        """
        pt = Chem.GetPeriodicTable()
        dm = Chem.Get3DDistanceMatrix(mol)
        for i in to_check:
            atom_i = mol.GetAtomWithIdx(i)
            for j, atom_j in enumerate(mol.GetAtoms()):
                # calculate cutoff if not C-C
                if atom_i.GetSymbol() == '*' or atom_j.GetSymbol() == '*':
                    ij_cutoff = cutoff
                elif atom_i.GetSymbol() == 'C' and atom_j.GetSymbol() == 'C':
                    ij_cutoff = cutoff
                else:
                    ij_cutoff = cutoff - 1.36 + sum([pt.GetRcovalent(atom.GetAtomicNum()) for atom in (atom_i, atom_j)])
                # determine if to join
                if i == j or j in to_check:
                    continue
                elif dm[i, j] > ij_cutoff:
                    continue
                else:
                    self._add_bond_if_possible(mol, atom_i, atom_j)
Beispiel #9
0
    def xyz2AC(atomicNumList, xyz):

        mol = get_proto_mol(atomicNumList)

        conf = Chem.Conformer(mol.GetNumAtoms())
        for i in range(mol.GetNumAtoms()):
            conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2]))
        mol.AddConformer(conf)

        dMat = Chem.Get3DDistanceMatrix(mol)
        pt = Chem.GetPeriodicTable()

        num_atoms = len(atomicNumList)
        AC = np.zeros((num_atoms, num_atoms)).astype(int)

        for i in range(num_atoms):
            a_i = mol.GetAtomWithIdx(i)
            Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum()) * 1.30
            for j in range(i + 1, num_atoms):
                a_j = mol.GetAtomWithIdx(j)
                Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum()) * 1.30
                if dMat[i, j] <= Rcov_i + Rcov_j:
                    AC[i, j] = 1
                    AC[j, i] = 1

        return AC, mol
Beispiel #10
0
 def _find_centroid(self):
     a = Chem.Get3DDistanceMatrix(self.mol)
     n = np.linalg.norm(a, axis=0)  # Frobenius norm
     i = int(np.argmin(n))
     s = np.max(a, axis=0)[i]
     self.NBR_ATOM.append(self.mol.GetAtomWithIdx(i).GetPDBResidueInfo().GetName())
     self.NBR_RADIUS.append(str(s))
Beispiel #11
0
def get_gobbi_similarity(correct_ligand,
                         mol_to_fix,
                         type_fp='normal',
                         use_features=False):
    # ref = Chem.MolFromSmiles('NC(=[NH2+])c1ccc(C[C@@H](NC(=O)CNS(=O)(=O)c2ccc3ccccc3c2)C(=O)N2CCCCC2)cc1')
    ref = Chem.MolFromSmiles(
        'C1=CC(=C(C=C1C2=C(C(=O)C3=C(C=C(C=C3O2)O)O)O)O)O')
    # mol1 = Chem.MolFromPDBFile(RDConfig.RDBaseDir + '/rdkit/Chem/test_data/1DWD_ligand.pdb')
    mol1 = AllChem.AssignBondOrdersFromTemplate(ref, correct_ligand)
    # mol2 = Chem.MolFromPDBFile(RDConfig.RDBaseDir + '/rdkit/Chem/test_data/1PPC_ligand.pdb')
    mol2 = AllChem.AssignBondOrdersFromTemplate(ref, mol_to_fix)

    factory = Gobbi_Pharm2D.factory
    fp1 = Generate.Gen2DFingerprint(mol1,
                                    factory,
                                    dMat=Chem.Get3DDistanceMatrix(mol1))
    fp2 = Generate.Gen2DFingerprint(mol2,
                                    factory,
                                    dMat=Chem.Get3DDistanceMatrix(mol2))
    # Tanimoto similarity
    tani = DataStructs.TanimotoSimilarity(fp1, fp2)
    print('GOBBI similarity is ------> ', tani)
Beispiel #12
0
    def _generate_interaction_matrix(self, rdkit_mol):
        """Generate interaction matrix using the real-valued Euclidian distance as interaction feature."""
        num_atoms = rdkit_mol.GetNumAtoms()
        interaction_matrix = np.zeros(
            (num_atoms, num_atoms, self.num_interaction_features))
        interaction_matrix[:, :, 0] = Chem.Get3DDistanceMatrix(rdkit_mol)

        # add zero padding
        padding_size = self.max_num_atoms - num_atoms
        interaction_matrix = np.pad(interaction_matrix,
                                    ((0, padding_size), (0, padding_size),
                                     (0, 0)),
                                    mode='constant')

        return interaction_matrix
 def join_rings(self, mol: Chem.RWMol, cutoff=1.8):
     # special case: x0749. bond between two rings
     # namely bonds are added to non-ring atoms. so in the case of bonded rings this is required.
     rings = self._get_ring_info(mol)
     dm = Chem.Get3DDistanceMatrix(mol)
     for ringA, ringB in itertools.combinations(rings, 2):
         if not self._are_rings_bonded(mol, ringA, ringB):
             mini = np.take(dm, ringA, 0)
             mini = np.take(mini, ringB, 1)
             d = np.nanmin(mini)
             if d < cutoff:
                 p = np.where(mini == d)
                 f = ringA[int(p[0][0])]
                 s = ringB[int(p[1][0])]
                 #mol.AddBond(f, s, Chem.BondType.SINGLE)
                 self._add_bond_if_possible(mol, mol.GetAtomWithIdx(f),  mol.GetAtomWithIdx(s))
    def find_closest_to_ligand(cls, pdb: Chem.Mol, ligand_resn: str) -> Tuple[Chem.Atom, Chem.Atom]:
        """
        Find the closest atom to the ligand

        :param pdb: a rdkit Chem object
        :param ligand_resn: 3 letter code
        :return: tuple of non-ligand atom and ligand atom
        """
        ligand = [atom.GetIdx() for atom in pdb.GetAtoms() if atom.GetPDBResidueInfo().GetResidueName() == ligand_resn]
        dm = Chem.Get3DDistanceMatrix(pdb)
        mini = np.take(dm, ligand, 0)
        mini[mini == 0] = np.nan
        mini[:, ligand] = np.nan
        a, b = np.where(mini == np.nanmin(mini))
        lig_atom = pdb.GetAtomWithIdx(ligand[int(a[0])])
        nonlig_atom = pdb.GetAtomWithIdx(int(b[0]))
        return (nonlig_atom, lig_atom)
 def _ring_overlap_scenario(self, mol, rings):
     # resolve the case where a border of two rings is lost.
     # the atoms have to be ajecent.
     dm = Chem.Get3DDistanceMatrix(mol)
     mergituri = []
     for ring in rings:
         for n in ring['atom'].GetNeighbors():
             if n.GetIntProp('_ori_i') == -1 and ring['atom'].GetIdx() < n.GetIdx():  # it may share a border.
                 # variables to assess if overlap or bond
                 conf = mol.GetConformer()
                 rpos = np.array(conf.GetAtomPosition(ring['atom'].GetIdx()))
                 npos = np.array(conf.GetAtomPosition(n.GetIdx()))
                 if np.linalg.norm(rpos - npos) > 4: # is bond
                     # is it connected via a bond and not an overlapping atom?
                     # 2.8 ring dia vs. 2.8 + 1.5 CC bond
                     # this will be fixed depending on if from history or not.
                     pass
                 else: # is overlap
                     A_idxs_old = ring['ori_is']
                     A_idxs = [self._get_new_index(mol, i, search_collapsed=False) for i in A_idxs_old]
                     B_idxs_old = json.loads(n.GetProp('_ori_is'))
                     B_idxs = [self._get_new_index(mol, i, search_collapsed=False) for i in B_idxs_old]
                     # do the have overlapping atoms already?
                     if len(set(A_idxs).intersection(B_idxs)) != 0:
                         continue
                     else:
                         #they still need merging
                         # which atoms of A are closer to B center and vice versa
                         pairs = []
                         for G_idxs, ref_i in [(A_idxs, n.GetIdx()), (B_idxs, ring['atom'].GetIdx())]:
                             tm = np.take(dm, G_idxs, 0)
                             tm2 = np.take(tm, [ref_i], 1)
                             p = np.where(tm2 == np.nanmin(tm2))
                             f = G_idxs[int(p[0][0])]
                             tm2[p[0][0], :] = np.ones(tm2.shape[1]) * float('inf')
                             p = np.where(tm2 == np.nanmin(tm2))
                             s = G_idxs[int(p[1][0])]
                             pairs.append((f,s))
                         # now determine which are closer
                         if dm[pairs[0][0], pairs[1][0]] < dm[pairs[0][0], pairs[1][1]]:
                             mergituri.append((pairs[0][0], pairs[1][0]))
                             mergituri.append((pairs[0][1], pairs[1][1]))
                         else:
                             mergituri.append((pairs[0][0], pairs[1][1]))
                             mergituri.append((pairs[0][1], pairs[1][0]))
     return mergituri
Beispiel #16
0
def xyz2AC(atomicNumList, xyz):
    mol = get_proto_mol(atomicNumList)
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        conf.SetAtomPosition(i, (xyz[i][0], xyz[i][1], xyz[i][2]))
    mol.AddConformer(conf)
    dMat = Chem.Get3DDistanceMatrix(mol)
    Rcovtable = [0.31, 0.28, 1.28, 0.96, 0.85, 0.76, 0.71, 0.66, 0.57]
    num_atoms = len(atomicNumList)
    AC = np.zeros((num_atoms, num_atoms)).astype(int)
    for i in range(num_atoms):
        a_i = mol.GetAtomWithIdx(i)
        Rcov_i = Rcovtable[a_i.GetAtomicNum() - 1] * 1.30
        for j in range(i + 1, num_atoms):
            a_j = mol.GetAtomWithIdx(j)
            Rcov_j = Rcovtable[a_j.GetAtomicNum() - 1] * 1.30
            if dMat[i, j] <= Rcov_i + Rcov_j:
                AC[i, j] = 1
                AC[j, i] = 1
    return AC, mol
 def absorb_overclose(self, mol, to_check=None, to_check2=None, cutoff:float=1.):
     # to_check list of indices to check and absorb into, else all atoms are tested
     dm = Chem.Get3DDistanceMatrix(mol)
     morituri = []
     if to_check is None:
         to_check = list(range(mol.GetNumAtoms()))
     if to_check2 is None:
         to_check2 = list(range(mol.GetNumAtoms()))
     for i in to_check:
         if i in morituri:
             continue
         for j in to_check2:
             if i == j or j in morituri:
                 continue
             elif dm[i, j] < cutoff:
                 self._absorb(mol, i, j)
                 morituri.append(j)
             else:
                 pass
     # kill morituri
     for i in sorted(morituri, reverse=True):
         mol.RemoveAtom(i)
     return len(morituri)
Beispiel #18
0
def is_clashing(mol, conf_id, clash_threshold):
    matrix = Chem.Get3DDistanceMatrix(mol, confId=conf_id).flatten()
    matrix = matrix[matrix > 0]
    return sum(matrix < clash_threshold) != 0
Beispiel #19
0
    def process_geometries(self, geometries):

        data_list = []

        for rxn_id, rxn in enumerate(tqdm(geometries)):

            if rxn_id == self.n_rxns:
                break

            r, ts, p = rxn

            num_atoms = r.GetNumAtoms()

            f_bonds, edge_index, y = [], [], []
            f_atoms = torch.zeros(self.MAX_NUM_ATOMS,
                                  self.NUM_NODE_FEATS,
                                  dtype=torch.float)

            # topological and 3d distance matrices, NOTE: topological currently unused
            tD_r = Chem.GetDistanceMatrix(r)
            tD_p = Chem.GetDistanceMatrix(p)
            D_r = Chem.Get3DDistanceMatrix(r)
            D_p = Chem.Get3DDistanceMatrix(p)
            D_ts = Chem.Get3DDistanceMatrix(ts)

            for i in range(num_atoms):

                # node feats
                atom = r.GetAtomWithIdx(i)
                e_ix = self.ELEM_TYPES[atom.GetSymbol()]
                f_atoms[i][e_ix] = 1
                f_atoms[i][self.NUM_ELEMS] = atom.GetAtomicNum() / 10.

                # edge features
                for j in range(i + 1, num_atoms):

                    # fully connected graph
                    edge_index.extend([(i, j), (j, i)])

                    # for now, naively include both reac and prod
                    b1_feats = [D_r[i][j], D_p[i][j]]
                    b2_feats = [D_r[j][i], D_p[j][i]]

                    f_bonds.append(b1_feats)
                    f_bonds.append(b2_feats)
                    y.extend([D_ts[i][j], D_ts[j][i]])

            # node_feats = torch.tensor(f_atoms, dtype=torch.float)
            edge_index = torch.tensor(edge_index,
                                      dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(f_bonds, dtype=torch.float)
            y = torch.tensor(y, dtype=torch.float)

            data = Data(x=f_atoms,
                        edge_attr=edge_attr,
                        edge_index=edge_index,
                        y=y,
                        idx=rxn_id)
            data_list.append(data)

        return data_list
Beispiel #20
0
def prepare_batch(batch_mols, max_size, elements):
    """ Returns batch with atom features, edge features, and coordinates initialised.
            Edge features based on topological distances (number of bonds between atoms) and 3D RBF distances. 
        """

    # func constants
    num_elements = len(elements)

    # initialise
    size = len(batch_mols)
    V = np.zeros((size, max_size, num_elements + 1), dtype=np.float32)
    E = np.zeros((size, max_size, max_size, 3), dtype=np.float32)
    sizes = np.zeros(size, dtype=np.int32)
    coordinates = np.zeros((size, max_size, 3), dtype=np.float32)

    # build molecule features for each reaction
    for bx in range(size):

        # get r, ts, p for reaction
        reactant, ts, product = batch_mols[bx]
        num_atoms = reactant.GetNumAtoms()
        sizes[bx] = int(num_atoms)

        # topological distances matrix
        D = (Chem.GetDistanceMatrix(reactant) +
             Chem.GetDistanceMatrix(product)) / 2
        D[D > MAX_NUM_BONDS] = MAX_NUM_BONDS

        # 3D rbf matrix
        D_3D_rbf = np.exp(-((Chem.Get3DDistanceMatrix(reactant) +
                             Chem.Get3DDistanceMatrix(product)) / 2))

        for i in range(num_atoms):

            # edge features (stays bonded and bond aromatic?, stays bonded?, 3D rbf distance)
            for j in range(num_atoms):
                # if stays bonded
                if D[i][j] == 1.:
                    # if aromatic bond
                    if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic():
                        E[bx, i, j, 0] = 1.
                    E[bx, i, j, 1] = 1
                # add 3D rbf dist
                E[bx, i, j, 2] = D_3D_rbf[i][j]

            # node features
            atom = reactant.GetAtomWithIdx(i)
            e_ix = elements.index(atom.GetSymbol())
            V[bx, i, e_ix] = 1.
            V[bx, i, num_elements] = atom.GetAtomicNum() / 10.

            # recover coordinates
            pos = ts.GetConformer().GetAtomPosition(i)
            np.asarray([pos.x, pos.y, pos.z])
            coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z])

    batch_dict = {
        "nodes": tf.constant(V),
        "edges": tf.constant(E),
        "sizes": tf.constant(sizes),
        "coordinates": tf.constant(coordinates)
    }

    return batch_dict
Beispiel #21
0
 def _transform_mol(self, mol):
     res = nanarray((len(mol.atoms), self.max_atoms))
     res[:, :len(mol.atoms)] = Chem.Get3DDistanceMatrix(mol)
     return res
Beispiel #22
0
    def construct_feature_matrices(self, mol):
        """ Given an rdkit mol, return atom feature matrices, bond feature
        matrices, and connectivity matrices.

        Returns
        dict with entries
        'n_atom' : number of atoms in the molecule
        'n_bond' : number of edges (likely n_atom * n_neighbors)
        'atom' : (n_atom,) length list of atom classes
        'bond' : (n_bond,) list of bond classes. 0 for no bond
        'distance' : (n_bond,) list of bond distances
        'connectivity' : (n_bond, 2) array of source atom, target atom pairs.
            
        """

        n_atom = len(mol.GetAtoms())

        # n_bond is actually the number of atom-atom pairs, so this is defined
        # by the number of neighbors for each atom.
        if self.n_neighbors <= (n_atom - 1):
            n_bond = self.n_neighbors * n_atom
        elif n_atom == 1:
            n_bond = 1
        else:
            # If there are fewer atoms than n_neighbors, all atoms will be
            # connected
            n_bond = (n_atom - 1) * n_atom

        # Initialize the matrices to be filled in during the following loop.
        atom_feature_matrix = np.zeros(n_atom, dtype='int')
        bond_feature_matrix = np.zeros(n_bond, dtype='int')
        bond_distance_matrix = np.zeros(n_bond, dtype=np.float32)
        connectivity = np.zeros((n_bond, 2), dtype='int')

        # Hopefully we've filtered out all problem mols by now.
        if mol is None:
            raise RuntimeError("Issue in loading mol")

        distance_matrix = Chem.Get3DDistanceMatrix(mol)

        # Get a list of the atoms in the molecule.
        atom_seq = mol.GetAtoms()
        atoms = [atom_seq[i] for i in range(n_atom)]

        # Here we loop over each atom, and the inner loop iterates over each
        # neighbor of the current atom.
        bond_index = 0  # keep track of our current bond.
        for n, atom in enumerate(atoms):

            # update atom feature matrix
            atom_feature_matrix[n] = self.atom_tokenizer(
                self.atom_features(atom))

            # if n_neighbors is greater than total atoms, then each atom is a
            # neighbor.
            if (self.n_neighbors + 1) > len(mol.GetAtoms()):
                end_index = len(mol.GetAtoms())
            else:
                end_index = (self.n_neighbors + 1)

            # Loop over each of the nearest neighbors
            neighbor_inds = distance_matrix[n, :].argsort()[1:end_index]
            for neighbor in neighbor_inds:

                # update bond feature matrix
                bond = mol.GetBondBetweenAtoms(n, int(neighbor))
                if bond is None:
                    bond_feature_matrix[bond_index] = 0
                else:
                    rev = False if bond.GetBeginAtomIdx() == n else True
                    bond_feature_matrix[bond_index] = self.bond_tokenizer(
                        self.bond_features(bond, flipped=rev))

                distance = distance_matrix[n, neighbor]
                bond_distance_matrix[bond_index] = distance

                # update connectivity matrix
                connectivity[bond_index, 0] = n
                connectivity[bond_index, 1] = neighbor

                bond_index += 1

        return {
            'n_atom': n_atom,
            'n_bond': n_bond,
            'atom': atom_feature_matrix,
            'bond': bond_feature_matrix,
            'distance': bond_distance_matrix,
            'connectivity': connectivity,
        }
def check_test_distribution(args):
    """Check test distribution of TS reaction core is same as original MIT work."""

    mols_folder = r'data/raw/'

    reactant_file = mols_folder + 'test_reactants.sdf'
    test_r = Chem.SDMolSupplier(reactant_file, removeHs=False, sanitize=False)
    test_r = [x for x in test_r]

    product_file = mols_folder + 'test_products.sdf'
    test_p = Chem.SDMolSupplier(product_file, removeHs=False, sanitize=False)
    test_p = [x for x in test_p]

    test_ts_file = mols_folder + 'test_ts.sdf'
    test_ts = Chem.SDMolSupplier(test_ts_file, removeHs=False, sanitize=False)
    test_ts = [ts for ts in test_ts]

    _, _, test_loader = construct_dataset_and_loaders(args)

    D_gts = []
    for idx, rxn_batch in enumerate(test_loader):
        X_gt, mask = to_dense_batch(
            rxn_batch.pos_ts, rxn_batch.x_ts_batch, 0.,
            max(rxn_batch.num_atoms))  # pos_ts = [b * max_num_nodes, 3]
        batched_D_gt = X_to_dist(X_gt)
        batched_D_gt = [x[-1] for x in torch.split(batched_D_gt, 1, 0)]
        D_gts.extend(batched_D_gt)

    assert len(
        D_gts
    ) == 842, f"Number of test dist matrices is {len(D_gts)}, you need 842 which was originally published in the MIT model."

    mine, gt = [], []
    for idx in range(len(test_ts)):

        # num_atoms + mask for reaction core
        num_atoms = test_ts[idx].GetNumAtoms()
        core_mask = (Chem.GetAdjacencyMatrix(test_p[idx]) +
                     Chem.GetAdjacencyMatrix(test_r[idx])) == 1

        gt.append(np.ravel(Chem.Get3DDistanceMatrix(test_ts[idx]) * core_mask))
        mine.append(np.ravel(D_gts[idx][0:num_atoms, 0:num_atoms] * core_mask))

    all_ds = [gt, mine]
    all_ds = [np.concatenate(ds).ravel() for ds in all_ds]
    all_ds = [ds[ds != 0] for ds in all_ds]  # only keep non-zero values

    fig, ax = plt.subplots(figsize=(12, 9))
    sns.distplot(all_ds[0],
                 color='b',
                 kde_kws={
                     "lw": 5,
                     "label": "GT"
                 },
                 hist=False)
    sns.distplot(all_ds[1],
                 color='r',
                 kde_kws={
                     "lw": 3,
                     "label": "Mine"
                 },
                 hist=False)

    ax.legend(loc='upper right')
    ax.legend(fontsize=12)
    ax.set_ylabel('Density', fontsize=22)
    ax.set_xlabel(r'Distance ($\AA$)', fontsize=22)
    ax.tick_params(axis='both', which='major', labelsize=22)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(True)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(True)
Beispiel #24
0
def featurization(r_mol: Chem.rdchem.Mol,
                  p_mol: Chem.rdchem.Mol,
                  ):
    """
    Generates features of the reactant and product for one reaction as input for the network.

    Args:
        r_mol: RDKit molecule object for the reactant.
        p_mol: RDKit molecule object for the product.

    Returns:
        data: Torch Geometric Data object, storing the atom and bond features
    """

    # compute properties with rdkit (only works if dataset is clean)
    r_mol.UpdatePropertyCache()
    p_mol.UpdatePropertyCache()

    # fake the number of "atoms" if we are collapsing substructures
    n_atoms = r_mol.GetNumAtoms()

    # topological and 3d distance matrices
    tD_r = Chem.GetDistanceMatrix(r_mol)
    tD_p = Chem.GetDistanceMatrix(p_mol)
    D_r = Chem.Get3DDistanceMatrix(r_mol)
    D_p = Chem.Get3DDistanceMatrix(p_mol)

    f_atoms = list()        # atom (node) features
    edge_index = list()     # list of tuples indicating presence of bonds
    f_bonds = list()        # bond (edge) features

    for a1 in range(n_atoms):

        # Node features
        f_atoms.append(atom_features(r_mol.GetAtomWithIdx(a1)))

        # Edge features
        for a2 in range(a1 + 1, n_atoms):
            # fully connected graph
            edge_index.extend([(a1, a2), (a2, a1)])

            # for now, naively include both reac and prod
            b1_feats = [D_r[a1][a2], D_p[a1][a2]]
            b2_feats = [D_r[a2][a1], D_p[a2][a1]]

            # r_bond = r_mol.GetBondBetweenAtoms(a1, a2)
            # b1_feats.extend(bond_features(r_bond))
            # b2_feats.extend(bond_features(r_bond))
            #
            # p_bond = p_mol.GetBondBetweenAtoms(a1, a2)
            # b1_feats.extend(bond_features(p_bond))
            # b2_feats.extend(bond_features(p_bond))

            f_bonds.append(b1_feats)
            f_bonds.append(b2_feats)

    data = tg.data.Data()
    data.x = torch.tensor(f_atoms, dtype=torch.float)
    data.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    data.edge_attr = torch.tensor(f_bonds, dtype=torch.float)

    return data
Beispiel #25
0
    def construct_feature_matrices(self, entry):
        """
        Given an entry contining rdkit molecule, bond_index and for the target property, 
        return atom 
        feature matrices, bond feature matrices, distance matrices, connectivity matrices and bond
        ref matrices.

        returns
        dict with entries
        see MolPreproccessor
        'bond_index' : ref array to the bond index
        """
        mol, atom_index_array = entry

        n_atom = len(mol.GetAtoms())
        n_pro = len(atom_index_array)

        # n_bond is actually the number of atom-atom pairs, so this is defined
        # by the number of neighbors for each atom.
        #if there is cutoff,
        distance_matrix = Chem.Get3DDistanceMatrix(mol)

        if self.n_neighbors <= (n_atom - 1):
            n_bond = self.n_neighbors * n_atom
        else:
            # If there are fewer atoms than n_neighbors, all atoms will be
            # connected
            n_bond = distance_matrix[(distance_matrix < self.cutoff)
                                     & (distance_matrix != 0)].size

        if n_bond == 0: n_bond = 1

        # Initialize the matrices to be filled in during the following loop.
        atom_feature_matrix = np.zeros(n_atom, dtype='int')
        bond_feature_matrix = np.zeros(n_bond, dtype='int')
        bond_distance_matrix = np.zeros(n_bond, dtype=np.float32)
        atom_index_matrix = np.full(n_atom, -1, dtype='int')
        connectivity = np.zeros((n_bond, 2), dtype='int')

        # Hopefully we've filtered out all problem mols by now.
        if mol is None:
            raise RuntimeError("Issue in loading mol")

        # Get a list of the atoms in the molecule.
        atom_seq = mol.GetAtoms()
        atoms = [atom_seq[i] for i in range(n_atom)]

        # Here we loop over each atom, and the inner loop iterates over each
        # neighbor of the current atom.
        bond_index = 0  # keep track of our current bond.
        for n, atom in enumerate(atoms):
            # update atom feature matrix
            atom_feature_matrix[n] = self.atom_tokenizer(
                self.atom_features(atom))
            try:
                atom_index_matrix[n] = atom_index_array.tolist().index(
                    atom.GetIdx())
            except:
                pass
            # if n_neighbors is greater than total atoms, then each atom is a
            # neighbor.
            if (self.n_neighbors + 1) > len(mol.GetAtoms()):
                neighbor_end_index = len(mol.GetAtoms())
            else:
                neighbor_end_index = (self.n_neighbors + 1)

            distance_atom = distance_matrix[n, :]
            cutoff_end_index = distance_atom[distance_atom < self.cutoff].size

            end_index = min(neighbor_end_index, cutoff_end_index)

            # Loop over each of the nearest neighbors

            neighbor_inds = distance_matrix[n, :].argsort()[1:end_index]
            if len(neighbor_inds) == 0: neighbor_inds = [n]
            for neighbor in neighbor_inds:

                # update bond feature matrix
                bond = mol.GetBondBetweenAtoms(n, int(neighbor))
                if bond is None:
                    bond_feature_matrix[bond_index] = 0
                else:
                    rev = False if bond.GetBeginAtomIdx() == n else True
                    bond_feature_matrix[bond_index] = self.bond_tokenizer(
                        self.bond_features(bond, flipped=rev))

                distance = distance_matrix[n, neighbor]
                bond_distance_matrix[bond_index] = distance

                # update connectivity matrix
                connectivity[bond_index, 0] = n
                connectivity[bond_index, 1] = neighbor

                bond_index += 1
        return {
            'n_atom': n_atom,
            'n_bond': n_bond,
            'n_pro': n_pro,
            'atom': atom_feature_matrix,
            'bond': bond_feature_matrix,
            'distance': bond_distance_matrix,
            'connectivity': connectivity,
            'atom_index': atom_index_matrix,
        }
    def process(self):
        # Host-Host
        hh = pd.read_table(self.raw_paths[2]).values.T

        # Virus-Virus
        vv = pd.read_table(self.raw_paths[3], usecols=[
            'EntrezGeneID_InteractorA', 'EntrezGeneID_InteractorB',
            'OrganismID_InteractorA', 'OrganismID_InteractorB'
        ], na_values='-').fillna(-1).convert_dtypes(int).values

        human_only = np.all(vv[:, 2:] == 9606, axis=-1)  # H**o Sapiens

        if self.add_missing_proteins:
            hh = np.concatenate((hh, vv[human_only, :2].T))

        entrez_ids, interactome = np.unique(hh, return_inverse=True)
        interactome = interactome.reshape((2, -1))
        pid2pos = {idx: pos for pos, idx in enumerate(entrez_ids)}

        vv = vv[~human_only]
        mask = vv.T[2:] == 9606
        vv = vv.T[:2]

        virus_ids, vv[~mask] = np.unique(vv[~mask], return_inverse=True)
        vv[mask] = [pid2pos.get(pid, -1) for pid in vv[mask]]

        # Virus-Host
        human_related = np.any(mask, axis=0)
        vv = np.where(mask[:1], vv[::-1], vv)
        vh = vv[:, human_related & (vv[1] != -1)]
        vv = vv[:, ~human_related]

        # K-Hop Restriction
        if self.restrict_to is not None:
            mask = np.zeros(entrez_ids.shape[0], dtype=bool)
            row, col = interactome
            mask[vh[1].astype(int)] = True

            for _ in range(self.restrict_to):
                last = mask.copy()
                mask[col] |= last[row]
                mask[row] |= last[col]

            interactome = interactome[:, mask[row] & mask[col]]
            old_ids, interactome = np.unique(interactome, return_inverse=True)
            interactome = interactome.reshape((2, -1))
            pid2pos = {idx: pos for pos, idx in enumerate(entrez_ids[mask])}
            vh[1] = [pid2pos[entrez_ids[pid]] for pid in vh[1]]
            entrez_ids = entrez_ids[mask]

        # Drug Structures
        drugs = PandasTools.LoadSDF(self.raw_paths[0], idName='RowID')[['RowID', 'ROMol']]
        drugs.iloc[:, 0] = drugs.iloc[:, 0].apply(lambda did: int(did[2:]))
        drugs = drugs.set_index('RowID')

        # Drugs-Host
        dh = pd.read_table(self.raw_paths[1], converters={
            '#DrugBankID': lambda did: int(did[2:]),
            'EntrezGeneID': lambda pid: int(pid2pos.get(int(pid), -1))
        })

        dh = dh[dh.iloc[:, 1] != -1].join(drugs[[]], on='#DrugBankID', how='inner').values.T
        drug_ids, dh[0] = np.unique(dh[0], return_inverse=True)
        drugs = drugs.loc[drug_ids, 'ROMol']

        data_list = []

        for mol in drugs:
            if self.feature_extractor is None:
                adj = Chem.GetAdjacencyMatrix(mol) * Chem.Get3DDistanceMatrix(mol)
                nz = np.nonzero(adj)

                features = {
                    'num_nodes': adj.shape[0],
                    'edge_index': torch.LongTensor(np.stack(nz)),
                    'edge_attr': torch.FloatTensor(adj[nz])
                }
            else:
                features = self.feature_extractor(mol)

            data = Data(**features)

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            if self.pre_filter is None or self.pre_filter(data):
                data_list.append(data)

        self.data, self.slices = self.collate(data_list)
        num_proteins = entrez_ids.shape[0]
        num_virus = virus_ids.shape[0]

        host_host = {
            'num_nodes': num_proteins,
            'edge_index': to_undirected(torch.LongTensor(interactome),
                                        num_nodes=num_proteins)
        }

        virus_virus = {
            'num_nodes': num_virus,
            'edge_index': to_undirected(torch.LongTensor(vv.astype(int)),
                                        num_nodes=num_virus)
        }

        if self.protein_features is not None:
            host_host.update(self.protein_features(entrez_ids))
            # virus_virus.update(self.protein_features(virus_ids))

        virus_host = torch.LongTensor(vh.astype(int))
        drug_host = torch.LongTensor(dh.astype(int))
        self.entrez_ids = torch.LongTensor(entrez_ids.astype(int))
        self.virus_ids = torch.LongTensor(virus_ids.astype(int))
        self.drug_ids = torch.LongTensor(drug_ids.astype(int))

        torch.save((self.data, self.slices), self.processed_paths[0])
        torch.save(drug_host, self.processed_paths[1])
        torch.save(host_host, self.processed_paths[2])
        torch.save(virus_host, self.processed_paths[3])
        torch.save(virus_virus, self.processed_paths[4])
        torch.save((self.entrez_ids, self.virus_ids, self.drug_ids), self.processed_paths[5])
    def process_geometry_file(self, geometry_file, current_list=None):
        """ Transforms molecules to their atom features and adjacency lists.
        Notes:
            - Code mostly lifted from PyG QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html 
        """
        data_list = current_list if current_list else []
        counted = len(data_list)
        full_path = self.root + geometry_file
        geometries = Chem.SDMolSupplier(full_path,
                                        removeHs=False,
                                        sanitize=False)

        # get atom and edge features for each geometry
        for i, mol in enumerate(tqdm(geometries)):

            if i == self.n_rxns:
                break

            N = mol.GetNumAtoms()

            #atom_positions = []
            #for c in mol.GetConformers():
            #    atom_positions.append(c.GetPositions())
            #atom_positions = torch.tensor(atom_positions, dtype = torch.float)

            # get atom positions as matrix w shape [num_nodes, coord_dim] = [num_atoms, 3]
            atom_data = geometries.GetItemText(i).split('\n')[4:4 + N]
            atom_positions = [[float(x) for x in line.split()[:3]]
                              for line in atom_data]
            atom_positions = torch.tensor(atom_positions, dtype=torch.float)

            # also get positions as flattened 3d dist matrix
            y = []
            D_ts = Chem.Get3DDistanceMatrix(mol)
            for i in range(N):
                for j in range(i + 1, N):
                    y.extend([D_ts[i][j], D_ts[j][i]])
            y = torch.tensor(y, dtype=torch.float)

            # all the features
            type_idx, atomic_number, aromatic = [], [], []
            sp, sp2, sp3 = [], [], []
            num_hs = []

            # atom/node features
            for atom in mol.GetAtoms():
                type_idx.append(self.TYPES[atom.GetSymbol()])
                atomic_number.append(atom.GetAtomicNum())
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                hybridisation = atom.GetHybridization()
                sp.append(1 if hybridisation == HybridizationType.SP else 0)
                sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)
                # TODO: lucky does: whether bonded, 3D_rbf

            # bond/edge features
            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                # edge type for each bond type; *2 because both ways
                edge_type += 2 * [self.BONDS[bond.GetBondType()]]
            # edge_index is graph connectivity in COO format with shape [2, num_edges]
            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            # edge_attr is edge feature matrix with shape [num_edges, num_edge_features]
            edge_attr = F.one_hot(edge_type,
                                  num_classes=len(self.BONDS)).to(torch.float)

            # order edges based on combined ascending order
            asc_order_perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, asc_order_perm]
            edge_type = edge_type[asc_order_perm]
            edge_attr = edge_attr[asc_order_perm]

            row, col = edge_index
            z = torch.tensor(atomic_number, dtype=torch.long)
            hs = (z == 1).to(torch.float)  # hydrogens
            num_hs = scatter(hs[row], col,
                             dim_size=N).tolist()  # scatter helps with one-hot

            x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.TYPES))
            x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                              dtype=torch.float).t().contiguous()
            x = torch.cat([x1.to(torch.float), x2], dim=-1)

            idx = counted + i
            mol_data = Data(x=x,
                            z=z,
                            pos=atom_positions,
                            y=y,
                            edge_index=edge_index,
                            edge_attr=edge_attr,
                            idx=idx)
            data_list.append(mol_data)

        return data_list
Beispiel #28
0
def create_ds_dict(d_files, d_folder='d_inits/', mols_folder=r'data/raw/'):
    # base_folder is where the test mol sdf files are
    # all_test_res is dict of D_preds, TODO: add assert
    # TODO: add way to automate loading multiple files ... pass in file names

    # get test mols
    test_ts_file = mols_folder + 'test_ts.sdf'
    reactant_file = mols_folder + 'test_reactants.sdf'
    product_file = mols_folder + 'test_products.sdf'
    test_r = Chem.SDMolSupplier(reactant_file, removeHs=False, sanitize=False)
    test_r = [x for x in test_r]
    test_ts = Chem.SDMolSupplier(test_ts_file, removeHs=False, sanitize=False)
    test_ts = [ts for ts in test_ts]
    test_p = Chem.SDMolSupplier(product_file, removeHs=False, sanitize=False)
    test_p = [x for x in test_p]

    # save and load
    mit_d_init = np.load(d_folder + 'mit_best.npy')
    d_inits = []
    for d_file in d_files:
        d_inits.append(np.load(d_folder + d_file))
    num_d_inits = len(d_inits)

    # lists for plotting
    gt, mit, lin_approx = [], [], []
    d_init_lists = [[] for _ in range(num_d_inits)]

    for idx in range(len(test_ts)):

        # num_atoms + mask for reaction core
        num_atoms = test_ts[idx].GetNumAtoms()
        core_mask = (Chem.GetAdjacencyMatrix(test_p[idx]) +
                     Chem.GetAdjacencyMatrix(test_r[idx])) == 1

        # main 3
        gt.append(np.ravel(Chem.Get3DDistanceMatrix(test_ts[idx]) * core_mask))
        mit.append(
            np.ravel(mit_d_init[idx][0:num_atoms, 0:num_atoms] * core_mask))
        lin_approx.append(
            np.ravel((Chem.Get3DDistanceMatrix(test_r[idx]) +
                      Chem.Get3DDistanceMatrix(test_p[idx])) / 2 * core_mask))

        # other d_inits
        for j, d_init_list in enumerate(d_init_lists):
            d_init_lists[j].append(
                np.ravel(d_inits[j][idx][0:num_atoms, 0:num_atoms] *
                         core_mask))

    # make plottable
    all_ds = [gt, mit, lin_approx, *d_init_lists]
    all_ds = [np.concatenate(ds).ravel() for ds in all_ds]
    assert all_same([len(ds) for ds in all_ds
                     ]), "Lengths of all ds after concat don't match."
    all_ds = [ds[ds != 0] for ds in all_ds]  # only keep non-zero values
    assert all_same([len(ds) for ds in all_ds
                     ]), "Lengths of all ds after removing zeroes don't match."

    ds_dict = {'gt': (all_ds[0], 'Ground Truth'), 'mit': (all_ds[1], 'MIT D_init'), \
               'lin_approx': (all_ds[2], 'Linear Approximation')}
    base_ds_counter = len(ds_dict)

    for d_id in range(len(d_init_lists)):
        name = f'D_fin{d_id}'
        ds_dict[name] = (all_ds[base_ds_counter + d_id], name)

    return ds_dict
Beispiel #29
0
    def __init__(self, smiles: Union[str, Tuple[str, List[int]]],
                 args: Namespace):
        if type(smiles) == tuple:
            smiles, index_map, substructures = smiles
            self.index_map = index_map  # map from real atom idx to idx after collapsing substructures
            # the following map gives an arbitrary index for each substructure, but that'll be masked to 0 anyway
            self.reverse_index_map = [
                index_map.index(i) for i in range(max(index_map) + 1)
            ]
            self.substructures = substructures
            self.substructure_atoms = set().union(*substructures)
            self.collapsing_substructures = True
        else:
            self.collapsing_substructures = False
        self.smiles = smiles
        self.n_atoms = 0  # number of atoms
        self.n_bonds = 0  # number of bonds
        self.f_atoms = []  # mapping from atom index to atom features
        self.f_bonds = [
        ]  # mapping from bond index to concat(in_atom, bond) features
        self.a2b = []  # mapping from atom index to incoming bond indices
        self.b2a = [
        ]  # mapping from bond index to the index of the atom the bond is coming from
        self.b2revb = [
        ]  # mapping from bond index to the index of the reverse bond

        # Convert smiles to molecule
        mol = Chem.MolFromSmiles(smiles)
        if not self.collapsing_substructures:
            self.index_map = range(mol.GetNumAtoms())
            self.reverse_index_map = self.index_map
            self.substructures = set()
            self.substructure_atoms = set()

        # Add hydrogens
        if args.addHs:
            mol = Chem.AddHs(mol)

        # Get 3D distance matrix
        if args.three_d:
            mol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol, AllChem.ETKDG())
            if not args.addHs:
                mol = Chem.RemoveHs(mol)
            try:
                distances_3d = Chem.Get3DDistanceMatrix(mol)
                distances_3d = np.digitize(
                    distances_3d, THREE_D_DISTANCE_BINS)  # bin 3d distances
            except:
                # zero distance matrix, in case rdkit errors out
                print('distance embedding failed')
                distances_3d = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms()))

        # Get topological (i.e. path-length) distance matrix and number of atoms
        distances_path = Chem.GetDistanceMatrix(mol)

        # fake the number of "atoms" if we are collapsing substructures
        self.n_atoms = mol.GetNumAtoms(
        ) if not self.collapsing_substructures else max(self.index_map) + 1

        # Get atom features
        if 'functional_group' in args.additional_atom_features:
            fg_featurizer = FunctionalGroupFeaturizer(args)
            fg_features = fg_featurizer.featurize(mol)
        for i, atom in enumerate(mol.GetAtoms()):
            if 'functional_group' in args.additional_atom_features:
                self.f_atoms.append(atom_features(atom, fg_features[i]))
            else:
                self.f_atoms.append(atom_features(atom))
        for atom_idx in self.substructure_atoms:
            # mask all of these features to 0 since they'll end up collapsed in substructures
            self.f_atoms[atom_idx] = [
                0 for _ in range(len(self.f_atoms[atom_idx]))
            ]
        self.f_atoms = [
            self.f_atoms[self.reverse_index_map[i]]
            for i in range(self.n_atoms)
        ]

        for _ in range(self.n_atoms):
            self.a2b.append([])

        if args.learn_virtual_edges:
            model = args.lve_model  # this is the MPN model, added to args so we can access it here
            f_atoms_cuda = torch.Tensor(self.f_atoms).cuda()
            processed_f_atoms_cuda = model.lve(f_atoms_cuda)
            lve_scores = torch.matmul(processed_f_atoms_cuda, f_atoms_cuda.t())
            symmetric_lve_scores = lve_scores + lve_scores.t()

        # Get bond features
        for a1 in range(self.n_atoms):
            for a2 in range(a1 + 1, self.n_atoms):
                bond = mol.GetBondBetweenAtoms(self.reverse_index_map[a1],
                                               self.reverse_index_map[a2])
                zero_bond = self.reverse_index_map[
                    a1] in self.substructure_atoms or self.reverse_index_map[
                        a2] in self.substructure_atoms

                # need to check all possible atoms in the substructure to see if there's a bond
                if zero_bond:
                    candidate_endpoints = []
                    for atom_idx in [a1, a2]:
                        reverse_idx = self.reverse_index_map[atom_idx]
                        if reverse_idx in self.substructure_atoms:
                            for substructure in self.substructures:
                                if reverse_idx in substructure:
                                    candidate_endpoints.append(substructure)
                                    break
                        else:
                            candidate_endpoints.append([atom_idx])
                    for alternate_a1 in candidate_endpoints[0]:
                        for alternate_a2 in candidate_endpoints[1]:
                            if mol.GetBondBetweenAtoms(
                                    alternate_a1, alternate_a2) is not None:
                                bond = mol.GetBondBetweenAtoms(
                                    alternate_a1, alternate_a2)

                # Randomly drop O(n_atoms) virtual edges so a total of O(n_atoms) edges instead of O(n_atoms^2)
                if bond is None:
                    if not args.virtual_edges:
                        continue

                    if args.drop_virtual_edges and hash(str(
                        (a1, a2))) % self.n_atoms != 0:
                        continue

                    # this option below doesn't seem to be as good
                    # if args.drop_virtual_edges and (hash(mol_fatoms[a1])+hash(mol_fatoms[a2])) % n_atoms != 0:
                    #     continue

                    if args.learn_virtual_edges:
                        # if score less than 0, don't add the edge
                        model = args.lve_model
                        if symmetric_lve_scores[
                                a1, a2] < 0:  # want symmetry in a1/a2
                            continue

                distance_3d = distances_3d[a1, a2] if args.three_d else None
                distance_path = distances_path[
                    a1, a2] if args.virtual_edges else None

                f_bond = bond_features(bond,
                                       distance_path=distance_path,
                                       distance_3d=distance_3d)
                if zero_bond:
                    f_bond = [0 for _ in range(len(f_bond))]

                if args.atom_messages:
                    self.f_bonds.append(f_bond)
                    self.f_bonds.append(f_bond)
                else:
                    self.f_bonds.append(self.f_atoms[a1] + f_bond)
                    self.f_bonds.append(self.f_atoms[a2] + f_bond)

                # Update index mappings
                b1 = self.n_bonds
                b2 = b1 + 1
                self.a2b[a2].append(b1)  # b1 = a1 --> a2
                self.b2a.append(a1)
                self.a2b[a1].append(b2)  # b2 = a2 --> a1
                self.b2a.append(a2)
                self.b2revb.append(b2)
                self.b2revb.append(b1)
                self.n_bonds += 2