Exemple #1
0
def load_data_from_smiles(smiles,
                          labels,
                          target,
                          bondtype_freq=20,
                          atomtype_freq=10,
                          sdf_file=None,
                          dummyNode=False,
                          formal_charge_one_hot=False):
    bondtype_dic = {}
    atomtype_dic = {}
    for smile in smiles:
        try:
            mol = MolFromSmiles(smile)
            bondtype_dic = fillBondType_dic(mol, bondtype_dic)
            atomtype_dic = fillAtomType_dic(mol, atomtype_dic)
        except AttributeError:
            pass
        else:
            pass

    sorted_bondtype_dic = sorted(bondtype_dic.items(),
                                 key=operator.itemgetter(1))
    sorted_bondtype_dic.reverse()
    bondtype_list_order = [ele[0] for ele in sorted_bondtype_dic]
    bondtype_list_number = [ele[1] for ele in sorted_bondtype_dic]

    filted_bondtype_list_order = []
    for i in range(0, len(bondtype_list_order)):
        if bondtype_list_number[i] > bondtype_freq:
            filted_bondtype_list_order.append(bondtype_list_order[i])
    filted_bondtype_list_order.append('Others')

    sorted_atom_types_dic = sorted(atomtype_dic.items(),
                                   key=operator.itemgetter(1))
    sorted_atom_types_dic.reverse()
    atomtype_list_order = [ele[0] for ele in sorted_atom_types_dic]
    atomtype_list_number = [ele[1] for ele in sorted_atom_types_dic]

    filted_atomtype_list_order = []
    for i in range(0, len(atomtype_list_order)):
        if atomtype_list_number[i] > atomtype_freq:
            filted_atomtype_list_order.append(atomtype_list_order[i])
    filted_atomtype_list_order.append('Others')

    print('filted_atomtype_list_order: {}, \n filted_bondtype_list_order: {}'.
          format(filted_atomtype_list_order, filted_bondtype_list_order))

    # mol to graph
    i = 0
    mol_sizes = []
    x_all = []
    y_all = []

    print('Transfer mol to matrices')
    if sdf_file:
        mols = []
        suppl = Chem.SDMolSupplier(sdf_file, sanitize=False)
        for mol in suppl:
            c = mol.GetConformers()[0]
            new_mol = Chem.RemoveHs(mol, sanitize=False)
            for i in range(new_mol.GetNumAtoms()):
                new_mol.GetConformers()[0].SetAtomPosition(
                    i, (c.GetAtomPosition(i).x, c.GetAtomPosition(i).y,
                        c.GetAtomPosition(i).z))
            mols.append(new_mol)
    for smile, label in zip(smiles, labels):
        try:
            ##### CONFORMATION FOR MOLECULE DIST MATRIX #####
            if not sdf_file:
                mol = MolFromSmiles(smile)
                try:
                    # 3d
                    mol = Chem.AddHs(mol)
                    AllChem.EmbedMolecule(mol, maxAttempts=5000)
                    AllChem.UFFOptimizeMolecule(mol)
                    mol = Chem.RemoveHs(mol)
                except:
                    # 2d
                    AllChem.Compute2DCoords(mol)
                #####
            else:
                mol = mols.pop(0)

            if dummyNode:
                (afm, adj, bft, adjTensor_OrderAtt, adjTensor_AromAtt,
                 adjTensor_ConjAtt, adjTensor_RingAtt,
                 mat_positions) = molToGraph(
                     mol,
                     filted_bondtype_list_order,
                     filted_atomtype_list_order,
                     formal_charge_one_hot=formal_charge_one_hot
                 ).dump_as_matrices_Att_dummyNode()
            else:
                (afm, adj, bft, adjTensor_OrderAtt, adjTensor_AromAtt,
                 adjTensor_ConjAtt, adjTensor_RingAtt,
                 mat_positions) = molToGraph(
                     mol,
                     filted_bondtype_list_order,
                     filted_atomtype_list_order,
                     formal_charge_one_hot=formal_charge_one_hot
                 ).dump_as_matrices_Att()

            x_all.append([
                afm, adj, bft, adjTensor_OrderAtt, adjTensor_AromAtt,
                adjTensor_ConjAtt, adjTensor_RingAtt, mat_positions
            ])
            y_all.append([label])
            mol_sizes.append(adj.shape[0])
            # feature matrices of mols, include Adj Matrix, Atom Feature, Bond Feature.
        except AttributeError:
            print('the smile: {} has an error'.format(smile))
        except RuntimeError:
            print('the smile: {} has an error'.format(smile))
        except ValueError as e:
            print('the smile: {}, can not convert to graph structure'.format(
                smile))
            print(e)
        except:
            print('the smile: {} has an error'.format(smile))
        else:
            pass

    print('Done.')
    return (x_all, y_all, target, mol_sizes)
