Ejemplo n.º 1
0
def main():
    args = get_options()

    if args.min_k >= args.max_k or args.min_k < 3 or args.max_k > 101 or args.k_step < 1:
        sys.stderr.write(
            "Minimum kmer size " + str(args.min_k) +
            " must be smaller than maximum kmer size " + str(args.max_k) +
            "; range must be between 3 and 101, step must be at least one\n")
        sys.exit(1)
    kmers = np.arange(args.min_k, args.max_k + 1, args.k_step)

    #
    # Create a database (sketch input)
    #
    if args.sketch:
        names = []
        sequences = []

        with open(args.rfile, 'rU') as refFile:
            for refLine in refFile:
                refFields = refLine.rstrip().split("\t")
                names.append(refFields[0])
                sequences.append(list(refFields[1:]))

        if len(set(names)) != len(names):
            sys.stderr.write(
                "Input contains duplicate names! All names must be unique\n")
            sys.exit(1)

        pp_sketchlib.constructDatabase(args.ref_db, names, sequences, kmers,
                                       int(round(args.sketch_size / 64)),
                                       args.codon_phased, not args.no_random,
                                       args.strand, args.min_count,
                                       args.exact_counter, args.cpus,
                                       args.use_gpu, args.gpu_id)

    #
    # Join two databases
    #
    elif args.join:
        join_name = args.output + ".h5"
        db1_name = args.ref_db + ".h5"
        db2_name = args.query_db + ".h5"

        hdf1 = h5py.File(db1_name, 'r')
        hdf2 = h5py.File(db2_name, 'r')

        try:
            v1 = hdf1['sketches'].attrs['sketch_version']
            v2 = hdf2['sketches'].attrs['sketch_version']
            if (v1 != v2):
                sys.stderr.write(
                    "Databases have been written with different sketch versions, "
                    "joining not recommended (but proceeding anyway)\n")
            p1 = hdf1['sketches'].attrs['codon_phased']
            p2 = hdf2['sketches'].attrs['codon_phased']
            if (p1 != p2):
                sys.stderr.write(
                    "One database uses codon-phased seeds - cannot join "
                    "with a standard seed database\n")
        except RuntimeError as e:
            sys.stderr.write("Unable to check sketch version\n")

        hdf_join = h5py.File(join_name + ".tmp",
                             'w')  # add .tmp in case join_name exists

        # Can only copy into new group, so for second file these are appended one at a time
        try:
            hdf1.copy('sketches', hdf_join)
            join_grp = hdf_join['sketches']
            read_grp = hdf2['sketches']
            for dataset in read_grp:
                join_grp.copy(read_grp[dataset], dataset)

            if 'random' in hdf1 or 'random' in hdf2:
                sys.stderr.write(
                    "Random matches found in one database, which will not be copied\n"
                    "Use --add-random to recalculate for the joined DB\n")
        except RuntimeError as e:
            sys.stderr.write("ERROR: " + str(e) + "\n")
            sys.stderr.write("Joining sketches failed\n")
            sys.exit(1)

        # Clean up
        hdf1.close()
        hdf2.close()
        hdf_join.close()
        os.rename(join_name + ".tmp", join_name)

    #
    # Query a database (calculate distances)
    #
    elif args.query:
        rList = getSampleNames(args.ref_db)
        qList = getSampleNames(args.query_db)

        if args.subset != None:
            subset = []
            with open(args.subset, 'r') as subset_file:
                for line in subset_file:
                    sample_name = line.rstrip().split("\t")[0]
                    subset.append(sample_name)
            rList = list(set(rList).intersection(subset))
            qList = list(set(qList).intersection(subset))
            if (len(rList) == 0 or len(qList) == 0):
                sys.stderr.write("Subset has removed all samples\n")
                sys.exit(1)

        # Check inputs overlap
        ref = h5py.File(args.ref_db + ".h5", 'r')
        query = h5py.File(args.query_db + ".h5", 'r')
        db_kmers = set(ref['sketches/' +
                           rList[0]].attrs['kmers']).intersection(
                               query['sketches/' + qList[0]].attrs['kmers'])
        if args.read_k:
            query_kmers = sorted(db_kmers)
        else:
            query_kmers = sorted(set(kmers).intersection(db_kmers))
            if (len(query_kmers) == 0):
                sys.stderr.write("No requested k-mer lengths found in DB\n")
                sys.exit(1)
            elif (len(query_kmers) < len(query_kmers)):
                sys.stderr.write(
                    "Some requested k-mer lengths not found in DB\n")
        ref.close()
        query.close()

        if args.sparse:
            sparseIdx = pp_sketchlib.queryDatabaseSparse(
                args.ref_db, args.query_db, rList, qList, query_kmers,
                not args.no_correction, args.threshold, args.kNN,
                not args.accessory, args.cpus, args.use_gpu, args.gpu_id)
            if args.print:
                if args.accessory:
                    distName = 'Accessory'
                else:
                    distName = 'Core'
                sys.stdout.write("\t".join(['Query', 'Reference', distName]) +
                                 "\n")

                (i_vec, j_vec, dist_vec) = sparseIdx
                for (i, j, dist) in zip(i_vec, j_vec, dist_vec):
                    sys.stdout.write("\t".join([rList[i], qList[j],
                                                str(dist)]) + "\n")

            else:
                coo_matrix = ijv_to_coo(sparseIdx, (len(rList), len(qList)),
                                        np.float32)
                storePickle(rList, qList, rList == qList, coo_matrix,
                            args.output)

        else:
            distMat = pp_sketchlib.queryDatabase(args.ref_db, args.query_db,
                                                 rList, qList, query_kmers,
                                                 not args.no_correction,
                                                 args.jaccard, args.cpus,
                                                 args.use_gpu, args.gpu_id)

            # get names order
            if args.print:
                names = iterDistRows(rList, qList, rList == qList)
                if not args.jaccard:
                    sys.stdout.write("\t".join(
                        ['Query', 'Reference', 'Core', 'Accessory']) + "\n")
                    for i, (ref, query) in enumerate(names):
                        sys.stdout.write("\t".join([
                            query, ref,
                            str(distMat[i, 0]),
                            str(distMat[i, 1])
                        ]) + "\n")
                else:
                    sys.stdout.write("\t".join(['Query', 'Reference'] +
                                               [str(i)
                                                for i in query_kmers]) + "\n")
                    for i, (ref, query) in enumerate(names):
                        sys.stdout.write("\t".join(
                            [query, ref] + [str(k)
                                            for k in distMat[i, ]]) + "\n")
            else:
                storePickle(rList, qList, rList == qList, distMat, args.output)

    #
    # Add random match chances to an older database
    #
    elif args.add_random:
        rList = getSampleNames(args.ref_db)
        ref = h5py.File(args.ref_db + ".h5", 'r')
        db_kmers = ref['sketches/' + rList[0]].attrs['kmers']
        ref.close()

        pp_sketchlib.addRandom(args.ref_db, rList, db_kmers, args.strand,
                               args.cpus)

    sys.exit(0)
