Exemplo n.º 1
0
def main(smiles_file,
         params_file,
         sdf_dir=None,
         out_file="molecules.csv.bz2",
         log=None,
         num_proc=None,
         parallel_mode=None,
         verbose=False):
    """Fingerprint molecules."""
    setup_logging(log, verbose=verbose)
    parallelizer = Parallelizer(parallel_mode="processes")

    # set conformer generation and fingerprinting parameters
    confgen_params, fprint_params = params_to_dicts(params_file)
    kwargs = {"save": False, "fprint_params": fprint_params}

    smiles_dict = smiles_to_dict(smiles_file)
    mol_num = len({x.split('-')[0] for x in smiles_dict})

    if sdf_dir is not None:
        sdf_files = glob.glob(os.path.join(sdf_dir, "*.sdf*"))
        sdf_files = sorted(
            [x for x in sdf_files if name_from_sdf_filename(x) in smiles_dict])
        data_iter = make_data_iterator(sdf_files)
        fp_method = native_tuples_from_sdf
        logging.info("Using SDF files from {}".format(sdf_dir))
    else:
        kwargs["confgen_params"] = confgen_params
        data_iter = ((smiles, name)
                     for name, smiles in smiles_dict.iteritems())
        mol_num = len({x.split('-')[0] for x in smiles_dict})
        fp_method = native_tuples_from_smiles
        logging.info("Will generate conformers.")
        logging.info(
            "Conformer generation params: {!r}.".format(confgen_params))
    logging.info("Fingerprinting params: {!r}.".format(fprint_params))

    # fingerprint in parallel
    logging.info("Fingerprinting {:d} molecules".format(mol_num))
    mol_list_dict = {}
    for result, data in parallelizer.run_gen(fp_method,
                                             data_iter,
                                             kwargs=kwargs):
        if not result:
            logging.warning("Fingerprinting failed for {}.".format(data[0]))
            continue
        try:
            _, name = result[0]
            name = name.split('_')[0]
        except IndexError:
            logging.warning("Fingerprinting failed for {}.".format(data[0]))
            continue
        mol_list_dict[name] = result
    logging.info("Finished fingerprinting molecules")

    # save to SEA molecules file
    logging.info("Saving fingerprints to {}".format(out_file))
    fp_type = fprint_params_to_fptype(**fprint_params)
    lists_dicts_to_molecules(out_file, smiles_dict, mol_list_dict, fp_type)
    logging.info("Finished!")
Exemplo n.º 2
0
 def __init__(self,
              k=5,
              splitter=MoleculeSplitter,
              cv_method=SEASearchCVMethod(),
              input_processor=None,
              parallelizer=None,
              out_dir=os.getcwd(),
              overwrite=False,
              return_auc_type="roc",
              reduce_negatives=False,
              fold_kwargs={}):
     if isinstance(splitter, type):
         self.splitter = splitter(k)
     else:
         assert splitter.k == k
         self.splitter = splitter
     self.k = k
     if (cv_method is SEASearchCVMethod and input_processor is not None):
         raise ValueError(
             "Input processing is not (currently) compatible with SEA.")
     self.cv_method = cv_method
     self.input_processor = input_processor
     self.overwrite = overwrite
     if parallelizer is None:
         self.parallelizer = Parallelizer(parallel_mode="serial")
     else:
         self.parallelizer = parallelizer
     self.out_dir = out_dir
     touch_dir(out_dir)
     self.input_file = os.path.join(self.out_dir, "inputs.pkl.bz2")
     self.return_auc_type = return_auc_type.lower()
     self.reduce_negatives = reduce_negatives
     self.fold_kwargs = fold_kwargs
Exemplo n.º 3
0
def main(molecules_file1, molecules_file2, memmap_file, mol_names_file1,
         mol_names_file2, log=None, overwrite=False, parallel_mode=None,
         num_proc=None):
    setup_logging(log)
    logging.info("Reading first molecules file.")
    fp_array1, mol_names1, mol_indices_dict1 = read_convert_mols(
        molecules_file1)
    logging.info("Reading second molecules file.")
    fp_array2, mol_names2, mol_indices_dict2 = read_convert_mols(
        molecules_file2)

    if overwrite or not os.path.isfile(memmap_file):
        logging.info("Overwriting memmap file.")
        memmap = np.memmap(memmap_file, mode="w+", dtype=np.double,
                           shape=(len(mol_names1), len(mol_names2)))
        del memmap
        save_mol_names(mol_names_file1, mol_names1)
        save_mol_names(mol_names_file2, mol_names2)

    logging.info("Computing all pairwise Tanimotos.")

    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    start_end_indices = get_start_end_inds(len(mol_names1), para.num_proc - 1)
    kwargs = {"fp_array1": fp_array1,
              "mol_names1": mol_names1,
              "mol_indices_dict1": mol_indices_dict1,
              "fp_array2": fp_array2,
              "mol_names2": mol_names2,
              "mol_indices_dict2": mol_indices_dict2,
              "memmap_file": memmap_file}
    para.run(run_batch, start_end_indices, kwargs=kwargs)
