Ejemplo n.º 1
0
def make_lca_counts(dblist):
    """
    Collect counts of all the LCAs in the list of databases.

    CTB this could usefully be converted to a generator function.
    """

    # gather all hashvalue assignments from across all the databases
    assignments = defaultdict(set)
    for lca_db in dblist:
        for hashval, lid_list in lca_db.hashval_to_lineage_id.items():
            lineages = [lca_db.lineage_dict[lid] for lid in lid_list]
            assignments[hashval].update(lineages)

    # now convert to trees -> do LCA & counts
    counts = defaultdict(int)
    for hashval, lineages in assignments.items():

        # for each list of tuple_info [(rank, name), ...] build
        # a tree that lets us discover lowest-common-ancestor.
        debug(lineages)
        tree = lca_utils.build_tree(lineages)

        # now find either a leaf or the first node with multiple
        # children; that's our lowest-common-ancestor node.
        lca, reason = lca_utils.find_lca(tree)
        counts[lca] += 1

    return counts
Ejemplo n.º 2
0
def classify_signature(query_sig, dblist, threshold):
    """
    Classify 'query_sig' using the given list of databases.

    Insist on at least 'threshold' counts of a given lineage before taking
    it seriously.

    Return (lineage, status) where 'lineage' is a tuple of LineagePairs
    and 'status' is either 'nomatch', 'found', or 'disagree'.

    This function proceeds in two stages:

       * first, build a list of assignments for all the lineages for each
         hashval.  (For e.g. kraken, this is done in the database preparation
         step; here, we do it dynamically each time.
       * then, across all the hashvals, count the number of times each linage
         shows up, and filter out low-abundance ones (under threshold).
         Then, determine the LCA of all of those.

      """
    # gather assignments from across all the databases
    assignments = lca_utils.gather_assignments(query_sig.minhash.get_mins(),
                                               dblist)

    # now convert to trees -> do LCA & counts
    counts = lca_utils.count_lca_for_assignments(assignments)
    debug(counts.most_common())

    # ok, we now have the LCAs for each hashval, and their number of
    # counts. Now build a tree across "significant" LCAs - those above
    # threshold.

    tree = {}

    for lca, count in counts.most_common():
        if count < threshold:
            break

        # update tree with this set of assignments
        lca_utils.build_tree([lca], tree)

    status = 'nomatch'
    if not tree:
        return [], status

    # now find lowest-common-ancestor of the resulting tree.
    lca, reason = lca_utils.find_lca(tree)
    if reason == 0:  # leaf node
        debug('END', lca)
        status = 'found'
    else:  # internal node => disagreement
        debug('MULTI', lca)
        status = 'disagree'

    debug('lineage is:', lca)

    return lca, status
Ejemplo n.º 3
0
def summarize(hashvals, dblist, threshold):
    """
    Classify 'hashvals' using the given list of databases.

    Insist on at least 'threshold' counts of a given lineage before taking
    it seriously.

    Return (lineage, counts) where 'lineage' is a tuple of LineagePairs.
    """

    # gather assignments from across all the databases
    assignments = lca_utils.gather_assignments(hashvals, dblist)

    # now convert to trees -> do LCA & counts
    counts = lca_utils.count_lca_for_assignments(assignments)
    debug(counts.most_common())

    # ok, we now have the LCAs for each hashval, and their number
    # of counts. Now aggregate counts across the tree, going up from
    # the leaves.
    aggregated_counts = defaultdict(int)
    for lca, count in counts.most_common():
        if count < threshold:
            break

        if not lca:
            aggregated_counts[lca] += count

        # climb from the lca to the root.
        while lca:
            aggregated_counts[lca] += count
            lca = lca[:-1]

    debug(aggregated_counts)

    return aggregated_counts