Ejemplo n.º 2
0
def constructDatabase(assemblyList, klist, sketch_size, oPrefix,
                        threads, overwrite,
                        strand_preserved, min_count,
                        use_exact, qc_dict, calc_random = True,
                        codon_phased = False,
                        use_gpu = False, deviceid = 0):
    """Sketch the input assemblies at the requested k-mer lengths

    A multithread wrapper around :func:`~runSketch`. Threads are used to either run multiple sketch
    processes for each klist value, or increase the threads used by each ``mash sketch`` process
    if len(klist) > threads.

    Also calculates random match probability based on length of first genome
    in assemblyList.

    Args:
        assemblyList (str)
            File with locations of assembly files to be sketched
        klist (list)
            List of k-mer sizes to sketch
        sketch_size (int)
            Size of sketch (``-s`` option)
        oPrefix (str)
            Output prefix for resulting sketch files
        threads (int)
            Number of threads to use (default = 1)
        overwrite (bool)
            Whether to overwrite sketch DBs, if they already exist.
            (default = False)
        strand_preserved (bool)
            Ignore reverse complement k-mers (default = False)
        min_count (int)
            Minimum count of k-mer in reads to include
            (default = 0)
        use_exact (bool)
            Use exact count of k-mer appearance in reads
            (default = False)
        qc_dict (dict)
            Dict containg QC settings
        calc_random (bool)
            Add random match chances to DB (turn off for queries)
        codon_phased (bool)
            Use codon phased seeds
            (default = False)
        use_gpu (bool)
            Use GPU for read sketching
            (default = False)
        deviceid (int)
            GPU device id
            (default = 0)
    Returns:
        names (list)
            List of names included in the database (some may be pruned due
            to QC)
    """
    # read file names
    names, sequences = readRfile(assemblyList)

    # create directory
    dbname = oPrefix + "/" + os.path.basename(oPrefix)
    dbfilename = dbname + ".h5"
    if os.path.isfile(dbfilename) and overwrite == True:
        sys.stderr.write("Overwriting db: " + dbfilename + "\n")
        os.remove(dbfilename)

    # generate sketches
    pp_sketchlib.constructDatabase(dbname,
                                   names,
                                   sequences,
                                   klist,
                                   sketch_size,
                                   codon_phased,
                                   False,
                                   not strand_preserved,
                                   min_count,
                                   use_exact,
                                   threads,
                                   use_gpu,
                                   deviceid)

    # QC sequences
    if qc_dict['run_qc']:
        filtered_names = sketchlibAssemblyQC(oPrefix,
                                             names,
                                             klist,
                                             qc_dict,
                                             strand_preserved,
                                             threads)
    else:
        filtered_names = names

    # Add random matches if required
    # (typically on for reference, off for query)
    if (calc_random):
        addRandom(oPrefix,
                  filtered_names,
                  klist,
                  strand_preserved,
                  overwrite = True,
                  threads = threads)

    # return filtered file names
    return filtered_names