Exemplo n.º 4
0
def main(mfile1, mfile2, name1, name2, out_file, precision=PRECISION,
         log_freq=LOG_FREQ, num_proc=None, parallel_mode=None):
    setup_logging()
    if not out_file:
        out_file = (name1.lower().replace('\s', '_') + "_" +
                    name2.lower().replace('\s', '_') + "_tcs.csv.gz")

    # Load files
    mmap1 = load_mmap(mfile1)
    mmap2 = load_mmap(mfile2)
    if mmap1.shape != mmap2.shape:
        raise ValueError(
            "Memmaps do not have the same shape: {} {}".format(
                mmap1.shape, mmap2.shape))

    # Count binned pairs
    pair_num = mmap1.shape[0]
    del mmap1, mmap2

    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    num_proc = max(para.num_proc - 1, 1)
    chunk_bounds = np.linspace(-1, pair_num - 1, num_proc + 1, dtype=int)
    chunk_bounds = list(zip(chunk_bounds[:-1] + 1, chunk_bounds[1:]))
    logging.info("Divided into {} chunks with ranges: {}".format(num_proc,
                                                                 chunk_bounds))

    logging.info("Counting TCs in chunks.")
    kwargs = {"mfile1": mfile1, "mfile2": mfile2, "precision": precision,
              "log_freq": log_freq}
    results_iter = para.run_gen(count_tcs, chunk_bounds, kwargs=kwargs)
    tc_pair_counts = Counter()
    for chunk_counts, _ in results_iter:
        if not isinstance(chunk_counts, dict):
            logging.error("Results are not in dict form.")
            continue
        tc_pair_counts.update(chunk_counts)

    # Write pairs to file
    logging.info("Writing binned pairs to {}.".format(out_file))
    mult = 10**precision
    with smart_open(out_file, "wb") as f:
        writer = csv.writer(f, delimiter=SEP)
        writer.writerow([name1, name2, "Count"])
        for pair in sorted(tc_pair_counts):
            writer.writerow([round(pair[0] / mult, precision),
                             round(pair[1] / mult, precision),
                             tc_pair_counts[pair]])

    total_counts = sum(tc_pair_counts.values())
    if total_counts != pair_num:
        logging.warning(
            "Pair counts {} did not match expected number {}".format(
                total_counts, pair_num))
        return
    logging.info("Completed.")
Exemplo n.º 5
0
def generate_fingerprints_parallel(sdf_files, threads):
    """Generate fingerprints for each sdf_file and construct a database.
    If threads=None, use all available processors, else specify an integral number
    of threads to use in parallel."""
    empty_fp = fprint.Fingerprint([])
    parallelizer = Parallelizer(parallel_mode="processes",
                                num_proc=threads,
                                fail_value=empty_fp)

    wrapped_files = [[f] for f in sdf_files]
    results = parallelizer.run(generate_e3fp_fingerprint, wrapped_files)

    results_dict = {}
    for r in results:
        results_dict[r[1][0]] = r[0]
    fingerprints = [results_dict[name] for name in sdf_files]

    database = db.FingerprintDatabase()
    database.add_fingerprints(fingerprints)
    return database
Exemplo n.º 6
0
def main(sdf_file,
         save_freq=SAVE_FREQ,
         overwrite=False,
         log=None,
         parallel_mode=None,
         num_proc=None):
    setup_logging(log)
    logging.info("Reading mols from SDF.")
    supp = rdkit.Chem.SDMolSupplier(sdf_file)
    num_mol = len(supp)
    del supp

    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    start_end_indices = get_triangle_indices(num_mol, para.num_proc - 1)
    kwargs = {
        "sdf_file": sdf_file,
        "save_freq": save_freq,
        "overwrite": overwrite
    }
    para.run(run_batch, start_end_indices, kwargs=kwargs)
