예제 #1
0
def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, number_plot_fits = 0,
                  threads = 1, use_gpu = False, deviceid = 0):
    """Calculate core and accessory distances between query sequences and a sketched database

    For a reference database, runs the query against itself to find all pairwise
    core and accessory distances.

    Uses the relation :math:`pr(a, b) = (1-a)(1-c)^k`

    To get the ref and query name for each row of the returned distances, call to the iterator
    :func:`~PopPUNK.utils.iterDistRows` with the returned refList and queryList

    Args:
        rNames (list)
            Names of references to query
        qNames (list)
            Names of queries
        dbPrefix (str)
            Prefix for reference mash sketch database created by :func:`~constructDatabase`
        queryPrefix (str)
            Prefix for query mash sketch database created by :func:`~constructDatabase`
        klist (list)
            K-mer sizes to use in the calculation
        self (bool)
            Set true if query = ref
            (default = True)
        number_plot_fits (int)
            If > 0, the number of k-mer length fits to plot (saved as pdfs).
            Takes random pairs of comparisons and calls :func:`~PopPUNK.plot.plot_fit`
            (default = 0)
        threads (int)
            Number of threads to use in the mash process
            (default = 1)
        use_gpu (bool)
            Use a GPU for querying
            (default = False)
        deviceid (int)
            Index of the CUDA GPU device to use
            (default = 0)

    Returns:
         distMat (numpy.array)
            Core distances (column 0) and accessory distances (column 1) between
            refList and queryList
    """
    ref_db = dbPrefix + "/" + os.path.basename(dbPrefix)

    if self:
        if dbPrefix != queryPrefix:
            raise RuntimeError("Must use same db for self query")
        qNames = rNames

        # Calls to library
        distMat = pp_sketchlib.queryDatabase(ref_db, ref_db, rNames, rNames, klist,
                                             True, False, threads, use_gpu, deviceid)

        # option to plot core/accessory fits. Choose a random number from cmd line option
        if number_plot_fits > 0:
            jacobian = -np.hstack((np.ones((klist.shape[0], 1)), klist.reshape(-1, 1)))
            for plot_idx in range(number_plot_fits):
                example = sample(rNames, k=2)
                raw = np.zeros(len(klist))
                corrected = np.zeros(len(klist))
                for kidx, kmer in enumerate(klist):
                    raw[kidx] = pp_sketchlib.jaccardDist(ref_db, example[0], example[1], kmer, False)
                    corrected[kidx] = pp_sketchlib.jaccardDist(ref_db, example[0], example[1], kmer, True)
                raw_fit = fitKmerCurve(raw, klist, jacobian)
                corrected_fit = fitKmerCurve(corrected, klist, jacobian)
                plot_fit(klist,
                         raw,
                         raw_fit,
                         corrected,
                         corrected_fit,
                         dbPrefix + "/" + dbPrefix + "_fit_example_" + str(plot_idx + 1),
                         "Example fit " + str(plot_idx + 1) + " - " +  example[0] + " vs. " + example[1])
    else:
        duplicated = set(rNames).intersection(set(qNames))
        if len(duplicated) > 0:
            sys.stderr.write("Sample names in query are contained in reference database:\n")
            sys.stderr.write("\n".join(duplicated))
            sys.stderr.write("Unique names are required!\n")
            sys.exit(1)

        # Calls to library
        query_db = queryPrefix + "/" + os.path.basename(queryPrefix)
        distMat = pp_sketchlib.queryDatabase(ref_db, query_db, rNames, qNames, klist,
                                             True, False, threads, use_gpu, deviceid)

    return distMat