Ejemplo n.º 3
0
def constructDatabase(assemblyList,
                      klist,
                      sketch_size,
                      oPrefix,
                      ignoreLengthOutliers=False,
                      threads=1,
                      overwrite=False,
                      reads=False,
                      min_count=0):
    """Sketch the input assemblies at the requested k-mer lengths

    A multithread wrapper around :func:`~runSketch`. Threads are used to either run multiple sketch
    processes for each klist value, or increase the threads used by each ``mash sketch`` process
    if len(klist) > threads.

    Also calculates random match probability based on length of first genome
    in assemblyList.

    Args:
        assemblyList (str)
            File with locations of assembly files to be sketched
        klist (list)
            List of k-mer sizes to sketch
        sketch_size (int)
            Size of sketch (``-s`` option)
        oPrefix (str)
            Output prefix for resulting sketch files
        ignoreLengthOutliers (bool)
            Whether to check for outlying genome lengths (and error
            if found)

            (default = False)
        threads (int)
            Number of threads to use

            (default = 1)
        overwrite (bool)
            Whether to overwrite sketch DBs, if they already exist.

            (default = False)
        reads (bool)
            If any reads are being used as input, do not run QC

            (default = False)
    """
    names, sequences = readRfile(assemblyList)
    if not reads:
        genome_length, max_prob = assembly_qc(sequences, klist,
                                              ignoreLengthOutliers)
        sys.stderr.write("Worst random match probability at " +
                         str(min(klist)) + "-mers: " +
                         "{:.2f}".format(max_prob) + "\n")

    dbname = oPrefix + "/" + oPrefix
    dbfilename = dbname + ".h5"
    if os.path.isfile(dbfilename) and overwrite == True:
        sys.stderr.write("Overwriting db: " + dbfilename + "\n")
        os.remove(dbfilename)

    pp_sketchlib.constructDatabase(dbname, names, sequences, klist,
                                   sketch_size, min_count, threads)