Ejemplo n.º 4
0
def gather_main(args):
    """
    Do a greedy search for the hash components of a query against an LCA db.

    Here we don't actually do a least-common-ancestor search of any kind; we
    do essentially the same kind of search as we do in `sourmash gather`, with
    the main difference that we are implicitly combining different genomes of
    identical lineages.

    This takes advantage of the structure of the LCA db, where we store the
    full lineage information for each known hash, as opposed to storing only
    the least-common-ancestor information for it.
    """
    p = argparse.ArgumentParser(prog="sourmash lca gather")
    p.add_argument('query')
    p.add_argument('db', nargs='+')
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-o', '--output', type=argparse.FileType('wt'),
                   help='output CSV containing matches to this file')
    p.add_argument('--output-unassigned', type=argparse.FileType('wt'),
                   help='output unassigned portions of the query as a signature to this file')
    p.add_argument('--ignore-abundance',  action='store_true',
                   help='do NOT use k-mer abundances if present')
    args = p.parse_args(args)

    if args.debug:
        set_debug(args.debug)

    # load all the databases
    dblist, ksize, scaled = lca_utils.load_databases(args.db, None)

    # for each query, gather all the matches across databases
    query_sig = sourmash_args.load_query_signature(args.query, ksize, 'DNA')
    debug('classifying', query_sig.name())

    # make sure we're looking at the same scaled value as database
    query_sig.minhash = query_sig.minhash.downsample_scaled(scaled)

    # do the classification, output results
    found = []
    for result, f_unassigned, est_bp, remaining_mins in gather_signature(query_sig, dblist, args.ignore_abundance):
        # is this our first time through the loop? print headers, if so.
        if not len(found):
            print_results("")
            print_results("overlap     p_query p_match ")
            print_results("---------   ------- --------")

        # output!
        pct_query = '{:.1f}%'.format(result.f_unique_to_query*100)
        pct_match = '{:.1f}%'.format(result.f_match*100)
        str_bp = format_bp(result.intersect_bp)
        name = format_lineage(result.lineage)

        equal_match_str = ""
        if result.n_equal_matches:
            equal_match_str = " (** {} equal matches)".format(result.n_equal_matches)

        print_results('{:9}   {:>6}  {:>6}      {}{}', str_bp, pct_query,
                      pct_match, name, equal_match_str)

        found.append(result)

    if found:
        print_results('')
        if f_unassigned:
            print_results('{:.1f}% ({}) of hashes have no assignment.', f_unassigned*100,
                          format_bp(est_bp))
        else:
            print_results('Query is completely assigned.')
            print_results('')
    # nothing found.
    else:
        est_bp = len(query_sig.minhash.get_mins()) * query_sig.minhash.scaled
        print_results('')
        print_results('No assignment for est {} of sequence.',
                      format_bp(est_bp))
        print_results('')

    if not found:
        sys.exit(0)

    if args.output:
        fieldnames = ['intersect_bp', 'f_match', 'f_unique_to_query', 'f_unique_weighted',
                      'average_abund', 'name', 'n_equal_matches'] + list(lca_utils.taxlist())

        w = csv.DictWriter(args.output, fieldnames=fieldnames)
        w.writeheader()
        for result in found:
            lineage = result.lineage
            d = dict(result._asdict())
            del d['lineage']

            for (rank, value) in lineage:
                d[rank] = value

            w.writerow(d)

    if args.output_unassigned:
        if not found:
            notify('nothing found - entire query signature unassigned.')
        elif not remaining_mins:
            notify('no unassigned hashes! not saving.')
        else:
            outname = args.output_unassigned.name
            notify('saving unassigned hashes to "{}"', outname)

            e = query_sig.minhash.copy_and_clear()
            e.add_many(remaining_mins)

            sourmash_lib.save_signatures([ sourmash_lib.SourmashSignature(e) ],
                                         args.output_unassigned)