Exemplo n.º 7
0
def main(smiles_file, sdf_dir, out_file):
    _, fprint_params = params_to_dicts(load_params())
    smiles_dict = smiles_to_dict(smiles_file)

    para = Parallelizer()
    smiles_iter = ((smiles, get_sdf_file(name, sdf_dir), name)
                   for name, smiles in smiles_dict.items())
    kwargs = {"fprint_params": fprint_params}
    results_iter = para.run_gen(benchmark_fprinting,
                                smiles_iter,
                                kwargs=kwargs)

    with open(out_file, "w") as f:
        f.write("\t".join([
            "Name", "ECFP4 Time", "E3FP Time", "Num Heavy", "Num Confs",
            "Num Rot"
        ]) + "\n")
        for results, (_, _, name) in results_iter:
            print(results)
            f.write("{}\t{:.4g}\t{:.4g}\t{:d}\t{:d}\t{:d}\n".format(
                name, *results))
def main(molecules_file,
         log=None,
         overwrite=False,
         parallel_mode=None,
         num_proc=None,
         merge_confs=False,
         save_freq=SAVE_FREQ,
         compress=False):
    setup_logging(log)
    para = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)
    if para.is_master():
        data_iter = ((molecules_file, i, para.num_proc - 1)
                     for i in range(para.num_proc - 1))
    else:
        data_iter = iter([])

    kwargs = {
        "overwrite": overwrite,
        "merge_confs": merge_confs,
        "save_freq": save_freq,
        "compress": compress
    }
    para.run(run_batch, data_iter, kwargs=kwargs)
Exemplo n.º 9
0
def main(molecules_file="molecules.csv.bz2",
         targets_file="targets.csv.bz2",
         k=5,
         method='sea',
         tc_files=None,
         auc_type='sum',
         process_inputs=None,
         split_by='target',
         reduce_negatives=False,
         min_mols=50,
         affinity=10000,
         out_dir="./",
         overwrite=False,
         log=None,
         num_proc=None,
         parallel_mode=None,
         verbose=False):
    setup_logging(log, verbose=verbose)
    if num_proc is None:
        num_proc = k + 1
    parallelizer = Parallelizer(parallel_mode=parallel_mode, num_proc=num_proc)

    cv_class = CV_METHODS[method]
    if cv_class is MaxTanimotoCVMethod and tc_files is not None:
        score_matrix = ScoreMatrix(*tc_files)
        cv_class = cv_class(score_matrix)
    splitter_class = SPLITTERS[split_by]
    if isinstance(splitter_class, MoleculeSplitter):
        splitter = splitter_class(k=k)
    else:
        splitter = splitter_class(k=k, reduce_negatives=reduce_negatives)

    if process_inputs is not None:
        processor = InputProcessor(mode=process_inputs)
    else:
        processor = None

    kfold_cv = KFoldCrossValidator(k=5,
                                   parallelizer=parallelizer,
                                   splitter=splitter,
                                   input_processor=processor,
                                   cv_method=cv_class,
                                   return_auc_type=auc_type,
                                   out_dir=out_dir,
                                   overwrite=overwrite)
    auc = kfold_cv.run(molecules_file,
                       targets_file,
                       min_mols=min_mols,
                       affinity=affinity)
    logging.info("CV Mean AUC: {:.4f}".format(auc))
