コード例 #1
0
    def __init__(self,
                 mol_input,
                 standard,
                 inv_temp=None,
                 total=None,
                 pruning_thresh=0.05,
                 input_type='smiles'):
        self.standard_energy = standard
        self.inv_temp = inv_temp
        self.total = total
        self.pruning_thresh = pruning_thresh
        self.temp_0 = 1

        if input_type == 'smiles':
            self.mol = Chem.MolFromSmiles(mol_input)
            self.mol = Chem.AddHs(self.mol)
        elif input_type == 'file':
            pass
        elif input_type == 'mol':
            self.mol = mol_input
            self.mol.UpdatePropertyCache()
            FastFindRings(self.mol)

        Chem.AllChem.EmbedMultipleConfs(self.mol, numConfs=1)
        Chem.AllChem.MMFFOptimizeMoleculeConfs(self.mol, maxIters=10000)
コード例 #2
0
    def get_ordered_scaffold_sets(molecules, include_chirality, log_every_n):
        """Group molecules based on their Bemis-Murcko scaffolds and
        order these groups based on their sizes.

        The order is decided by comparing the size of groups, where groups with a larger size
        are placed before the ones with a smaller size.

        Parameters
        ----------
        molecules : list of rdkit.Chem.rdchem.Mol
            Pre-computed RDKit molecule instances. We expect a one-on-one
            correspondence between ``dataset.smiles`` and ``mols``, i.e.
            ``mols[i]`` corresponds to ``dataset.smiles[i]``.
        include_chirality : bool
            Whether to consider chirality in computing scaffolds.
        log_every_n : None or int
            Molecule related computation can take a long time for a large dataset and we want
            to learn the progress of processing. This can be done by printing a message whenever
            a batch of ``log_every_n`` molecules have been processed. If None, no messages will
            be printed.

        Returns
        -------
        scaffold_sets : list
            Each element of the list is a list of int,
            representing the indices of compounds with a same scaffold.
        """
        if log_every_n is not None:
            print('Start computing Bemis-Murcko scaffolds.')
        scaffolds = defaultdict(list)
        for i, mol in enumerate(molecules):
            count_and_log('Computing Bemis-Murcko for compound', i,
                          len(molecules), log_every_n)
            # For mols that have not been sanitized, we need to compute their ring information
            try:
                FastFindRings(mol)
                mol_scaffold = MurckoScaffold.MurckoScaffoldSmiles(
                    mol=mol, includeChirality=include_chirality)
                # Group molecules that have the same scaffold
                scaffolds[mol_scaffold].append(i)
            except:
                print('Failed to compute the scaffold for molecule {:d} '
                      'and it will be excluded.'.format(i + 1))

        # Order groups of molecules by first comparing the size of groups
        # and then the index of the first compound in the group.
        scaffold_sets = [
            scaffold_set
            for (scaffold,
                 scaffold_set) in sorted(scaffolds.items(),
                                         key=lambda x: (len(x[1]), x[1][0]),
                                         reverse=True)
        ]

        return scaffold_sets
コード例 #3
0
def pair_prods(mol_list, rxn, debug=False):
    prod1_list = []
    prod2_list = []
    for mol in mol_list:
        if debug:
            print(MolToSmiles(mol))
        try:
            mol.UpdatePropertyCache()
            FastFindRings(mol)
        except:
            print('This mol fails! ' + MolToSmiles(mol))
            continue
        products = rxn.RunReactants((Chem.AddHs(mol), ))
        if products != ():
            for prod in products:
                prod1_list.append(prod[0])
                prod2_list.append(prod[1])
    return prod1_list, prod2_list
コード例 #4
0
    def render(self, mode="human"):
        """
        Generates a 3D rendering of the current molecule in the environment.

        Parameters
        -------
        mode: string
            Just a way of indicating whether the rendering is mainly for humans vs machines in OpenAI
        """
        if self.check_valency() and self.check_chemical_validity():

            if not self.pymol_window_flag:
                self.start_pymol()

            molecule = self.mol

            self.mol.UpdatePropertyCache(
                strict=False)  # Update valence information
            FastFindRings(
                self.mol
            )  # Quick for finding out if an atom is in a ring, use Chem.GetSymmSSR() if more reliability is desired
            Chem.AddHs(
                molecule)  # Add explicit hydrogen atoms for rendering purposes

            # remove stereochemistry information
            rdmolops.RemoveStereochemistry(molecule)

            # Generate 3D structure
            AllChem.EmbedMolecule(molecule)
            AllChem.MMFFOptimizeMolecule(molecule)

            v = MolViewer()
            v.ShowMol(molecule)
            v.GetPNG(h=400)

            # Save the rendering in pse format
            pymol.cmd.save("./pymol_renderings/" + Chem.MolToSmiles(molecule) +
                           ".pse",
                           format="pse")
            Chem.RemoveHs(molecule)  # Remove explicit hydrogen
        else:
            print(
                "The molecule is not chemically valid, and rendering has been terminated."
            )
