Exemplo n.º 1
0
def main(molecules_file1, molecules_file2, memmap_file, mol_names_file1,
         mol_names_file2, log=None, overwrite=False, parallel_mode=None,
         num_proc=None):
    setup_logging(log)
    logging.info("Reading first molecules file.")
    fp_array1, mol_names1, mol_indices_dict1 = read_convert_mols(
        molecules_file1)
    logging.info("Reading second molecules file.")
    fp_array2, mol_names2, mol_indices_dict2 = read_convert_mols(
        molecules_file2)

    if overwrite or not os.path.isfile(memmap_file):
        logging.info("Overwriting memmap file.")
        memmap = np.memmap(memmap_file, mode="w+", dtype=np.double,
                           shape=(len(mol_names1), len(mol_names2)))
        del memmap
        save_mol_names(mol_names_file1, mol_names1)
        save_mol_names(mol_names_file2, mol_names2)

    logging.info("Computing all pairwise Tanimotos.")

    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    start_end_indices = get_start_end_inds(len(mol_names1), para.num_proc - 1)
    kwargs = {"fp_array1": fp_array1,
              "mol_names1": mol_names1,
              "mol_indices_dict1": mol_indices_dict1,
              "fp_array2": fp_array2,
              "mol_names2": mol_names2,
              "mol_indices_dict2": mol_indices_dict2,
              "memmap_file": memmap_file}
    para.run(run_batch, start_end_indices, kwargs=kwargs)
Exemplo n.º 2
0
def main(sdf_file,
         save_freq=SAVE_FREQ,
         overwrite=False,
         log=None,
         parallel_mode=None,
         num_proc=None):
    setup_logging(log)
    logging.info("Reading mols from SDF.")
    supp = rdkit.Chem.SDMolSupplier(sdf_file)
    num_mol = len(supp)
    del supp

    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    start_end_indices = get_triangle_indices(num_mol, para.num_proc - 1)
    kwargs = {
        "sdf_file": sdf_file,
        "save_freq": save_freq,
        "overwrite": overwrite
    }
    para.run(run_batch, start_end_indices, kwargs=kwargs)
def main(molecules_file,
         log=None,
         overwrite=False,
         parallel_mode=None,
         num_proc=None,
         merge_confs=False,
         save_freq=SAVE_FREQ,
         compress=False):
    setup_logging(log)
    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    if para.is_master():
        data_iter = ((molecules_file, i, para.num_proc - 1)
                     for i in range(para.num_proc - 1))
    else:
        data_iter = iter([])

    kwargs = {
        "overwrite": overwrite,
        "merge_confs": merge_confs,
        "save_freq": save_freq,
        "compress": compress
    }
    para.run(run_batch, data_iter, kwargs=kwargs)
Exemplo n.º 4
0
def generate_fingerprints_parallel(sdf_files, threads):
    """Generate fingerprints for each sdf_file and construct a database.
    If threads=None, use all available processors, else specify an integral number
    of threads to use in parallel."""
    empty_fp = fprint.Fingerprint([])
    parallelizer = Parallelizer(parallel_mode="processes",
                                num_proc=threads,
                                fail_value=empty_fp)

    wrapped_files = [[f] for f in sdf_files]
    results = parallelizer.run(generate_e3fp_fingerprint, wrapped_files)

    results_dict = {}
    for r in results:
        results_dict[r[1][0]] = r[0]
    fingerprints = [results_dict[name] for name in sdf_files]

    database = db.FingerprintDatabase()
    database.add_fingerprints(fingerprints)
    return database