Exemplo n.º 10
0
def main(job_id, params, main_conf_dir=MAIN_CONF_DIR, main_dir=CV_DIR,
         out_dir=None, smiles_file=SMILES_FILE, check_existing=True,
         mol_targets_file=MOL_TARGETS_FILE, k=CV_K, log_file=LOG_FILE,
         verbose=False, overwrite=False, min_mols=MIN_MOLS_PER_TARGET,
         parallelizer=None):
    params = format_params(params)

    pre_encoding_params_string = params_to_str(params, with_first=False)
    params_string = params_to_str(params)
    if out_dir is None:
        out_dir = os.path.join(main_dir, params_string)
    touch_dir(out_dir)
    if log_file is not None:
        log_file = os.path.join(out_dir, log_file)
    setup_logging(log_file, verbose=verbose)

    params_file = os.path.join(out_dir, "params.cfg")
    config_parser = update_params(params, section_name="fingerprinting")
    write_params(config_parser, params_file)

    if not isinstance(parallelizer, Parallelizer):
        parallelizer = Parallelizer(parallel_mode="processes",
                                    num_proc=NUM_PROC)

    logging.info("Params: {!r}".format(params.items()))
    logging.info("Saving files to {:s}.".format(out_dir))

    logging.info("Checking for usable pre-existing fingerprints.")
    existing_molecules_file = get_existing_fprints(pre_encoding_params_string,
                                                   params['first'], main_dir)

    molecules_file = get_molecules_file(out_dir)
    if os.path.isfile(molecules_file) and not overwrite:
        logging.info("Molecules file already exists. Loading.")
        smiles_dict, mol_lists_dict, fp_type = molecules_to_lists_dicts(
            molecules_file)
    elif existing_molecules_file is None:
        conf_dir = os.path.join(main_conf_dir, params['conformers'])
        logging.info("Generating fingerprints from conformers in "
                     "{!s}.".format(conf_dir))
        smiles_dict, mol_lists_dict, fp_type = params_to_molecules(
            params, smiles_file, conf_dir, out_dir, parallelizer=parallelizer)
    else:
        logging.info("Using native strings from existing molecules "
                     "file {!s}.".format(existing_molecules_file))
        smiles_dict, mol_lists_dict, fp_type = molecules_to_lists_dicts(
            existing_molecules_file, first=params['first'])
        lists_dicts_to_molecules(get_molecules_file(out_dir),
                                 smiles_dict, mol_lists_dict, fp_type)

    targets_file = get_targets_file(out_dir)
    if overwrite or not os.path.isfile(targets_file):
        logging.info("Reading targets from {!s}.".format(mol_targets_file))
        targets_dict = targets_to_dict(mol_targets_file, affinity=AFFINITY)
        logging.debug("Read {:d} targets.".format(len(targets_dict)))
        logging.info("Filtering targets by molecules.")
        filtered_targets_dict = targets_to_mol_lists_targets(
            filter_targets_by_molecules(targets_dict, mol_lists_dict),
            mol_lists_dict)

        del targets_dict, smiles_dict, mol_lists_dict, fp_type
        logging.info("Saving filtered targets to {!s}.".format(targets_file))
        dict_to_targets(targets_file, filtered_targets_dict)
        del filtered_targets_dict
    else:
        logging.info("Targets file already exists. Skipping.")

    parallel_mode = parallelizer.parallel_mode
    parallelizer = Parallelizer(parallel_mode=parallel_mode, num_proc=k + 1)

    splitter = ByTargetMoleculeSplitter(k, reduce_negatives=REDUCE_NEGATIVES)
    kfold_cv = KFoldCrossValidator(k=k, parallelizer=parallelizer,
                                   splitter=splitter,
                                   return_auc_type=AUC_TYPE, out_dir=out_dir,
                                   overwrite=False)
    auc = kfold_cv.run(molecules_file, targets_file, min_mols=min_mols,
                       affinity=AFFINITY)
    logging.info("CV Mean AUC: {:.4f}".format(auc))
    return 1 - auc