예제 #2
0
def generate_visualisations(query_db, ref_db, distances, threads, output,
                            gpu_dist, deviceid, external_clustering,
                            microreact, phandango, grapetree, cytoscape,
                            perplexity, strand_preserved, include_files,
                            model_dir, previous_clustering,
                            previous_query_clustering, network_file, gpu_graph,
                            info_csv, rapidnj, tree, mst_distances, overwrite,
                            core_only, accessory_only, display_cluster, web):

    from .models import loadClusterFit

    from .network import construct_network_from_assignments
    from .network import fetchNetwork
    from .network import generate_minimum_spanning_tree
    from .network import load_network_file
    from .network import cugraph_to_graph_tool

    from .plot import drawMST
    from .plot import outputsForMicroreact
    from .plot import outputsForCytoscape
    from .plot import outputsForPhandango
    from .plot import outputsForGrapetree
    from .plot import writeClusterCsv

    from .prune_db import prune_distance_matrix

    from .sketchlib import readDBParams
    from .sketchlib import getKmersFromReferenceDatabase
    from .sketchlib import addRandom

    from .trees import load_tree, generate_nj_tree, mst_to_phylogeny

    from .utils import isolateNameToLabel
    from .utils import readPickle
    from .utils import setGtThreads
    from .utils import update_distance_matrices
    from .utils import readIsolateTypeFromCsv
    from .utils import joinClusterDicts
    from .utils import listDistInts

    # Check on parallelisation of graph-tools
    setGtThreads(threads)

    sys.stderr.write("PopPUNK: visualise\n")
    if not (microreact or phandango or grapetree or cytoscape):
        sys.stderr.write(
            "Must specify at least one type of visualisation to output\n")
        sys.exit(1)

    # make directory for new output files
    if not os.path.isdir(output):
        try:
            os.makedirs(output)
        except OSError:
            sys.stderr.write("Cannot create output directory\n")
            sys.exit(1)

    if distances is None:
        if query_db is None:
            distances = ref_db + "/" + os.path.basename(ref_db) + ".dists"
        else:
            distances = query_db + "/" + os.path.basename(query_db) + ".dists"
    else:
        distances = distances

    rlist, qlist, self, complete_distMat = readPickle(distances)
    if not self:
        qr_distMat = complete_distMat
    else:
        rr_distMat = complete_distMat

    # Fill in qq-distances if required
    if self == False:
        sys.stderr.write(
            "Note: Distances in " + distances + " are from assign mode\n"
            "Note: Distance will be extended to full all-vs-all distances\n"
            "Note: Re-run poppunk_assign with --update-db to avoid this\n")
        ref_db_loc = ref_db + "/" + os.path.basename(ref_db)
        rlist_original, qlist_original, self_ref, rr_distMat = readPickle(
            ref_db_loc + ".dists")
        if not self_ref:
            sys.stderr.write("Distances in " + ref_db +
                             " not self all-vs-all either\n")
            sys.exit(1)
        kmers, sketch_sizes, codon_phased = readDBParams(query_db)
        addRandom(query_db,
                  qlist,
                  kmers,
                  strand_preserved=strand_preserved,
                  threads=threads)
        query_db_loc = query_db + "/" + os.path.basename(query_db)
        qq_distMat = pp_sketchlib.queryDatabase(query_db_loc, query_db_loc,
                                                qlist, qlist, kmers, True,
                                                False, threads, gpu_dist,
                                                deviceid)

        # If the assignment was run with references, qrDistMat will be incomplete
        if rlist != rlist_original:
            rlist = rlist_original
            qr_distMat = pp_sketchlib.queryDatabase(ref_db_loc, query_db_loc,
                                                    rlist, qlist, kmers, True,
                                                    False, threads, gpu_dist,
                                                    deviceid)

    else:
        qlist = None
        qr_distMat = None
        qq_distMat = None

    # Turn long form matrices into square form
    combined_seq, core_distMat, acc_distMat = \
            update_distance_matrices(rlist, rr_distMat,
                                     qlist, qr_distMat, qq_distMat,
                                     threads = threads)

    # extract subset of distances if requested
    if include_files is not None:
        viz_subset = set()
        with open(include_files, 'r') as assemblyFiles:
            for assembly in assemblyFiles:
                viz_subset.add(assembly.rstrip())
        if len(viz_subset.difference(combined_seq)) > 0:
            sys.stderr.write(
                "--include-files contains names not in --distances\n")

        # Only keep found rows
        row_slice = [
            True if name in viz_subset else False for name in combined_seq
        ]
        combined_seq = [name for name in combined_seq if name in viz_subset]
        if qlist != None:
            qlist = list(viz_subset.intersection(qlist))
        core_distMat = core_distMat[np.ix_(row_slice, row_slice)]
        acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)]
    else:
        viz_subset = None

    # Either use strain definitions, lineage assignments or external clustering
    isolateClustering = {}
    # Use external clustering if specified
    if external_clustering:
        cluster_file = external_clustering
        isolateClustering = readIsolateTypeFromCsv(cluster_file,
                                                   mode='external',
                                                   return_dict=True)

    # identify existing model and cluster files
    if model_dir is not None:
        model_prefix = model_dir
    else:
        model_prefix = ref_db
    try:
        model_file = model_prefix + "/" + os.path.basename(model_prefix)
        model = loadClusterFit(model_file + '_fit.pkl',
                               model_file + '_fit.npz')
        model.set_threads(threads)
    except FileNotFoundError:
        sys.stderr.write('Unable to locate previous model fit in ' +
                         model_prefix + '\n')
        sys.exit(1)

    # Load previous clusters
    if previous_clustering is not None:
        prev_clustering = previous_clustering
        mode = "clusters"
        suffix = "_clusters.csv"
        if prev_clustering.endswith('_lineages.csv'):
            mode = "lineages"
            suffix = "_lineages.csv"
    else:
        # Identify type of clustering based on model
        mode = "clusters"
        suffix = "_clusters.csv"
        if model.type == "lineage":
            mode = "lineages"
            suffix = "_lineages.csv"
        if model.indiv_fitted:
            sys.stderr.write(
                "Note: Individual (core/accessory) fits found, but "
                "visualisation only supports combined boundary fit\n")
        prev_clustering = os.path.basename(
            model_file) + '/' + os.path.basename(model_file) + suffix
    isolateClustering = readIsolateTypeFromCsv(prev_clustering,
                                               mode=mode,
                                               return_dict=True)

    # Join clusters with query clusters if required
    if not self:
        if previous_query_clustering is not None:
            prev_query_clustering = previous_query_clustering
        else:
            prev_query_clustering = os.path.basename(
                query_db) + '/' + os.path.basename(query_db) + suffix

        queryIsolateClustering = readIsolateTypeFromCsv(prev_query_clustering,
                                                        mode=mode,
                                                        return_dict=True)
        isolateClustering = joinClusterDicts(isolateClustering,
                                             queryIsolateClustering)

    # Generate MST
    mst_tree = None
    mst_graph = None
    nj_tree = None
    if len(combined_seq) >= 3:
        # MST tree
        if tree == 'mst' or tree == 'both':
            existing_tree = None
            if not overwrite:
                existing_tree = load_tree(output,
                                          "MST",
                                          distances=mst_distances)
            if existing_tree is None:
                # Check selecting clustering type is in CSV
                clustering_name = 'Cluster'
                if display_cluster != None:
                    if display_cluster not in isolateClustering.keys():
                        clustering_name = list(isolateClustering.keys())[0]
                        sys.stderr.write('Unable to find clustering column ' +
                                         display_cluster + ' in file ' +
                                         prev_clustering + '; instead using ' +
                                         clustering_name + '\n')
                    else:
                        clustering_name = display_cluster
                else:
                    clustering_name = list(isolateClustering.keys())[0]
                # Get distance matrix
                complete_distMat = \
                    np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
                            pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1)))
                # Dense network may be slow
                sys.stderr.write(
                    "Generating MST from dense distances (may be slow)\n")
                G = construct_network_from_assignments(
                    combined_seq,
                    combined_seq, [0] * complete_distMat.shape[0],
                    within_label=0,
                    distMat=complete_distMat,
                    weights_type=mst_distances,
                    use_gpu=gpu_graph,
                    summarise=False)
                if gpu_graph:
                    G = cugraph.minimum_spanning_tree(G, weight='weights')
                mst_graph = generate_minimum_spanning_tree(G, gpu_graph)
                del G
                mst_as_tree = mst_to_phylogeny(
                    mst_graph,
                    isolateNameToLabel(combined_seq),
                    use_gpu=gpu_graph)
                if gpu_graph:
                    mst_graph = cugraph_to_graph_tool(
                        mst_graph, isolateNameToLabel(combined_seq))
                else:
                    vid = mst_graph.new_vertex_property(
                        'string', vals=isolateNameToLabel(combined_seq))
                    mst_graph.vp.id = vid
                drawMST(mst_graph, output, isolateClustering, clustering_name,
                        overwrite)
            else:
                mst_tree = existing_tree

        # Generate NJ tree
        if tree == 'nj' or tree == 'both':
            existing_tree = None
            if not overwrite:
                existing_tree = load_tree(output, "NJ")
            if existing_tree is None:
                nj_tree = generate_nj_tree(core_distMat,
                                           combined_seq,
                                           output,
                                           rapidnj,
                                           threads=threads)
            else:
                nj_tree = existing_tree
    else:
        sys.stderr.write("Fewer than three sequences, not drawing trees\n")

    # Now have all the objects needed to generate selected visualisations
    if microreact:
        sys.stderr.write("Writing microreact output\n")
        outputsForMicroreact(combined_seq,
                             isolateClustering,
                             nj_tree,
                             mst_tree,
                             acc_distMat,
                             perplexity,
                             output,
                             info_csv,
                             queryList=qlist,
                             overwrite=overwrite,
                             use_gpu=gpu_graph)

    if phandango:
        sys.stderr.write("Writing phandango output\n")
        outputsForPhandango(combined_seq,
                            isolateClustering,
                            nj_tree,
                            mst_tree,
                            output,
                            info_csv,
                            queryList=qlist,
                            overwrite=overwrite)

    if grapetree:
        sys.stderr.write("Writing grapetree output\n")
        outputsForGrapetree(combined_seq,
                            isolateClustering,
                            nj_tree,
                            mst_tree,
                            output,
                            info_csv,
                            queryList=qlist,
                            overwrite=overwrite)

    if cytoscape:
        sys.stderr.write("Writing cytoscape output\n")
        if network_file is None:
            sys.stderr.write(
                'Cytoscape output requires a network file is provided\n')
            sys.exit(1)
        genomeNetwork = load_network_file(network_file, use_gpu=gpu_graph)
        if gpu_graph:
            genomeNetwork = cugraph_to_graph_tool(
                genomeNetwork, isolateNameToLabel(combined_seq))
        outputsForCytoscape(genomeNetwork,
                            mst_graph,
                            combined_seq,
                            isolateClustering,
                            output,
                            info_csv,
                            viz_subset=viz_subset)
        if model.type == 'lineage':
            sys.stderr.write(
                "Note: Only support for output of cytoscape graph at lowest rank\n"
            )

    sys.stderr.write("\nDone\n")