コード例 #5
0
def pair_prods(mol_list, rxn, debug=False):
    """
    Function that applies a one-reactant two-product reaction SMILES to a list of input RDKit molecules,
    returning the products as two separate lists of RDKit molecules.
    """
    prod1_list = []
    prod2_list = []
    for mol in mol_list:
        if debug:
            logging.info(MolToSmiles(mol))
        try:
            mol.UpdatePropertyCache()
            FastFindRings(mol)
        except:
            logging.info('This mol fails! ' + MolToSmiles(mol))
            continue
        products = rxn.RunReactants((Chem.AddHs(mol), ))
        if products != ():
            for prod in products:
                prod1_list.append(prod[0])
                prod2_list.append(prod[1])
    return prod1_list, prod2_list
コード例 #6
0
def main():
    args = get_parser()
    if args.use_deepchem_feature:
        args.degree_dim = 11
        args.use_sybyl = False
        args.use_electronegativity = False
        args.use_gasteiger = False

    adj_list = []
    feature_list = []
    label_data_list = []
    label_mask_list = []
    atom_num_list = []
    mol_name_list = []
    seq_symbol_list = None
    dragon_data_list = None
    seq = None
    seq_list = None
    seq_domain = None
    seq_domain_list = None
    dragon_data = None
    profeat = None
    mol_list = []
    train_list = []
    eval_list = []
    test_list = []
    prefix_idx = 0
    if args.solubility:
        args.sdf_label = "SOL_classification"
        args.sdf_label_active = "high"
        args.sdf_label_inactive = "low"
    if args.assay_dir is not None:
        mol_obj_list, label_data, label_mask, dragon_data, task_name_list, seq, seq_symbol, seq_domain, profeat, publication_years =\
            extract_mol_info(args)
    else:
        mol_obj_list, label_data, label_mask, _, task_name_list, _, _, _, _, publication_years = extract_mol_info(
            args)

    if args.vector_modal is not None:
        dragon_data = build_vector_modal(args)
    if args.atom_num_limit is None:
        args.atom_num_limit = 0
        for index, mol in enumerate(mol_obj_list):
            if mol is None:
                continue
            Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
            if args.atom_num_limit < mol.GetNumAtoms():
                args.atom_num_limit = mol.GetNumAtoms()

    if args.use_electronegativity:
        ELECTRONEGATIVITIES = [
            element(i).electronegativity('pauling') for i in range(1, 100)
        ]
        ELECTRONEGATIVITIES = [
            e if e is not None else 0 for e in ELECTRONEGATIVITIES
        ]

    for index, mol in enumerate(mol_obj_list):
        if mol is None:
            continue
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
        if args.atom_num_limit is not None and mol.GetNumAtoms(
        ) > args.atom_num_limit:
            continue
        try:
            name = mol.GetProp("_Name")
        except KeyError:
            name = "index_" + str(index)
        mol_list.append(mol)
        mol_name_list.append(name)
        adj = create_adjancy_matrix(mol)
        if args.use_electronegativity:
            feature = create_feature_matrix(
                mol,
                args.atom_num_limit,
                use_electronegativity=args.use_electronegativity,
                use_sybyl=args.use_sybyl,
                use_gasteiger=args.use_gasteiger,
                use_tfrecords=args.tfrecords,
                degree_dim=args.degree_dim,
                en_list=ELECTRONEGATIVITIES)
        else:
            feature = create_feature_matrix(
                mol,
                args.atom_num_limit,
                use_electronegativity=args.use_electronegativity,
                use_sybyl=args.use_sybyl,
                use_gasteiger=args.use_gasteiger,
                use_tfrecords=args.tfrecords,
                degree_dim=args.degree_dim)

        if args.tfrecords:
            ex = convert_to_example(adj, feature, label_data[index],
                                    label_mask[index])
            if args.csv_reaxys:
                if publication_years[index] < 2015:
                    train_list.append(ex)
                else:
                    choice = random.choice(["test", "eval"])
                    if choice == "test":
                        test_list.append(ex)
                    else:
                        eval_list.append(ex)
            if index % 100000 == 0 and index > 0:
                save_tfrecords(args.output, train_list, eval_list, test_list,
                               prefix_idx)
                train_list.clear()
                eval_list.clear()
                test_list.clear()
                prefix_idx += 1
            continue

        atom_num_list.append(mol.GetNumAtoms())
        adj_list.append(dense_to_sparse(adj))
        feature_list.append(feature)
        # Create labels
        if args.sdf_label:
            line = mol.GetProp(args.sdf_label)
            if line.find(args.sdf_label_active) != -1:
                label_data_list.append([0, 1])
                label_mask_list.append([1, 1])
            elif line.find(args.sdf_label_inactive) != -1:
                label_data_list.append([1, 0])
                label_mask_list.append([1, 1])
            else:
                print(f"[WARN] unknown label: {line}")
                label_data_list.append([0, 0])
                label_mask_list.append([0, 0])
        else:
            label_data_list.append(label_data[index])
            label_mask_list.append(label_mask[index])
            if dragon_data is not None:
                dragon_data_list = dragon_data_list if dragon_data_list is not None else []
                dragon_data_list.append(dragon_data[index])
        if args.multimodal:
            if seq is not None:
                seq_list, seq_symbol_list = (
                    seq_list,
                    seq_symbol_list) if seq_list is not None else ([], [])
                seq_list.append(seq[index])
                seq_symbol_list.append(seq[index])
            if seq_domain is not None:
                seq_domain_list = seq_domain_list if seq_domain_list is not None else []
                seq_domain_list.append(seq_domain[index])

    if args.csv_reaxys:
        save_tfrecords(args.output, train_list, eval_list, test_list,
                       prefix_idx)
    if args.tfrecords:
        with open(os.path.join(args.output, "tasks.txt"), "w") as f:
            f.write("\n".join(task_name_list))
        sys.exit(0)
    # joblib output
    obj = {
        "feature": np.asarray(feature_list, dtype=np.float32),
        "adj": np.asarray(adj_list)
    }
    if not args.output_sparse_label:
        obj["label"] = np.asarray(label_data_list)
        obj["mask_label"] = np.asarray(label_mask_list)
    else:
        if args.input_sparse_label:
            obj['label_dim'] = label_data[0].get_shape(
            )[1] if args.label_dim is None else args.label_dim
            obj['label_sparse'] = csr_matrix(vstack(label_data))
            obj['mask_label_sparse'] = csr_matrix(vstack(label_mask))
        else:
            label_data = np.asarray(label_data_list)
            label_mask = np.asarray(label_mask_list)
            obj['label_dim'] = label_data.shape[
                1] if args.label_dim is None else args.label_dim
            obj['label_sparse'] = csr_matrix(label_data.astype(np.float32))
            obj['mask_label_sparse'] = csr_matrix(label_mask.astype(
                np.float32))
    if task_name_list is not None:
        obj["task_names"] = np.asarray(task_name_list)
    if dragon_data_list is not None:
        obj["dragon"] = np.asarray(dragon_data_list)
    if profeat is not None:
        obj["profeat"] = np.asarray(profeat)
    obj["max_node_num"] = args.atom_num_limit
    mol_info = {"obj_list": mol_list, "name_list": mol_name_list}
    obj["mol_info"] = mol_info
    if not args.regression:
        label_int = np.ravel(obj['label_sparse'].argmax(
            axis=1)) if args.input_sparse_label else np.argmax(label_data_list,
                                                               axis=1)
        cw = class_weight.compute_class_weight("balanced",
                                               np.unique(label_int), label_int)
        obj["class_weight"] = cw

    if args.generate_mfp:
        mfps = []
        for mol in mol_list:
            FastFindRings(mol)
            mfp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            mfp_vec = np.array([mfp.GetBit(i) for i in range(2048)], np.int32)
            mfps.append(mfp_vec)
        obj["mfp"] = np.array(mfps)

    if args.multimodal:
        if seq is not None:
            max_len_seq = args.max_len_seq if args.max_len_seq is not None else max(
                map(len, seq_list))
            print(f"max_len_seq: {max_len_seq}")
            seq_mat = np.zeros((len(seq_list), max_len_seq), np.int32)
            for i, s in enumerate(seq_list):
                seq_mat[i, 0:len(s)] = s
            obj["sequence"] = seq_mat
            obj["sequence_symbol"] = seq_symbol_list
            obj["sequence_length"] = list(map(len, seq_list))
            obj["sequence_symbol_num"] = int(np.max(seq_mat) + 1)
        if seq_domain is not None:
            if max_len_seq is None:
                max_len_seq = args.max_len_seq if args.max_len_seq is not None else max(
                    map(len, seq_list))
            print(f"max_len_seq: {max_len_seq}")

            if args.assay_domain_name_clone is None:
                seq_domain_name = set()
                for i, doms in enumerate(seq_domain_list):
                    for key, el in doms.items():
                        seq_domain_name.add(key)
                seq_domain_name = sorted(list(seq_domain_name))
            else:
                print("[LOAD]", args.assay_domain_name_clone)
                loaded_obj = joblib.load(args.assay_domain_name_clone)
                seq_domain_name = loaded_obj["sequence_vec_name"]
            print(seq_domain_name)
            #seq_domain_mat = np.zeros((len(seq_domain_list), max_len_seq, len(seq_domain_name)), np.float32)
            seq_vec_range = []
            for i, doms in enumerate(seq_domain_list):
                seq_r = {}
                for key, vec in doms.items():
                    j = seq_domain_name.index(key)
                    seq_r[j] = vec
                    #for el in vec:
                    #    seq_domain_mat[i, el[0]-1:el[1],j] = 1
                seq_vec_range.append(seq_r)
            obj["sequence_vec_range"] = np.array(seq_vec_range)
            obj["sequence_vec_name"] = seq_domain_name
            #obj["sequence_vec"] = seq_domain_mat
    print(f"[SAVE] {args.output}")
    joblib.dump(obj, args.output, compress=3)