Exemplo n.º 5
0
class KFoldCrossValidator(object):
    """Class to perform k-fold cross-validation."""
    def __init__(self,
                 k=5,
                 splitter=MoleculeSplitter,
                 cv_method=SEASearchCVMethod(),
                 input_processor=None,
                 parallelizer=None,
                 out_dir=os.getcwd(),
                 overwrite=False,
                 return_auc_type="roc",
                 reduce_negatives=False,
                 fold_kwargs={}):
        if isinstance(splitter, type):
            self.splitter = splitter(k)
        else:
            assert splitter.k == k
            self.splitter = splitter
        self.k = k
        if (cv_method is SEASearchCVMethod and input_processor is not None):
            raise ValueError(
                "Input processing is not (currently) compatible with SEA.")
        self.cv_method = cv_method
        self.input_processor = input_processor
        self.overwrite = overwrite
        if parallelizer is None:
            self.parallelizer = Parallelizer(parallel_mode="serial")
        else:
            self.parallelizer = parallelizer
        self.out_dir = out_dir
        touch_dir(out_dir)
        self.input_file = os.path.join(self.out_dir, "inputs.pkl.bz2")
        self.return_auc_type = return_auc_type.lower()
        self.reduce_negatives = reduce_negatives
        self.fold_kwargs = fold_kwargs

    def run(self,
            molecules_file,
            targets_file,
            min_mols=50,
            affinity=None,
            overwrite=False):
        fold_validators = {
            fold_num: FoldValidator(fold_num,
                                    self._get_fold_dir(fold_num),
                                    cv_method=copy.deepcopy(self.cv_method),
                                    input_file=self.input_file,
                                    overwrite=self.overwrite,
                                    **self.fold_kwargs)
            for fold_num in range(self.k)
        }
        if not os.path.isfile(self.input_file) or not all(
            [x.fold_files_exist() for x in fold_validators.values()]):
            logging.info("Loading and filtering input files.")
            ((smiles_dict, mol_list_dict, fp_type),
             target_dict) = self.load_input_files(molecules_file,
                                                  targets_file,
                                                  min_mols=min_mols,
                                                  affinity=affinity)

            mol_list = sorted(mol_list_dict.keys())
            if isinstance(self.cv_method, SEASearchCVMethod):
                # efficiency hack
                fp_array, mol_to_fp_inds = (None, None)
            else:
                logging.info("Converting inputs to arrays.")
                fp_array, mol_to_fp_inds = molecules_to_array(
                    mol_list_dict, mol_list, processor=self.input_processor)
            target_mol_array, target_list = targets_to_array(
                target_dict, mol_list, dtype=TARGET_MOL_DENSE_DTYPE)
            total_imbalance = get_imbalance(target_mol_array)

            if self.overwrite or not os.path.isfile(self.input_file):
                logging.info("Saving arrays and labels to files.")
                save_cv_inputs(self.input_file, fp_array, mol_to_fp_inds,
                               target_mol_array, target_list, mol_list)
            del fp_array, mol_to_fp_inds

            logging.info("Splitting data into {} folds using {}.".format(
                self.k,
                type(self.splitter).__name__))
            if self.splitter.reduce_negatives:
                logging.info("After splitting, negatives will be reduced.")
            train_test_masks = self.splitter.get_train_test_masks(
                target_mol_array)
            for fold_num, train_test_mask in enumerate(train_test_masks):
                logging.info(
                    "Saving inputs to files (fold {})".format(fold_num))
                fold_val = fold_validators[fold_num]
                fold_val.save_fold_files(train_test_mask, mol_list,
                                         target_list, smiles_dict,
                                         mol_list_dict, fp_type, target_dict)
            del (smiles_dict, mol_list_dict, fp_type, target_mol_array,
                 mol_list, train_test_masks, target_dict, target_list)
        else:
            logging.info("Resuming from input and fold files.")
            (fp_array, mol_to_fp_inds, target_mol_array, target_list,
             mol_list) = load_cv_inputs(self.input_file)
            total_imbalance = get_imbalance(target_mol_array)
            del (fp_array, mol_to_fp_inds, target_mol_array, target_list,
                 mol_list)

        # run cross-validation and gather scores
        logging.info("Running fold validation.")
        para_args = sorted(fold_validators.items())
        aucs = zip(*self.parallelizer.run(_run_fold, para_args))[0]
        if fold_validators.values()[0].compute_combined:
            aurocs, auprcs = zip(*aucs)
            mean_auroc = np.mean(aurocs)
            std_auroc = np.std(aurocs)
            logging.info("CV Mean AUROC: {:.4f} +/- {:.4f}".format(
                mean_auroc, std_auroc))
            mean_auprc = np.mean(auprcs)
            std_auprc = np.std(auprcs)
            logging.info(("CV Mean AUPRC: {:.4f} +/- {:.4f} ({:.4f} of data "
                          "is positive)").format(mean_auprc, std_auprc,
                                                 total_imbalance))
        else:
            (mean_auroc, mean_auprc) = (None, None)

        target_aucs = []
        for fold_val in fold_validators.values():
            with smart_open(fold_val.target_aucs_file, "rb") as f:
                target_aucs.extend(pkl.load(f).values())
        target_aucs = np.array(target_aucs)
        mean_target_auroc, mean_target_auprc = np.mean(target_aucs, axis=0)
        std_target_auroc, std_target_auprc = np.std(target_aucs, axis=0)
        logging.info("CV Mean Target AUROC: {:.4f} +/- {:.4f}".format(
            mean_target_auroc, std_target_auroc))
        logging.info("CV Mean Target AUPRC: {:.4f} +/- {:.4f}".format(
            mean_target_auprc, std_target_auprc))

        if "target" in self.return_auc_type or mean_auroc is None:
            logging.info("Returning target AUC.")
            aucs = (mean_target_auroc, mean_target_auprc)
        else:
            logging.info("Returning combined average AUC.")
            aucs = (mean_auroc, mean_auprc)

        if "pr" in self.return_auc_type:
            logging.info("Returned AUC is AUPRC.")
            return aucs[1]
        elif "sum" in self.return_auc_type:
            return sum(aucs)
        else:
            logging.info("Returned AUC is AUROC.")
            return aucs[0]

    def load_input_files(self,
                         molecules_file,
                         targets_file,
                         min_mols=50,
                         affinity=None,
                         overwrite=False):
        smiles_dict, mol_lists_dict, fp_type = molecules_to_lists_dicts(
            molecules_file)

        mol_names_target_dict = mol_lists_targets_to_targets(
            targets_to_dict(targets_file, affinity=affinity))
        target_dict = filter_targets_by_molecules(mol_names_target_dict,
                                                  mol_lists_dict)
        del mol_names_target_dict
        target_dict = filter_targets_by_molnum(target_dict, n=min_mols)

        return (smiles_dict, mol_lists_dict, fp_type), target_dict

    def _get_fold_dir(self, fold_num):
        return os.path.join(self.out_dir, str(fold_num))