Exemple #2
0
def extract_graph(data_path, out_file_path, max_atom_num, label_name=None):
    import os
    from rdkit import RDConfig
    from rdkit.Chem import ChemicalFeatures
    fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
    factory = ChemicalFeatures.BuildFeatureFactory(fdefName)

    data_pd = pd.read_csv(data_path)
    smiles_list = data_pd['SMILES'].tolist()

    symbol_candidates = set()
    atom_attribute_dim = num_atom_features()
    bond_attribute_dim = num_bond_features()

    node_attribute_matrix_list = []
    bond_attribute_matrix_list = []
    adjacent_matrix_list = []
    distance_matrix_list = []
    valid_index = []

    ###
    degree_set = set()
    h_num_set = set()
    implicit_valence_set = set()
    charge_set = set()
    ###

    for line_idx, smiles in enumerate(smiles_list):
        smiles = smiles.strip()
        mol = MolFromSmiles(smiles)
        AllChem.Compute2DCoords(mol)
        conformer = mol.GetConformers()[0]
        feats = factory.GetFeaturesForMol(mol)
        acceptor_atom_ids = map(
            lambda x: x.GetAtomIds()[0],
            filter(lambda x: x.GetFamily() == 'Acceptor', feats))
        donor_atom_ids = map(lambda x: x.GetAtomIds()[0],
                             filter(lambda x: x.GetFamily() == 'Donor', feats))

        adjacent_matrix = np.zeros((max_atom_num, max_atom_num))
        adjacent_matrix = adjacent_matrix.astype(int)
        distance_matrix = np.zeros((max_atom_num, max_atom_num))
        node_attribute_matrix = np.zeros((max_atom_num, atom_attribute_dim))
        node_attribute_matrix = node_attribute_matrix.astype(int)

        if len(mol.GetAtoms()) > max_atom_num:
            print('Outlier {} has {} atoms'.format(line_idx,
                                                   mol.GetNumAtoms()))
            continue
        valid_index.append(line_idx)

        atom_positions = [None for _ in range(mol.GetNumAtoms() + 1)]
        for atom in mol.GetAtoms():
            atom_idx = atom.GetIdx()
            symbol_candidates.add(atom.GetSymbol())
            atom_positions[atom_idx] = conformer.GetAtomPosition(atom_idx)
            degree_set.add(atom.GetDegree())
            h_num_set.add(atom.GetTotalNumHs())
            implicit_valence_set.add(atom.GetImplicitValence())
            charge_set.add(atom.GetFormalCharge())
            node_attribute_matrix[atom_idx] = extract_atom_features(
                atom,
                is_acceptor=atom_idx in acceptor_atom_ids,
                is_donor=atom_idx in donor_atom_ids)
        node_attribute_matrix_list.append(node_attribute_matrix)

        for idx_i in range(mol.GetNumAtoms()):
            for idx_j in range(idx_i + 1, mol.GetNumAtoms()):
                distance = get_atom_distance(conformer.GetAtomPosition(idx_i),
                                             conformer.GetAtomPosition(idx_j))
                distance_matrix[idx_i, idx_j] = distance
                distance_matrix[idx_j, idx_i] = distance
        distance_matrix_list.append(distance_matrix)

        for bond in mol.GetBonds():
            begin_atom = bond.GetBeginAtom()
            end_atom = bond.GetEndAtom()
            begin_index = begin_atom.GetIdx()
            end_index = end_atom.GetIdx()
            adjacent_matrix[begin_index, end_index] = 1
            adjacent_matrix[end_index, begin_index] = 1
        adjacent_matrix_list.append(adjacent_matrix)

    adjacent_matrix_list = np.asarray(adjacent_matrix_list)
    distance_matrix_list = np.asarray(distance_matrix_list)
    node_attribute_matrix_list = np.asarray(node_attribute_matrix_list)
    bond_attribute_matrix_list = np.asarray(bond_attribute_matrix_list)
    print('adjacent matrix shape\t', adjacent_matrix_list.shape)
    print('distance matrix shape\t', distance_matrix_list.shape)
    print('node attr matrix shape\t', node_attribute_matrix_list.shape)
    print('bond attr matrix shape\t', bond_attribute_matrix_list.shape)
    print(symbol_candidates)
    print('{} valid out of {}'.format(len(valid_index), len(smiles_list)))

    print('degree set:\t', degree_set)
    print('h num set: \t', h_num_set)
    print('implicit valence set: \t', implicit_valence_set)
    print('charge set:\t', charge_set)

    if label_name is None:
        np.savez_compressed(
            out_file_path,
            adjacent_matrix_list=adjacent_matrix_list,
            distance_matrix_list=distance_matrix_list,
            node_attribute_matrix_list=node_attribute_matrix_list,
            bond_attribute_matrix_list=bond_attribute_matrix_list)
    else:
        true_labels = data_pd[label_name].tolist()
        true_labels = np.array(true_labels)
        valid_index = np.array(valid_index)
        true_labels = true_labels[valid_index]
        np.savez_compressed(
            out_file_path,
            adjacent_matrix_list=adjacent_matrix_list,
            distance_matrix_list=distance_matrix_list,
            node_attribute_matrix_list=node_attribute_matrix_list,
            bond_attribute_matrix_list=bond_attribute_matrix_list,
            label_name=true_labels)
    print()
    return