Exemplo n.º 11
0
def run(sdf_files,
        bits=BITS,
        first=FIRST_DEF,
        level=LEVEL_DEF,
        radius_multiplier=RADIUS_MULTIPLIER_DEF,
        counts=COUNTS_DEF,
        stereo=STEREO_DEF,
        include_disconnected=INCLUDE_DISCONNECTED_DEF,
        rdkit_invariants=RDKIT_INVARIANTS_DEF,
        exclude_floating=EXCLUDE_FLOATING_DEF,
        params=None,
        out_dir_base=None,
        out_ext=OUT_EXT_DEF,
        db_file=None,
        overwrite=False,
        all_iters=False,
        log=None,
        num_proc=None,
        parallel_mode=None,
        verbose=False):
    """Generate E3FP fingerprints from SDF files."""
    setup_logging(log, verbose=verbose)

    if params is not None:
        params = read_params(params, fill_defaults=True)
        bits = get_value(params, "fingerprinting", "bits", int)
        first = get_value(params, "fingerprinting", "first", int)
        level = get_value(params, "fingerprinting", "level", int)
        radius_multiplier = get_value(params, "fingerprinting",
                                      "radius_multiplier", float)
        counts = get_value(params, "fingerprinting", "counts", bool)
        stereo = get_value(params, "fingerprinting", "stereo", bool)
        include_disconnected = get_value(params, "fingerprinting",
                                         "include_disconnected", bool)
        rdkit_invariants = get_value(params, "fingerprinting",
                                     "rdkit_invariants", bool)
        exclude_floating = get_value(params, "fingerprinting",
                                     "exclude_floating", bool)

    para = Parallelizer(num_proc=num_proc, parallel_mode=parallel_mode)

    if para.rank == 0:
        logging.info("Initializing E3FP generation.")
        logging.info("Getting SDF files")

        if len(sdf_files) == 1 and os.path.isdir(sdf_files[0]):
            from glob import glob
            sdf_files = glob("{:s}/*sdf*".format(sdf_files[0]))

        data_iterator = make_data_iterator(sdf_files)

        logging.info("SDF File Number: {:d}".format(len(sdf_files)))
        if out_dir_base is not None:
            logging.info("Out Directory Basename: {:s}".format(out_dir_base))
            logging.info("Out Extension: {:s}".format(out_ext))
        if db_file is not None:
            logging.info("Database File: {:s}".format(db_file))
        if db_file is None and out_dir_base is None:
            sys.exit('Either `db_file` or `out_dir_base` must be specified.')
        logging.info("Max First Conformers: {:d}".format(first))
        logging.info("Bits: {:d}".format(bits))
        logging.info("Level/Max Iterations: {:d}".format(level))
        logging.info(
            "Shell Radius Multiplier: {:.4g}".format(radius_multiplier))
        logging.info("Stereo Mode: {!s}".format(stereo))
        if include_disconnected:
            logging.info("Connected-only mode: on")
        if rdkit_invariants:
            logging.info("Invariant type: RDKit")
        else:
            logging.info("Invariant type: Daylight")
        logging.info("Parallel Mode: {!s}".format(para.parallel_mode))
        logging.info("Starting")
    else:
        data_iterator = iter([])

    fp_kwargs = {
        "first": first,
        "bits": bits,
        "level": level,
        "radius_multiplier": radius_multiplier,
        "stereo": stereo,
        "counts": counts,
        "include_disconnected": include_disconnected,
        "rdkit_invariants": rdkit_invariants,
        "exclude_floating": exclude_floating,
        "out_dir_base": out_dir_base,
        "out_ext": out_ext,
        "all_iters": all_iters,
        "overwrite": overwrite,
        "save": False
    }
    if out_dir_base is not None:
        fp_kwargs['save'] = True

    run_kwargs = {"kwargs": fp_kwargs}

    results_iter = para.run_gen(fprints_dict_from_sdf, data_iterator,
                                **run_kwargs)

    if db_file is not None:
        fprints = []
        for result, data in results_iter:
            try:
                fprints.extend(result.get(level, result[max(result.keys())]))
            except (AttributeError, ValueError):
                # fprinting failed, assume logged in method
                continue
        if len(fprints) > 0:
            db = FingerprintDatabase(fp_type=type(fprints[0]), level=level)
            db.add_fingerprints(fprints)
            db.save(db_file)
            logging.info("Saved fingerprints to {:s}".format(db_file))
    else:
        list(results_iter)
            pair[0], pair[1], tcs[0], tcs[1], precision))
    f.close()
    del example_pairs

    return counts_file, examples_file


if __name__ == "__main__":
    usage = "python compare_fingerprints.py <ecfp4_molecules> <e3fp_molecules>"
    try:
        ecfp_mol_file, e3fp_mol_file = sys.argv[1:]
    except ValueError:
        sys.exit(usage)

    setup_logging("log.txt")
    para = Parallelizer(parallel_mode="mpi")
    if para.rank == 0:
        logging.info("Reading molecules")
    ecfp_fp_sets = molecules_to_fp_sets(ecfp_mol_file)
    e3fp_fp_sets = molecules_to_fp_sets(e3fp_mol_file)

    mutual_mols = sorted(set(ecfp_fp_sets.keys()) & set(e3fp_fp_sets.keys()))
    mol_num = int(MOL_FRAC * len(mutual_mols))
    mols = mutual_mols[:mol_num]

    if para.rank == 0:
        logging.info(
            "Found total of {} mols. Selecting {} for comparison.".format(
                len(mutual_mols), mol_num))
    mols = sorted(np.random.choice(mutual_mols, size=mol_num, replace=False))
    pairs = ((i, j) for i in xrange(mol_num) for j in xrange(i + 1, mol_num))