Ejemplo n.º 5
0
def classify(args):
    """
    main single-genome classification function.
    """
    p = argparse.ArgumentParser()
    p.add_argument('--db', nargs='+', action='append')
    p.add_argument('--query', nargs='+', action='append')
    p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='output CSV to this file instead of stdout')
    p.add_argument('--scaled', type=float)
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args(args)

    if not args.db:
        error('Error! must specify at least one LCA database with --db')
        sys.exit(-1)

    if not args.query:
        error('Error! must specify at least one query signature with --query')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    # flatten --db and --query
    args.db = [item for sublist in args.db for item in sublist]
    args.query = [item for sublist in args.query for item in sublist]

    # load all the databases
    dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
    notify('ksize={} scaled={}', ksize, scaled)

    # find all the queries
    notify('finding query signatures...')
    if args.traverse_directory:
        inp_files = list(sourmash_args.traverse_find_sigs(args.query))
    else:
        inp_files = list(args.query)

    # set up output
    csvfp = csv.writer(sys.stdout)
    if args.output:
        notify("outputting classifications to '{}'", args.output.name)
        csvfp = csv.writer(args.output)
    else:
        notify("outputting classifications to stdout")
    csvfp.writerow(['ID', 'status'] + list(lca_utils.taxlist()))

    # for each query, gather all the matches across databases
    total_count = 0
    n = 0
    total_n = len(inp_files)
    for query_filename in inp_files:
        n += 1
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            notify(u'\r\033[K', end=u'')
            notify('... classifying {} (file {} of {})',
                   query_sig.name(),
                   n,
                   total_n,
                   end='\r')
            debug('classifying', query_sig.name())
            total_count += 1

            # make sure we're looking at the same scaled value as database
            query_sig.minhash = query_sig.minhash.downsample_scaled(scaled)

            # do the classification
            lineage, status = classify_signature(query_sig, dblist,
                                                 args.threshold)
            debug(lineage)

            # output each classification to the spreadsheet
            row = [query_sig.name(), status]
            row += lca_utils.zip_lineage(lineage)

            # when outputting to stdout, make output intelligible
            if not args.output:
                notify(u'\r\033[K', end=u'')
            csvfp.writerow(row)

    notify(u'\r\033[K', end=u'')
    notify('classified {} signatures total', total_count)