예제 #3
0
def main():
    args = get_options()

    if args.min_k >= args.max_k or args.min_k < 3 or args.max_k > 101 or args.k_step < 1:
        sys.stderr.write(
            "Minimum kmer size " + str(args.min_k) +
            " must be smaller than maximum kmer size " + str(args.max_k) +
            "; range must be between 3 and 101, step must be at least one\n")
        sys.exit(1)
    kmers = np.arange(args.min_k, args.max_k + 1, args.k_step)

    #
    # Create a database (sketch input)
    #
    if args.sketch:
        names = []
        sequences = []

        with open(args.rfile, 'rU') as refFile:
            for refLine in refFile:
                refFields = refLine.rstrip().split("\t")
                names.append(refFields[0])
                sequences.append(list(refFields[1:]))

        if len(set(names)) != len(names):
            sys.stderr.write(
                "Input contains duplicate names! All names must be unique\n")
            sys.exit(1)

        pp_sketchlib.constructDatabase(args.ref_db, names, sequences, kmers,
                                       int(round(args.sketch_size / 64)),
                                       args.codon_phased, not args.no_random,
                                       args.strand, args.min_count,
                                       args.exact_counter, args.cpus,
                                       args.use_gpu, args.gpu_id)

    #
    # Join two databases
    #
    elif args.join:
        join_name = args.output + ".h5"
        db1_name = args.ref_db + ".h5"
        db2_name = args.query_db + ".h5"

        hdf1 = h5py.File(db1_name, 'r')
        hdf2 = h5py.File(db2_name, 'r')

        try:
            v1 = hdf1['sketches'].attrs['sketch_version']
            v2 = hdf2['sketches'].attrs['sketch_version']
            if (v1 != v2):
                sys.stderr.write(
                    "Databases have been written with different sketch versions, "
                    "joining not recommended (but proceeding anyway)\n")
            p1 = hdf1['sketches'].attrs['codon_phased']
            p2 = hdf2['sketches'].attrs['codon_phased']
            if (p1 != p2):
                sys.stderr.write(
                    "One database uses codon-phased seeds - cannot join "
                    "with a standard seed database\n")
        except RuntimeError as e:
            sys.stderr.write("Unable to check sketch version\n")

        hdf_join = h5py.File(join_name + ".tmp",
                             'w')  # add .tmp in case join_name exists

        # Can only copy into new group, so for second file these are appended one at a time
        try:
            hdf1.copy('sketches', hdf_join)
            join_grp = hdf_join['sketches']
            read_grp = hdf2['sketches']
            for dataset in read_grp:
                join_grp.copy(read_grp[dataset], dataset)

            if 'random' in hdf1 or 'random' in hdf2:
                sys.stderr.write(
                    "Random matches found in one database, which will not be copied\n"
                    "Use --add-random to recalculate for the joined DB\n")
        except RuntimeError as e:
            sys.stderr.write("ERROR: " + str(e) + "\n")
            sys.stderr.write("Joining sketches failed\n")
            sys.exit(1)

        # Clean up
        hdf1.close()
        hdf2.close()
        hdf_join.close()
        os.rename(join_name + ".tmp", join_name)

    #
    # Query a database (calculate distances)
    #
    elif args.query:
        rList = getSampleNames(args.ref_db)
        qList = getSampleNames(args.query_db)

        if args.subset != None:
            subset = []
            with open(args.subset, 'r') as subset_file:
                for line in subset_file:
                    sample_name = line.rstrip().split("\t")[0]
                    subset.append(sample_name)
            rList = list(set(rList).intersection(subset))
            qList = list(set(qList).intersection(subset))
            if (len(rList) == 0 or len(qList) == 0):
                sys.stderr.write("Subset has removed all samples\n")
                sys.exit(1)

        # Check inputs overlap
        ref = h5py.File(args.ref_db + ".h5", 'r')
        query = h5py.File(args.query_db + ".h5", 'r')
        db_kmers = set(ref['sketches/' +
                           rList[0]].attrs['kmers']).intersection(
                               query['sketches/' + qList[0]].attrs['kmers'])
        if args.read_k:
            query_kmers = sorted(db_kmers)
        else:
            query_kmers = sorted(set(kmers).intersection(db_kmers))
            if (len(query_kmers) == 0):
                sys.stderr.write("No requested k-mer lengths found in DB\n")
                sys.exit(1)
            elif (len(query_kmers) < len(query_kmers)):
                sys.stderr.write(
                    "Some requested k-mer lengths not found in DB\n")
        ref.close()
        query.close()

        if args.sparse:
            sparseIdx = pp_sketchlib.queryDatabaseSparse(
                args.ref_db, args.query_db, rList, qList, query_kmers,
                not args.no_correction, args.threshold, args.kNN,
                not args.accessory, args.cpus, args.use_gpu, args.gpu_id)
            if args.print:
                if args.accessory:
                    distName = 'Accessory'
                else:
                    distName = 'Core'
                sys.stdout.write("\t".join(['Query', 'Reference', distName]) +
                                 "\n")

                (i_vec, j_vec, dist_vec) = sparseIdx
                for (i, j, dist) in zip(i_vec, j_vec, dist_vec):
                    sys.stdout.write("\t".join([rList[i], qList[j],
                                                str(dist)]) + "\n")

            else:
                coo_matrix = ijv_to_coo(sparseIdx, (len(rList), len(qList)),
                                        np.float32)
                storePickle(rList, qList, rList == qList, coo_matrix,
                            args.output)

        else:
            distMat = pp_sketchlib.queryDatabase(args.ref_db, args.query_db,
                                                 rList, qList, query_kmers,
                                                 not args.no_correction,
                                                 args.jaccard, args.cpus,
                                                 args.use_gpu, args.gpu_id)

            # get names order
            if args.print:
                names = iterDistRows(rList, qList, rList == qList)
                if not args.jaccard:
                    sys.stdout.write("\t".join(
                        ['Query', 'Reference', 'Core', 'Accessory']) + "\n")
                    for i, (ref, query) in enumerate(names):
                        sys.stdout.write("\t".join([
                            query, ref,
                            str(distMat[i, 0]),
                            str(distMat[i, 1])
                        ]) + "\n")
                else:
                    sys.stdout.write("\t".join(['Query', 'Reference'] +
                                               [str(i)
                                                for i in query_kmers]) + "\n")
                    for i, (ref, query) in enumerate(names):
                        sys.stdout.write("\t".join(
                            [query, ref] + [str(k)
                                            for k in distMat[i, ]]) + "\n")
            else:
                storePickle(rList, qList, rList == qList, distMat, args.output)

    #
    # Add random match chances to an older database
    #
    elif args.add_random:
        rList = getSampleNames(args.ref_db)
        ref = h5py.File(args.ref_db + ".h5", 'r')
        db_kmers = ref['sketches/' + rList[0]].attrs['kmers']
        ref.close()

        pp_sketchlib.addRandom(args.ref_db, rList, db_kmers, args.strand,
                               args.cpus)

    sys.exit(0)