Exemplo n.º 13
0
def run(
    mol2=None,
    smiles=None,
    standardise=STANDARDISE_DEF,
    num_conf=NUM_CONF_DEF,
    first=FIRST_DEF,
    pool_multiplier=POOL_MULTIPLIER_DEF,
    rmsd_cutoff=RMSD_CUTOFF_DEF,
    max_energy_diff=MAX_ENERGY_DIFF_DEF,
    forcefield=FORCEFIELD_DEF,
    seed=SEED_DEF,
    params=None,
    prioritize=False,
    out_dir=OUTDIR_DEF,
    compress=COMPRESS_DEF,
    overwrite=False,
    values_file=None,
    log=None,
    num_proc=None,
    parallel_mode=None,
    verbose=False,
):
    """Run conformer generation."""
    setup_logging(log, verbose=verbose)

    if params is not None:
        params = read_params(params)
        standardise = get_value(params, "preprocessing", "standardise", bool)
        num_conf = get_value(params, "conformer_generation", "num_conf", int)
        first = get_value(params, "conformer_generation", "first", int)
        pool_multiplier = get_value(params, "conformer_generation",
                                    "pool_multiplier", int)
        rmsd_cutoff = get_value(params, "conformer_generation", "rmsd_cutoff",
                                float)
        max_energy_diff = get_value(params, "conformer_generation",
                                    "max_energy_diff", float)
        forcefield = get_value(params, "conformer_generation", "forcefield")
        seed = get_value(params, "conformer_generation", "seed", int)

    # check args
    if forcefield not in FORCEFIELD_CHOICES:
        raise ValueError(
            "Specified forcefield {} is not in valid options {!r}".format(
                forcefield, FORCEFIELD_CHOICES))

    para = Parallelizer(num_proc=num_proc, parallel_mode=parallel_mode)

    # Check to make sure args make sense
    if mol2 is None and smiles is None:
        if para.is_master():
            parser.print_usage()
            logging.error("Please provide mol2 file or a SMILES file.")
        sys.exit()

    if mol2 is not None and smiles is not None:
        if para.is_master():
            parser.print_usage()
            logging.error("Please provide only a mol2 file OR a SMILES file.")
        sys.exit()

    if num_proc and num_proc < 1:
        if para.is_master():
            parser.print_usage()
            logging.error(
                "Please provide more than one processor with `--num_proc`.")
        sys.exit()

    # Set up input type
    if mol2 is not None:
        in_type = "mol2"
    elif smiles is not None:
        in_type = "smiles"

    if para.is_master():
        if in_type == "mol2":
            logging.info("Input type: mol2 file(s)")
            logging.info("Input file number: {:d}".format(len(mol2)))
            mol_iter = (mol_from_mol2(_mol2_file,
                                      _name,
                                      standardise=standardise)
                        for _mol2_file, _name in mol2_generator(*mol2))
        else:
            logging.info("Input type: Detected SMILES file(s)")
            logging.info("Input file number: {:d}".format(len(smiles)))
            mol_iter = (mol_from_smiles(_smiles,
                                        _name,
                                        standardise=standardise)
                        for _smiles, _name in smiles_generator(*smiles))

        if prioritize:
            logging.info(("Prioritizing mols with low rotatable bond number"
                          " and molecular weight first."))
            mols_with_properties = [(
                AllChem.CalcNumRotatableBonds(mol),
                AllChem.CalcExactMolWt(mol),
                mol,
            ) for mol in mol_iter if mol is not None]
            data_iterator = make_data_iterator(
                (x[-1] for x in sorted(mols_with_properties)))
        else:
            data_iterator = make_data_iterator(
                (x for x in mol_iter if x is not None))

        # Set up parallel-specific options
        logging.info("Parallel Type: {}".format(para.parallel_mode))

        # Set other options
        touch_dir(out_dir)

        if not num_conf:
            num_conf = -1

        logging.info("Out Directory: {}".format(out_dir))
        logging.info("Overwrite Existing Files: {}".format(overwrite))
        if values_file is not None:
            if os.path.exists(values_file) and overwrite is not True:
                value_args = (values_file, "a")
                logging.info("Values file: {} (append)".format((values_file)))
            else:
                value_args = (values_file, "w")
                logging.info("Values file: {} (new file)".format(
                    (values_file)))
        if num_conf is None or num_conf == -1:
            logging.info("Target Conformer Number: auto")
        else:
            logging.info("Target Conformer Number: {:d}".format(num_conf))
        if first is None or first == -1:
            logging.info("First Conformers Number: all")
        else:
            logging.info("First Conformers Number: {:d}".format(first))
        logging.info("Pool Multiplier: {:d}".format(pool_multiplier))
        logging.info("RMSD Cutoff: {:.4g}".format(rmsd_cutoff))
        if max_energy_diff is None:
            logging.info("Maximum Energy Difference: None")
        else:
            logging.info("Maximum Energy Difference: {:.4g} kcal".format(
                max_energy_diff))
        logging.info("Forcefield: {}".format(forcefield.upper()))
        if seed != -1:
            logging.info("Seed: {:d}".format(seed))

        logging.info("Starting.")
    else:
        data_iterator = iter([])

    gen_conf_kwargs = {
        "out_dir": out_dir,
        "num_conf": num_conf,
        "rmsd_cutoff": rmsd_cutoff,
        "max_energy_diff": max_energy_diff,
        "forcefield": forcefield,
        "pool_multiplier": pool_multiplier,
        "first": first,
        "seed": seed,
        "save": True,
        "overwrite": overwrite,
        "compress": compress,
    }

    run_kwargs = {"kwargs": gen_conf_kwargs}

    results_iterator = para.run_gen(generate_conformers, data_iterator,
                                    **run_kwargs)

    if para.is_master() and values_file is not None:
        hdf5_buffer = HDF5Buffer(*value_args)

    for result, data in results_iterator:
        if (para.is_master() and values_file is not None
                and result is not False):
            values_to_hdf5(hdf5_buffer, result)

    if para.is_master() and values_file is not None:
        hdf5_buffer.flush()
        hdf5_buffer.close()