コード例 #7
0
ファイル: chem.py プロジェクト: embeddedsamurai/kGCN-1
def main():
    args = get_parser()
    if args.use_deepchem_feature:
        args.degree_dim = 11
        args.use_sybyl = False
        args.use_electronegativity = False
        args.use_gasteiger = False

    adj_list = []
    feature_list = []
    label_data_list = []
    label_mask_list = []
    atom_num_list = []
    mol_name_list = []
    seq_symbol_list = None
    dragon_data_list = None
    task_name_list = None
    seq_list = None
    seq = None
    dragon_data = None
    profeat = None
    mol_list = []
    if args.solubility:
        args.sdf_label = "SOL_classification"
        args.sdf_label_active = "high"
        args.sdf_label_inactive = "low"
    if args.assay_dir is not None:
        mol_obj_list, label_data, label_mask, dragon_data, task_name_list, seq, seq_symbol, profeat, publication_years = extract_mol_info(
            args)
    else:
        mol_obj_list, label_data, label_mask, _, task_name_list, _, _, _, publication_years = extract_mol_info(
            args)

    if args.vector_modal is not None:
        dragon_data = build_vector_modal(args)
    ## automatically setting atom_num_limit
    if args.atom_num_limit is None:
        args.atom_num_limit = 0
        for index, mol in enumerate(mol_obj_list):
            if mol is None:
                continue
            Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
            if args.atom_num_limit < mol.GetNumAtoms():
                args.atom_num_limit = mol.GetNumAtoms()

    if args.use_electronegativity:
        ELECTRONEGATIVITIES = [
            element(i).electronegativity('pauling') for i in range(1, 100)
        ]
        ELECTRONEGATIVITIES = [
            e if e is not None else 0 for e in ELECTRONEGATIVITIES
        ]

    for index, mol in enumerate(mol_obj_list):
        if mol is None:
            continue
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
        # Skip the compound whose total number of atoms is larger than "atom_num_limit"
        if args.atom_num_limit is not None and mol.GetNumAtoms(
        ) > args.atom_num_limit:
            continue
        # Get mol. name
        try:
            name = mol.GetProp("_Name")
        except KeyError:
            name = "index_" + str(index)
        mol_list.append(mol)
        mol_name_list.append(name)
        adj = create_adjancy_matrix(mol)
        feature = create_feature_matrix(
            mol,
            args) if not args.use_electronegativity else create_feature_matrix(
                mol, args, en_list=ELECTRONEGATIVITIES)
        if args.tfrecords:
            tfrname = os.path.join(args.output, name + '_.tfrecords')
            if args.csv_reaxys:
                if publication_years[index] < 2015:
                    name += "_train"
                else:
                    name += random.choice(["_test", "_eval"])
                tfrname = os.path.join(args.output,
                                       str(publication_years[index]),
                                       name + '_.tfrecords')
            pathlib.Path(os.path.dirname(tfrname)).mkdir(parents=True,
                                                         exist_ok=True)
            ex = convert_to_example(adj, feature, label_data[index],
                                    label_mask[index])
            with TFRecordWriter(tfrname) as single_writer:
                single_writer.write(ex.SerializeToString())
            continue

        atom_num_list.append(mol.GetNumAtoms())
        adj_list.append(dense_to_sparse(adj))
        feature_list.append(feature)
        # Create labels
        if args.sdf_label:
            line = mol.GetProp(args.sdf_label)
            if line.find(args.sdf_label_active) != -1:
                label_data_list.append([0, 1])
                label_mask_list.append([1, 1])
            elif line.find(args.sdf_label_inactive) != -1:
                label_data_list.append([1, 0])
                label_mask_list.append([1, 1])
            else:
                # missing
                print("[WARN] unknown label:", line)
                label_data_list.append([0, 0])
                label_mask_list.append([0, 0])
        else:
            label_data_list.append(label_data[index])
            label_mask_list.append(label_mask[index])
            if dragon_data is not None:
                if dragon_data_list is None:
                    dragon_data_list = []
                dragon_data_list.append(dragon_data[index])
        if args.multimodal:
            if seq is not None:
                if seq_list is None:
                    seq_list, seq_symbol_list = [], []
                seq_list.append(seq[index])
                seq_symbol_list.append(seq[index])
    if args.tfrecords:
        with open(os.path.join(args.output, "tasks.txt"), "w") as f:
            f.write("\n".join(task_name_list))
        sys.exit(0)
    # joblib output
    obj = {"feature": np.asarray(feature_list), "adj": np.asarray(adj_list)}
    if not args.sparse_label:
        obj["label"] = np.asarray(label_data_list)
        obj["mask_label"] = np.asarray(label_mask_list)
    else:
        from scipy.sparse import csr_matrix
        label_data = np.asarray(label_data_list)
        label_mask = np.asarray(label_mask_list)
        if args.label_dim is None:
            obj['label_dim'] = label_data.shape[1]
        else:
            obj['label_dim'] = args.label_dim
        obj['label_sparse'] = csr_matrix(label_data.astype(float))
        obj['mask_label_sparse'] = csr_matrix(label_mask.astype(float))
    if task_name_list is not None:
        obj["task_names"] = np.asarray(task_name_list)
    if dragon_data_list is not None:
        obj["dragon"] = np.asarray(dragon_data_list)
    if profeat is not None:
        obj["profeat"] = np.asarray(profeat)
    obj["max_node_num"] = args.atom_num_limit
    mol_info = {"obj_list": mol_list, "name_list": mol_name_list}
    obj["mol_info"] = mol_info
    if not args.regression:
        label_int = np.argmax(label_data_list, axis=1)
        cw = class_weight.compute_class_weight("balanced",
                                               np.unique(label_int), label_int)
        obj["class_weight"] = cw

    if args.generate_mfp:
        from rdkit.Chem import AllChem
        mfps = []
        for mol in mol_list:
            FastFindRings(mol)
            mfp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            mfp_vec = np.array([mfp.GetBit(i) for i in range(2048)], np.int32)
            mfps.append(mfp_vec)
        obj["mfp"] = np.array(mfps)
    ##

    if args.multimodal:
        if seq is not None:
            if args.max_len_seq is not None:
                max_len_seq = args.max_len_seq
            else:
                max_len_seq = max(map(len, seq_list))
            print("max_len_seq:", max_len_seq)
            seq_mat = np.zeros((len(seq_list), max_len_seq), np.int32)
            for i, s in enumerate(seq_list):
                seq_mat[i, 0:len(s)] = s
            obj["sequence"] = seq_mat
            obj["sequence_symbol"] = seq_symbol_list
            obj["sequence_length"] = list(map(len, seq_list))
            obj["sequence_symbol_num"] = int(np.max(seq_mat) + 1)

    filename = args.output
    print("[SAVE] " + filename)
    joblib.dump(obj, filename, compress=3)