Ejemplo n.º 6
0
def index(args):
    """
    main function for building an LCA database.
    """
    p = argparse.ArgumentParser()
    p.add_argument('csv', help='taxonomy spreadsheet')
    p.add_argument('lca_db_out', help='name to save database to')
    p.add_argument('signatures', nargs='+',
                   help='one or more sourmash signatures')
    p.add_argument('--scaled', default=10000, type=float)
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-C', '--start-column', default=2, type=int,
                   help='column at which taxonomic assignments start')
    p.add_argument('--tabs', action='store_true',
                   help='input spreadsheet is tab-delimited (default: commas)')
    p.add_argument('--no-headers', action='store_true',
                   help='no headers present in taxonomy spreadsheet')
    p.add_argument('--split-identifiers', action='store_true',
                   help='split names in signatures on whitspace and period')
    p.add_argument('-f', '--force', action='store_true')
    p.add_argument('--traverse-directory', action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('--report', help='output a report on anomalies, if any.')
    args = p.parse_args(args)

    if args.start_column < 2:
        error('error, --start-column cannot be less than 2')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    args.scaled = int(args.scaled)

    # first, load taxonomy spreadsheet
    delimiter = ','
    if args.tabs:
        delimiter = '\t'
    assignments, num_rows = load_taxonomy_assignments(args.csv,
                                               delimiter=delimiter,
                                               start_column=args.start_column,
                                               use_headers=not args.no_headers,
                                               force=args.force)

    # convert lineages to numbers.
    next_lineage_index = 0
    lineage_dict = {}

    assignments_idx = {}
    lineage_to_idx = {}
    for (ident, lineage) in assignments.items():
        idx = lineage_to_idx.get(lineage)
        if idx is None:
            idx = next_lineage_index
            next_lineage_index += 1

            lineage_dict[idx] = lineage
            lineage_to_idx[lineage] = idx

        assignments_idx[ident] = idx

    notify('{} distinct lineages in spreadsheet out of {} rows',
           len(lineage_dict), num_rows)

    # load signatures, construct index of hashvals to lineages
    hashval_to_lineage = defaultdict(set)
    md5_to_lineage = {}

    notify('finding signatures...')
    if args.traverse_directory:
        yield_all_files = False           # only pick up *.sig files?
        if args.force:
            yield_all_files = True
        inp_files = list(sourmash_args.traverse_find_sigs(args.signatures,
                                                          yield_all_files=yield_all_files))
    else:
        inp_files = list(args.signatures)

    n = 0
    total_n = len(inp_files)
    record_duplicates = set()
    record_no_lineage = set()
    record_remnants = set(assignments_idx.keys())
    for filename in inp_files:
        n += 1
        for sig in sourmash_lib.load_signatures(filename, ksize=args.ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading signature {} (file {} of {})', sig.name()[:30], n, total_n, end='\r')
            debug(filename, sig.name())

            if sig.md5sum() in md5_to_lineage:
                notify('\nWARNING: in file {}, duplicate md5sum: {}; skipping', filename, sig.md5sum())
                record_duplicates.add(filename)
                continue

            name = sig.name()
            if args.split_identifiers: # hack for NCBI-style names, etc.
                name = name.split(' ')[0].split('.')[0]

            # is this one for which we have a lineage assigned?
            lineage_idx = assignments_idx.get(name)
            if lineage_idx is None:
               notify('\nWARNING: no lineage assignment for {}.', name)
               record_no_lineage.add(name)
            else:
                # remove from our list of remnant lineages
                record_remnants.remove(name)

                # downsample to specified scaled; this has the side effect of
                # making sure they're all at the same scaled value!
                minhash = sig.minhash.downsample_scaled(args.scaled)

                # connect hashvals to lineage
                for hashval in minhash.get_mins():
                    hashval_to_lineage[hashval].add(lineage_idx)

                # store md5 -> lineage too
                md5_to_lineage[sig.md5sum()] = lineage_idx

    notify(u'\r\033[K', end=u'')
    notify('...found {} genomes with lineage assignments!!',
           len(md5_to_lineage))

    # remove those lineages with no genomes associated
    assigned_lineages = set(md5_to_lineage.values())
    lineage_dict_2 = {}
    for idx in assigned_lineages:
        lineage_dict_2[idx] = lineage_dict[idx]

    unused_lineages = set(lineage_dict.values()) - set(lineage_dict_2.values())

    notify('{} assigned lineages out of {} distinct lineages in spreadsheet',
           len(lineage_dict_2), len(lineage_dict))
    lineage_dict = lineage_dict_2

    # now, save!
    db_outfile = args.lca_db_out
    if not (db_outfile.endswith('.lca.json') or db_outfile.endswith('.lca.json.gz')):
        db_outfile += '.lca.json'
    notify('saving to LCA DB: {}'.format(db_outfile))

    db = lca_utils.LCA_Database()
    db.lineage_dict = lineage_dict
    db.hashval_to_lineage_id = hashval_to_lineage
    db.ksize = int(args.ksize)
    db.scaled = int(args.scaled)
    db.signatures_to_lineage = md5_to_lineage

    db.save(db_outfile)

    if record_duplicates or record_no_lineage or record_remnants or unused_lineages:
        if record_duplicates:
            notify('WARNING: {} duplicate signatures.', len(record_duplicates))
        if record_no_lineage:
            notify('WARNING: no lineage provided for {} signatures.',
                   len(record_no_lineage))
        if record_remnants:
            notify('WARNING: no signatures for {} lineage assignments.',
                   len(record_remnants))
        if unused_lineages:
            notify('WARNING: {} unused lineages.', len(unused_lineages))

        if args.report:
            notify("generating a report and saving in '{}'", args.report)
            generate_report(record_duplicates, record_no_lineage,
                            record_remnants, unused_lineages, args.report)
        else:
            notify('(You can use --report to generate a detailed report.)')
Ejemplo n.º 7
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('sigs', nargs='+')
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-f', '--force', action='store_true')
    p.add_argument('--scaled', type=float, default=10000)
    p.add_argument('--plot', default=None)
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='CSV output')
    p.add_argument('--step', type=int, default=1000)
    p.add_argument('--repeat', type=int, default=5)
    p.add_argument('--db', nargs='+', action='append')
    args = p.parse_args()

    if args.debug:
        set_debug(args.debug)

    args.scaled = int(args.scaled)

    dblist = []
    known_hashes = set()
    if args.db:
        args.db = [item for sublist in args.db for item in sublist]
        dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
        assert ksize == args.ksize
        notify('loaded {} LCA databases', len(dblist))

        for db in dblist:
            known_hashes.update(db.hashval_to_lineage_id.keys())
        notify('got {} known hashes!', len(known_hashes))

    notify('finding signatures...')
    if args.traverse_directory:
        yield_all_files = False  # only pick up *.sig files?
        if args.force:
            yield_all_files = True
        inp_files = list(
            sourmash_args.traverse_find_sigs(args.sigs,
                                             yield_all_files=yield_all_files))
    else:
        inp_files = list(args.sigs)

    n = 0
    total_n = len(inp_files)
    sigs = []
    total_hashvals = list()
    for filename in inp_files:
        n += 1
        for sig in sourmash_lib.load_signatures(filename, ksize=args.ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading signature {} (file {} of {})',
                   sig.name()[:30],
                   n,
                   total_n,
                   end='\r')
            debug(filename, sig.name())

            sig.minhash = sig.minhash.downsample_scaled(args.scaled)

            total_hashvals.extend(sig.minhash.get_mins())
            sigs.append(sig)

    notify(u'\r\033[K', end=u'')
    notify('...found {} signatures total in {} files.', len(sigs), total_n)

    distinct_hashvals = set(total_hashvals)
    notify('{} distinct out of {} total hashvals.', (len(distinct_hashvals)),
           len(total_hashvals))
    if known_hashes:
        n_known = len(known_hashes.intersection(distinct_hashvals))
        notify('{} of them known, or {:.1f}%', n_known,
               n_known / float(len(distinct_hashvals)) * 100)

    x = []
    y = []
    z = []
    notify('subsampling...')
    for n in range(0, len(total_hashvals), args.step):
        notify(u'\r\033[K', end=u'')
        notify('... {} of {}', n, len(total_hashvals), end='\r')
        avg = 0
        known = 0
        for j in range(0, args.repeat):
            subsample = random.sample(total_hashvals, n)
            distinct = len(set(subsample))
            if known_hashes:
                known += len(set(subsample).intersection(known_hashes))
            avg += distinct

        x.append(n)
        y.append(avg / args.repeat)
        z.append(known / args.repeat)

    notify('\n...done!')

    if args.output:
        w = csv.writer(args.output)
        w.writerow(['n', 'k', 'known'])
        for a, b, c in zip(x, y, z):
            w.writerow([a, b, c])

    if args.plot:
        from matplotlib import pyplot
        pyplot.plot(x, y)
        pyplot.savefig(args.plot)
Ejemplo n.º 8
0
def summarize_main(args):
    """
    main summarization function.
    """
    p = argparse.ArgumentParser()
    p.add_argument('--db', nargs='+', action='append')
    p.add_argument('--query', nargs='+', action='append')
    p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='CSV output')
    p.add_argument('--scaled', type=float)
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args(args)

    if not args.db:
        error('Error! must specify at least one LCA database with --db')
        sys.exit(-1)

    if not args.query:
        error('Error! must specify at least one query signature with --query')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    if args.scaled:
        args.scaled = int(args.scaled)

    # flatten --db and --query
    args.db = [item for sublist in args.db for item in sublist]
    args.query = [item for sublist in args.query for item in sublist]

    # load all the databases
    dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
    notify('ksize={} scaled={}', ksize, scaled)

    # find all the queries
    notify('finding query signatures...')
    if args.traverse_directory:
        inp_files = list(sourmash_args.traverse_find_sigs(args.query))
    else:
        inp_files = list(args.query)

    # for each query, gather all the hashvals across databases
    total_count = 0
    n = 0
    total_n = len(inp_files)
    hashvals = defaultdict(int)
    for query_filename in inp_files:
        n += 1
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading {} (file {} of {})',
                   query_sig.name(),
                   n,
                   total_n,
                   end='\r')
            total_count += 1

            mh = query_sig.minhash.downsample_scaled(scaled)
            for hashval in mh.get_mins():
                hashvals[hashval] += 1

    notify(u'\r\033[K', end=u'')
    notify('loaded {} signatures from {} files total.', total_count, n)

    # get the full counted list of lineage counts in this signature
    lineage_counts = summarize(hashvals, dblist, args.threshold)

    # output!
    total = float(len(hashvals))
    for (lineage, count) in lineage_counts.items():
        if lineage:
            lineage = lca_utils.zip_lineage(lineage, truncate_empty=True)
            lineage = ';'.join(lineage)
        else:
            lineage = '(root)'

        p = count / total * 100.
        p = '{:.1f}%'.format(p)

        print_results('{:5} {:>5}   {}'.format(p, count, lineage))

    # CSV:
    if args.output:
        w = csv.writer(args.output)
        headers = ['count'] + list(lca_utils.taxlist())
        w.writerow(headers)

        for (lineage, count) in lineage_counts.items():
            debug('lineage:', lineage)
            row = [count] + lca_utils.zip_lineage(lineage)
            w.writerow(row)