Exemplo n.º 14
0
class KFoldCrossValidator(object):
    """Class to perform k-fold cross-validation."""
    def __init__(self,
                 k=5,
                 splitter=MoleculeSplitter,
                 cv_method=SEASearchCVMethod(),
                 input_processor=None,
                 parallelizer=None,
                 out_dir=os.getcwd(),
                 overwrite=False,
                 return_auc_type="roc",
                 reduce_negatives=False,
                 fold_kwargs={}):
        if isinstance(splitter, type):
            self.splitter = splitter(k)
        else:
            assert splitter.k == k
            self.splitter = splitter
        self.k = k
        if (cv_method is SEASearchCVMethod and input_processor is not None):
            raise ValueError(
                "Input processing is not (currently) compatible with SEA.")
        self.cv_method = cv_method
        self.input_processor = input_processor
        self.overwrite = overwrite
        if parallelizer is None:
            self.parallelizer = Parallelizer(parallel_mode="serial")
        else:
            self.parallelizer = parallelizer
        self.out_dir = out_dir
        touch_dir(out_dir)
        self.input_file = os.path.join(self.out_dir, "inputs.pkl.bz2")
        self.return_auc_type = return_auc_type.lower()
        self.reduce_negatives = reduce_negatives
        self.fold_kwargs = fold_kwargs

    def run(self,
            molecules_file,
            targets_file,
            min_mols=50,
            affinity=None,
            overwrite=False):
        fold_validators = {
            fold_num: FoldValidator(fold_num,
                                    self._get_fold_dir(fold_num),
                                    cv_method=copy.deepcopy(self.cv_method),
                                    input_file=self.input_file,
                                    overwrite=self.overwrite,
                                    **self.fold_kwargs)
            for fold_num in range(self.k)
        }
        if not os.path.isfile(self.input_file) or not all(
            [x.fold_files_exist() for x in fold_validators.values()]):
            logging.info("Loading and filtering input files.")
            ((smiles_dict, mol_list_dict, fp_type),
             target_dict) = self.load_input_files(molecules_file,
                                                  targets_file,
                                                  min_mols=min_mols,
                                                  affinity=affinity)

            mol_list = sorted(mol_list_dict.keys())
            if isinstance(self.cv_method, SEASearchCVMethod):
                # efficiency hack
                fp_array, mol_to_fp_inds = (None, None)
            else:
                logging.info("Converting inputs to arrays.")
                fp_array, mol_to_fp_inds = molecules_to_array(
                    mol_list_dict, mol_list, processor=self.input_processor)
            target_mol_array, target_list = targets_to_array(
                target_dict, mol_list, dtype=TARGET_MOL_DENSE_DTYPE)
            total_imbalance = get_imbalance(target_mol_array)

            if self.overwrite or not os.path.isfile(self.input_file):
                logging.info("Saving arrays and labels to files.")
                save_cv_inputs(self.input_file, fp_array, mol_to_fp_inds,
                               target_mol_array, target_list, mol_list)
            del fp_array, mol_to_fp_inds

            logging.info("Splitting data into {} folds using {}.".format(
                self.k,
                type(self.splitter).__name__))
            if self.splitter.reduce_negatives:
                logging.info("After splitting, negatives will be reduced.")
            train_test_masks = self.splitter.get_train_test_masks(
                target_mol_array)
            for fold_num, train_test_mask in enumerate(train_test_masks):
                logging.info(
                    "Saving inputs to files (fold {})".format(fold_num))
                fold_val = fold_validators[fold_num]
                fold_val.save_fold_files(train_test_mask, mol_list,
                                         target_list, smiles_dict,
                                         mol_list_dict, fp_type, target_dict)
            del (smiles_dict, mol_list_dict, fp_type, target_mol_array,
                 mol_list, train_test_masks, target_dict, target_list)
        else:
            logging.info("Resuming from input and fold files.")
            (fp_array, mol_to_fp_inds, target_mol_array, target_list,
             mol_list) = load_cv_inputs(self.input_file)
            total_imbalance = get_imbalance(target_mol_array)
            del (fp_array, mol_to_fp_inds, target_mol_array, target_list,
                 mol_list)

        # run cross-validation and gather scores
        logging.info("Running fold validation.")
        para_args = sorted(fold_validators.items())
        aucs = zip(*self.parallelizer.run(_run_fold, para_args))[0]
        if fold_validators.values()[0].compute_combined:
            aurocs, auprcs = zip(*aucs)
            mean_auroc = np.mean(aurocs)
            std_auroc = np.std(aurocs)
            logging.info("CV Mean AUROC: {:.4f} +/- {:.4f}".format(
                mean_auroc, std_auroc))
            mean_auprc = np.mean(auprcs)
            std_auprc = np.std(auprcs)
            logging.info(("CV Mean AUPRC: {:.4f} +/- {:.4f} ({:.4f} of data "
                          "is positive)").format(mean_auprc, std_auprc,
                                                 total_imbalance))
        else:
            (mean_auroc, mean_auprc) = (None, None)

        target_aucs = []
        for fold_val in fold_validators.values():
            with smart_open(fold_val.target_aucs_file, "rb") as f:
                target_aucs.extend(pkl.load(f).values())
        target_aucs = np.array(target_aucs)
        mean_target_auroc, mean_target_auprc = np.mean(target_aucs, axis=0)
        std_target_auroc, std_target_auprc = np.std(target_aucs, axis=0)
        logging.info("CV Mean Target AUROC: {:.4f} +/- {:.4f}".format(
            mean_target_auroc, std_target_auroc))
        logging.info("CV Mean Target AUPRC: {:.4f} +/- {:.4f}".format(
            mean_target_auprc, std_target_auprc))

        if "target" in self.return_auc_type or mean_auroc is None:
            logging.info("Returning target AUC.")
            aucs = (mean_target_auroc, mean_target_auprc)
        else:
            logging.info("Returning combined average AUC.")
            aucs = (mean_auroc, mean_auprc)

        if "pr" in self.return_auc_type:
            logging.info("Returned AUC is AUPRC.")
            return aucs[1]
        elif "sum" in self.return_auc_type:
            return sum(aucs)
        else:
            logging.info("Returned AUC is AUROC.")
            return aucs[0]

    def load_input_files(self,
                         molecules_file,
                         targets_file,
                         min_mols=50,
                         affinity=None,
                         overwrite=False):
        smiles_dict, mol_lists_dict, fp_type = molecules_to_lists_dicts(
            molecules_file)

        mol_names_target_dict = mol_lists_targets_to_targets(
            targets_to_dict(targets_file, affinity=affinity))
        target_dict = filter_targets_by_molecules(mol_names_target_dict,
                                                  mol_lists_dict)
        del mol_names_target_dict
        target_dict = filter_targets_by_molnum(target_dict, n=min_mols)

        return (smiles_dict, mol_lists_dict, fp_type), target_dict

    def _get_fold_dir(self, fold_num):
        return os.path.join(self.out_dir, str(fold_num))