예제 #4
0
    else:
        for query in querySeqs:
            for ref in refSeqs:
                yield (ref, query)


# Generate distances
rList = []
ref = h5py.File(ref_db + ".h5", 'r')
for sample_name in list(ref['sketches'].keys()):
    rList.append(sample_name)

db_kmers = ref['sketches/' + rList[0]].attrs['kmers']
ref.close()

distMat = pp_sketchlib.queryDatabase(ref_db, ref_db, rList, rList, db_kmers)
jaccard_dists = pp_sketchlib.queryDatabase(ref_db,
                                           ref_db,
                                           rList,
                                           rList,
                                           db_kmers,
                                           jaccard=True)
jaccard_dists_raw = pp_sketchlib.queryDatabase(ref_db,
                                               ref_db,
                                               rList,
                                               rList,
                                               db_kmers,
                                               jaccard=True,
                                               random_correct=False)
distMat = np.hstack((distMat, jaccard_dists, jaccard_dists_raw))
예제 #5
0
subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile1.txt --output batch1 --overwrite", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2", shell=True, check=True)
subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite", shell=True, check=True)

# Load updated distances
X2 = np.load("batch2/batch2.dists.npy")
with open("batch2/batch2.dists.pkl", 'rb') as pickle_file:
    rlist2, qlist, self = pickle.load(pickle_file)

# Get same distances from the full database
ref_db = "batch12/batch12"
ref_h5 = h5py.File(ref_db + ".h5", 'r')
db_kmers = sorted(ref_h5['sketches/' + rlist2[0]].attrs['kmers'])
ref_h5.close()
X1 = pp_sketchlib.queryDatabase(ref_db, ref_db, rlist2, rlist2, db_kmers,
                                True, False, 1, False, 0)

# Check distances match
run_regression(X1[:, 0], X2[:, 0])
run_regression(X1[:, 1], X2[:, 1])

# Check sparse distances after one query
with open("batch12/batch12.dists.pkl", 'rb') as pickle_file:
    rlist1, qlist1, self = pickle.load(pickle_file)
S1 = scipy.sparse.load_npz("batch12/batch12_rank2_fit.npz")
S2 = scipy.sparse.load_npz("batch2/batch2_rank2_fit.npz")
compare_sparse_matrices(S1,S2,rlist1,rlist2)

# Check distances after second query

# Check that order is the same after doing 1 + 2 + 3 with --update-db, as doing all of 1 + 2 + 3 together