Ejemplo n.º 1
0
def test_sourmash_signature_api():
    e = sourmash.MinHash(n=1, ksize=20)
    sig = sourmash.SourmashSignature(e)

    s = sourmash.save_signatures([sig])
    sig_x1 = sourmash.load_one_signature(s)
    sig_x2 = list(sourmash.load_signatures(s))[0]

    assert sig_x1 == sig
    assert sig_x2 == sig
Ejemplo n.º 2
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--db', nargs='+', action='append')
    p.add_argument('--query', nargs='+', action='append')
    p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='output CSV to this file instead of stdout')
    #p.add_argument('-v', '--verbose', action='store_true')
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args()

    if args.debug:
        global _print_debug
        _print_debug = True

    ksize_vals = set()
    scaled_vals = set()
    dblist = []

    # flatten --db and --query
    args.db = [item for sublist in args.db for item in sublist]
    args.query = [item for sublist in args.query for item in sublist]

    for db_name in args.db:
        print(u'\r\033[K', end=u'', file=sys.stderr)
        print('... loading database {}'.format(db_name),
              end='\r',
              file=sys.stderr)

        lca_db = LCA_Database()
        lca_db.load(db_name)

        ksize_vals.add(lca_db.ksize)
        if len(ksize_vals) > 1:
            raise Exception('multiple ksizes, quitting')
        scaled_vals.add(lca_db.scaled)
        if len(scaled_vals) > 1:
            raise Exception('multiple scaled vals, quitting')

        dblist.append(lca_db)

    print(u'\r\033[K', end=u'', file=sys.stderr)
    print('loaded {} databases for LCA use.'.format(len(dblist)))

    ksize = ksize_vals.pop()
    scaled = scaled_vals.pop()
    print('ksize={} scaled={}'.format(ksize, scaled))

    # for each query, gather all the matches across databases, then
    csvfp = csv.writer(sys.stdout)
    if args.output:
        print("outputting classifications to '{}'".format(args.output.name))
        csvfp = csv.writer(args.output)
    else:
        print("outputting classifications to stdout")
    csvfp.writerow(['ID'] + taxlist)

    total_count = 0
    n = 0
    total_n = len(args.query)
    for query_filename in args.query:
        n += 1
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            print(u'\r\033[K', end=u'', file=sys.stderr)
            print('... classifying {} (file {} of {})'.format(
                query_sig.name(), n, total_n),
                  end='\r',
                  file=sys.stderr)
            debug('classifying', query_sig.name())
            total_count += 1

            # make sure we're looking at the same scaled value as database
            query_sig.minhash = query_sig.minhash.downsample_scaled(scaled)

            lineage = classify_signature(query_sig, dblist, args.threshold)

            # output!
            row = [query_sig.name()]
            for taxrank, (rank, name) in itertools.zip_longest(taxlist,
                                                               lineage,
                                                               fillvalue=('',
                                                                          '')):
                if rank:
                    assert taxrank == rank
                row.append(name)

            csvfp.writerow(row)

    print(u'\r\033[K', end=u'', file=sys.stderr)
    print('classified {} signatures total'.format(total_count),
          file=sys.stderr)
Ejemplo n.º 3
0
def index(args):
    """
    main function for building an LCA database.
    """
    p = argparse.ArgumentParser()
    p.add_argument('csv', help='taxonomy spreadsheet')
    p.add_argument('lca_db_out', help='name to save database to')
    p.add_argument('signatures', nargs='+',
                   help='one or more sourmash signatures')
    p.add_argument('--scaled', default=10000, type=float)
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-C', '--start-column', default=2, type=int,
                   help='column at which taxonomic assignments start')
    p.add_argument('--tabs', action='store_true',
                   help='input spreadsheet is tab-delimited (default: commas)')
    p.add_argument('--no-headers', action='store_true',
                   help='no headers present in taxonomy spreadsheet')
    p.add_argument('--split-identifiers', action='store_true',
                   help='split names in signatures on whitspace and period')
    p.add_argument('-f', '--force', action='store_true')
    p.add_argument('--traverse-directory', action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('--report', help='output a report on anomalies, if any.')
    args = p.parse_args(args)

    if args.start_column < 2:
        error('error, --start-column cannot be less than 2')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    args.scaled = int(args.scaled)

    # first, load taxonomy spreadsheet
    delimiter = ','
    if args.tabs:
        delimiter = '\t'
    assignments, num_rows = load_taxonomy_assignments(args.csv,
                                               delimiter=delimiter,
                                               start_column=args.start_column,
                                               use_headers=not args.no_headers,
                                               force=args.force)

    # convert lineages to numbers.
    next_lineage_index = 0
    lineage_dict = {}

    assignments_idx = {}
    lineage_to_idx = {}
    for (ident, lineage) in assignments.items():
        idx = lineage_to_idx.get(lineage)
        if idx is None:
            idx = next_lineage_index
            next_lineage_index += 1

            lineage_dict[idx] = lineage
            lineage_to_idx[lineage] = idx

        assignments_idx[ident] = idx

    notify('{} distinct lineages in spreadsheet out of {} rows',
           len(lineage_dict), num_rows)

    # load signatures, construct index of hashvals to lineages
    hashval_to_lineage = defaultdict(set)
    md5_to_lineage = {}

    notify('finding signatures...')
    if args.traverse_directory:
        yield_all_files = False           # only pick up *.sig files?
        if args.force:
            yield_all_files = True
        inp_files = list(sourmash_args.traverse_find_sigs(args.signatures,
                                                          yield_all_files=yield_all_files))
    else:
        inp_files = list(args.signatures)

    n = 0
    total_n = len(inp_files)
    record_duplicates = set()
    record_no_lineage = set()
    record_remnants = set(assignments_idx.keys())
    for filename in inp_files:
        n += 1
        for sig in sourmash_lib.load_signatures(filename, ksize=args.ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading signature {} (file {} of {})', sig.name()[:30], n, total_n, end='\r')
            debug(filename, sig.name())

            if sig.md5sum() in md5_to_lineage:
                notify('\nWARNING: in file {}, duplicate md5sum: {}; skipping', filename, sig.md5sum())
                record_duplicates.add(filename)
                continue

            name = sig.name()
            if args.split_identifiers: # hack for NCBI-style names, etc.
                name = name.split(' ')[0].split('.')[0]

            # is this one for which we have a lineage assigned?
            lineage_idx = assignments_idx.get(name)
            if lineage_idx is None:
               notify('\nWARNING: no lineage assignment for {}.', name)
               record_no_lineage.add(name)
            else:
                # remove from our list of remnant lineages
                record_remnants.remove(name)

                # downsample to specified scaled; this has the side effect of
                # making sure they're all at the same scaled value!
                minhash = sig.minhash.downsample_scaled(args.scaled)

                # connect hashvals to lineage
                for hashval in minhash.get_mins():
                    hashval_to_lineage[hashval].add(lineage_idx)

                # store md5 -> lineage too
                md5_to_lineage[sig.md5sum()] = lineage_idx

    notify(u'\r\033[K', end=u'')
    notify('...found {} genomes with lineage assignments!!',
           len(md5_to_lineage))

    # remove those lineages with no genomes associated
    assigned_lineages = set(md5_to_lineage.values())
    lineage_dict_2 = {}
    for idx in assigned_lineages:
        lineage_dict_2[idx] = lineage_dict[idx]

    unused_lineages = set(lineage_dict.values()) - set(lineage_dict_2.values())

    notify('{} assigned lineages out of {} distinct lineages in spreadsheet',
           len(lineage_dict_2), len(lineage_dict))
    lineage_dict = lineage_dict_2

    # now, save!
    db_outfile = args.lca_db_out
    if not (db_outfile.endswith('.lca.json') or db_outfile.endswith('.lca.json.gz')):
        db_outfile += '.lca.json'
    notify('saving to LCA DB: {}'.format(db_outfile))

    db = lca_utils.LCA_Database()
    db.lineage_dict = lineage_dict
    db.hashval_to_lineage_id = hashval_to_lineage
    db.ksize = int(args.ksize)
    db.scaled = int(args.scaled)
    db.signatures_to_lineage = md5_to_lineage

    db.save(db_outfile)

    if record_duplicates or record_no_lineage or record_remnants or unused_lineages:
        if record_duplicates:
            notify('WARNING: {} duplicate signatures.', len(record_duplicates))
        if record_no_lineage:
            notify('WARNING: no lineage provided for {} signatures.',
                   len(record_no_lineage))
        if record_remnants:
            notify('WARNING: no signatures for {} lineage assignments.',
                   len(record_remnants))
        if unused_lineages:
            notify('WARNING: {} unused lineages.', len(unused_lineages))

        if args.report:
            notify("generating a report and saving in '{}'", args.report)
            generate_report(record_duplicates, record_no_lineage,
                            record_remnants, unused_lineages, args.report)
        else:
            notify('(You can use --report to generate a detailed report.)')
Ejemplo n.º 4
0
def classify(args):
    """
    main single-genome classification function.
    """
    p = argparse.ArgumentParser()
    p.add_argument('--db', nargs='+', action='append')
    p.add_argument('--query', nargs='+', action='append')
    p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='output CSV to this file instead of stdout')
    p.add_argument('--scaled', type=float)
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args(args)

    if not args.db:
        error('Error! must specify at least one LCA database with --db')
        sys.exit(-1)

    if not args.query:
        error('Error! must specify at least one query signature with --query')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    # flatten --db and --query
    args.db = [item for sublist in args.db for item in sublist]
    args.query = [item for sublist in args.query for item in sublist]

    # load all the databases
    dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
    notify('ksize={} scaled={}', ksize, scaled)

    # find all the queries
    notify('finding query signatures...')
    if args.traverse_directory:
        inp_files = list(sourmash_args.traverse_find_sigs(args.query))
    else:
        inp_files = list(args.query)

    # set up output
    csvfp = csv.writer(sys.stdout)
    if args.output:
        notify("outputting classifications to '{}'", args.output.name)
        csvfp = csv.writer(args.output)
    else:
        notify("outputting classifications to stdout")
    csvfp.writerow(['ID', 'status'] + list(lca_utils.taxlist()))

    # for each query, gather all the matches across databases
    total_count = 0
    n = 0
    total_n = len(inp_files)
    for query_filename in inp_files:
        n += 1
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            notify(u'\r\033[K', end=u'')
            notify('... classifying {} (file {} of {})',
                   query_sig.name(),
                   n,
                   total_n,
                   end='\r')
            debug('classifying', query_sig.name())
            total_count += 1

            # make sure we're looking at the same scaled value as database
            query_sig.minhash = query_sig.minhash.downsample_scaled(scaled)

            # do the classification
            lineage, status = classify_signature(query_sig, dblist,
                                                 args.threshold)
            debug(lineage)

            # output each classification to the spreadsheet
            row = [query_sig.name(), status]
            row += lca_utils.zip_lineage(lineage)

            # when outputting to stdout, make output intelligible
            if not args.output:
                notify(u'\r\033[K', end=u'')
            csvfp.writerow(row)

    notify(u'\r\033[K', end=u'')
    notify('classified {} signatures total', total_count)
Ejemplo n.º 5
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('csv')
    p.add_argument('revindex')
    p.add_argument('siglist', nargs='+')
    p.add_argument('--lca', nargs='+', default=LCA_DBs)
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-o', '--output', type=argparse.FileType('wt'),
                   help='output CSV to this file instead of stdout')
    #p.add_argument('-v', '--verbose', action='store_true')
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args()

    if args.debug:
        global _print_debug
        _print_debug = True

    ## load LCA databases
    lca_db_list = []
    for lca_filename in args.lca:
        print('loading LCA database from {}'.format(lca_filename),
              file=sys.stderr)
        lca_db = lca_json.LCA_Database(lca_filename)
        taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)
        lca_db_list.append((taxfoo, hashval_to_lca))
    
    # reverse index names -> taxids
    names_to_taxids = defaultdict(set)
    for taxid, (name, _, _) in taxfoo.taxid_to_names.items():
        names_to_taxids[name].add(taxid)

    ### parse spreadsheet
    r = csv.reader(open(args.csv, 'rt'))
    row_headers = ['identifier'] + taxlist

    print('examining spreadsheet headers...', file=sys.stderr)
    first_row = next(iter(r))

    n_disagree = 0
    for (column, value) in zip(row_headers, first_row):
        if column.lower() != value.lower():
            print('** assuming {} == {} in spreadsheet'.format(column, value),
                  file=sys.stderr)
            n_disagree += 1
            if n_disagree > 2:
                print('whoa, too many assumptions. are the headers right?',
                      file=sys.stderr)
                sys.exit(-1)

    confusing_lineages = defaultdict(list)
    incompatible_lineages = defaultdict(list)
    assignments = {}
    for row in r:
        lineage = list(zip(row_headers, row))

        ident = lineage[0][1]
        lineage = lineage[1:]

        # clean lineage of null names
        lineage = [(a,b) for (a,b) in lineage if b not in null_names]

        # ok, find the least-common-ancestor taxid...
        taxid, rest = get_lca_taxid_for_lineage(taxfoo, names_to_taxids,
                                                lineage)

        # and find the *lowest* identifiable ancestor taxid, just to see
        # if there are confusing lineages.
        lowest_taxid, lowest_rest = \
          get_lowest_taxid_for_lineage(taxfoo, names_to_taxids, lineage)

        # do they match? if not, report.
        if lowest_taxid != taxid:
            lowest_lineage = taxfoo.get_lineage(lowest_taxid, taxlist)
            lowest_str = ', '.join(lowest_lineage)

            # find last matching, in case different classification levels.
            match_lineage = [ b for (a, b) in lineage ]
            end = match_lineage.index(lowest_lineage[-1])
            assert end >= 0
            match_lineage = match_lineage[:end + 1]
            match_str = ', '.join(match_lineage)

            confusing_lineages[(match_str, lowest_str)].append(ident)

        # check! NCBI lineage should be lineage of taxid + rest
        ncbi_lineage = taxfoo.get_lineage(taxid, taxlist)
        assert len(ncbi_lineage)
        reconstructed = ncbi_lineage + [ b for (a,b) in rest ]

        # ...make a comparable lineage from the CSV line...
        csv_lineage = [ b for (a, b) in lineage ]

        # are NCBI-rooted and CSV lineages the same?? if not, report.
        if csv_lineage != reconstructed:
            csv_str = ", ".join(csv_lineage[:len(ncbi_lineage)])
            ncbi_str = ", ".join(ncbi_lineage)
            incompatible_lineages[(csv_str, ncbi_str)].append(ident)

        # all is well if we've reached this point! We've got NCBI-rooted
        # taxonomies and now we need to record. next:
        #
        # build a set of triples: (rank, name, taxid), where taxid can
        # be None.

        lineage_taxids = taxfoo.get_lineage_as_taxids(taxid)
        tuples_info = []
        for taxid in lineage_taxids:
            name = taxfoo.get_taxid_name(taxid)
            rank = taxfoo.get_taxid_rank(taxid)

            if rank in taxlist:
                tuples_info.append((rank, name))

        for (rank, name) in rest:
            assert rank in taxlist
            tuples_info.append((rank, name))

        assignments[ident] = tuples_info

    print("{} weird lineages that maybe don't match with NCBI.".format(len(confusing_lineages) + len(incompatible_lineages)), file=sys.stderr)

    ## next phase: collapse lineages etc.

    ## load revindex
    print('loading reverse index:', args.revindex, file=sys.stderr)
    custom_bins_ri = revindex_utils.HashvalRevindex(args.revindex)

    # load the signatures associated with each revindex.
    print('loading signatures for custom genomes...', file=sys.stderr)
    sigids_to_sig = {}
    for sigid, (filename, md5) in custom_bins_ri.sigid_to_siginfo.items():
        sig = revindex_utils.get_sourmash_signature(filename, md5)
        if sig.name() in assignments:
            sigids_to_sig[sigid] = sig
        else:
            debug('no assignment:', sig.name())

    # figure out what ksize we're talking about here! (this should
    # probably be stored on the revindex...)
    random_db_sig = next(iter(sigids_to_sig.values()))
    ksize = random_db_sig.minhash.ksize

    print('...found {} custom genomes that also have assignments!!'.format(len(sigids_to_sig)), file=sys.stderr)

    ## now, connect the dots: hashvals to custom classifications
    hashval_to_custom = defaultdict(list)
    for hashval, sigids in custom_bins_ri.hashval_to_sigids.items():
        for sigid in sigids:
            sig = sigids_to_sig.get(sigid, None)
            if sig:
                assignment = assignments[sig.name()]
                hashval_to_custom[hashval].append(assignment)

    # whew! done!! we can now go from a hashval to a custom assignment!!

    # for each query, gather all the matches in both custom and NCBI, then
    # classify.
    csvfp = csv.writer(sys.stdout)
    if args.output:
        print("outputting classifications to '{}'".format(args.output.name))
        csvfp = csv.writer(args.output)
    else:
        print("outputting classifications to stdout")
    csvfp.writerow(['ID'] + taxlist)

    total_count = 0
    for query_filename in args.siglist:
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            print(u'\r\033[K', end=u'', file=sys.stderr)
            print('... classifying {}'.format(query_sig.name()), end='\r',
                  file=sys.stderr)
            debug('classifying', query_sig.name())
            total_count += 1

            these_assignments = defaultdict(list)
            n_custom = 0
            for hashval in query_sig.minhash.get_mins():
                # custom
                assignment = hashval_to_custom.get(hashval, [])
                if assignment:
                    these_assignments[hashval].extend(assignment)
                    n_custom += 1

                # NCBI
                for (this_taxfoo, hashval_to_lca) in lca_db_list:
                    hashval_lca = hashval_to_lca.get(hashval)
                    if hashval_lca is not None and hashval_lca != 1:
                        lineage = this_taxfoo.get_lineage_as_dict(hashval_lca,
                                                                  taxlist)

                        tuple_info = []
                        for rank in taxlist:
                            if rank not in lineage:
                                break
                            tuple_info.append((rank, lineage[rank]))
                        these_assignments[hashval_lca].append(tuple_info)

            check_counts = Counter()
            for tuple_info in these_assignments.values():
                last_tup = tuple(tuple_info[-1])
                check_counts[last_tup] += 1

            debug('n custom hashvals:', n_custom)
            debug(pprint.pformat(check_counts.most_common()))

            # now convert to trees -> do LCA & counts
            counts = Counter()
            parents = {}
            for hashval in these_assignments:

                # for each list of tuple_info [(rank, name), ...] build
                # a tree that lets us discover least-common-ancestor.
                tuple_info = these_assignments[hashval]
                tree = build_tree(tuple_info)

                # also update a tree that we can ascend from leaves -> parents
                # for all assignments for all hashvals
                parents = build_reverse_tree(tuple_info, parents)

                # now find either a leaf or the first node with multiple
                # children; that's our least-common-ancestor node.
                lca, reason = find_lca(tree)
                counts[lca] += 1

            # ok, we now have the LCAs for each hashval, and their number
            # of counts. Now sum across "significant" LCAs - those above
            # threshold.

            tree = {}
            tree_counts = defaultdict(int)

            debug(pprint.pformat(counts.most_common()))

            n = 0
            for lca, count in counts.most_common():
                if count < THRESHOLD:
                    break

                n += 1

                xx = []
                parent = lca
                while parent:
                    xx.insert(0, parent)
                    tree_counts[parent] += count
                    parent = parents.get(parent)
                debug(n, count, xx[1:])

                # update tree with this set of assignments
                build_tree([xx], tree)

            if n > 1:
                debug('XXX', n)

            # now find LCA? or whatever.
            lca, reason = find_lca(tree)
            if reason == 0:               # leaf node
                debug('END', lca)
            else:                         # internal node
                debug('MULTI', lca)

            # backtrack to full lineage via parents
            lineage = []
            parent = lca
            while parent != ('root', 'root'):
                lineage.insert(0, parent)
                parent = parents.get(parent)

            # output!
            row = [query_sig.name()]
            for taxrank, (rank, name) in itertools.zip_longest(taxlist, lineage, fillvalue=('', '')):
                if rank:
                    assert taxrank == rank
                row.append(name)

            csvfp.writerow(row)

    print(u'\r\033[K', end=u'', file=sys.stderr)
    print('classified {} signatures total'.format(total_count), file=sys.stderr)
Ejemplo n.º 6
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('sigs', nargs='+')
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-f', '--force', action='store_true')
    p.add_argument('--scaled', type=float, default=10000)
    p.add_argument('--plot', default=None)
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='CSV output')
    p.add_argument('--step', type=int, default=1000)
    p.add_argument('--repeat', type=int, default=5)
    p.add_argument('--db', nargs='+', action='append')
    args = p.parse_args()

    if args.debug:
        set_debug(args.debug)

    args.scaled = int(args.scaled)

    dblist = []
    known_hashes = set()
    if args.db:
        args.db = [item for sublist in args.db for item in sublist]
        dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
        assert ksize == args.ksize
        notify('loaded {} LCA databases', len(dblist))

        for db in dblist:
            known_hashes.update(db.hashval_to_lineage_id.keys())
        notify('got {} known hashes!', len(known_hashes))

    notify('finding signatures...')
    if args.traverse_directory:
        yield_all_files = False  # only pick up *.sig files?
        if args.force:
            yield_all_files = True
        inp_files = list(
            sourmash_args.traverse_find_sigs(args.sigs,
                                             yield_all_files=yield_all_files))
    else:
        inp_files = list(args.sigs)

    n = 0
    total_n = len(inp_files)
    sigs = []
    total_hashvals = list()
    for filename in inp_files:
        n += 1
        for sig in sourmash_lib.load_signatures(filename, ksize=args.ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading signature {} (file {} of {})',
                   sig.name()[:30],
                   n,
                   total_n,
                   end='\r')
            debug(filename, sig.name())

            sig.minhash = sig.minhash.downsample_scaled(args.scaled)

            total_hashvals.extend(sig.minhash.get_mins())
            sigs.append(sig)

    notify(u'\r\033[K', end=u'')
    notify('...found {} signatures total in {} files.', len(sigs), total_n)

    distinct_hashvals = set(total_hashvals)
    notify('{} distinct out of {} total hashvals.', (len(distinct_hashvals)),
           len(total_hashvals))
    if known_hashes:
        n_known = len(known_hashes.intersection(distinct_hashvals))
        notify('{} of them known, or {:.1f}%', n_known,
               n_known / float(len(distinct_hashvals)) * 100)

    x = []
    y = []
    z = []
    notify('subsampling...')
    for n in range(0, len(total_hashvals), args.step):
        notify(u'\r\033[K', end=u'')
        notify('... {} of {}', n, len(total_hashvals), end='\r')
        avg = 0
        known = 0
        for j in range(0, args.repeat):
            subsample = random.sample(total_hashvals, n)
            distinct = len(set(subsample))
            if known_hashes:
                known += len(set(subsample).intersection(known_hashes))
            avg += distinct

        x.append(n)
        y.append(avg / args.repeat)
        z.append(known / args.repeat)

    notify('\n...done!')

    if args.output:
        w = csv.writer(args.output)
        w.writerow(['n', 'k', 'known'])
        for a, b, c in zip(x, y, z):
            w.writerow([a, b, c])

    if args.plot:
        from matplotlib import pyplot
        pyplot.plot(x, y)
        pyplot.savefig(args.plot)
Ejemplo n.º 7
0
def summarize_main(args):
    """
    main summarization function.
    """
    p = argparse.ArgumentParser()
    p.add_argument('--db', nargs='+', action='append')
    p.add_argument('--query', nargs='+', action='append')
    p.add_argument('--threshold', type=int, default=DEFAULT_THRESHOLD)
    p.add_argument('--traverse-directory',
                   action='store_true',
                   help='load all signatures underneath directories.')
    p.add_argument('-o',
                   '--output',
                   type=argparse.FileType('wt'),
                   help='CSV output')
    p.add_argument('--scaled', type=float)
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args(args)

    if not args.db:
        error('Error! must specify at least one LCA database with --db')
        sys.exit(-1)

    if not args.query:
        error('Error! must specify at least one query signature with --query')
        sys.exit(-1)

    if args.debug:
        set_debug(args.debug)

    if args.scaled:
        args.scaled = int(args.scaled)

    # flatten --db and --query
    args.db = [item for sublist in args.db for item in sublist]
    args.query = [item for sublist in args.query for item in sublist]

    # load all the databases
    dblist, ksize, scaled = lca_utils.load_databases(args.db, args.scaled)
    notify('ksize={} scaled={}', ksize, scaled)

    # find all the queries
    notify('finding query signatures...')
    if args.traverse_directory:
        inp_files = list(sourmash_args.traverse_find_sigs(args.query))
    else:
        inp_files = list(args.query)

    # for each query, gather all the hashvals across databases
    total_count = 0
    n = 0
    total_n = len(inp_files)
    hashvals = defaultdict(int)
    for query_filename in inp_files:
        n += 1
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            notify(u'\r\033[K', end=u'')
            notify('... loading {} (file {} of {})',
                   query_sig.name(),
                   n,
                   total_n,
                   end='\r')
            total_count += 1

            mh = query_sig.minhash.downsample_scaled(scaled)
            for hashval in mh.get_mins():
                hashvals[hashval] += 1

    notify(u'\r\033[K', end=u'')
    notify('loaded {} signatures from {} files total.', total_count, n)

    # get the full counted list of lineage counts in this signature
    lineage_counts = summarize(hashvals, dblist, args.threshold)

    # output!
    total = float(len(hashvals))
    for (lineage, count) in lineage_counts.items():
        if lineage:
            lineage = lca_utils.zip_lineage(lineage, truncate_empty=True)
            lineage = ';'.join(lineage)
        else:
            lineage = '(root)'

        p = count / total * 100.
        p = '{:.1f}%'.format(p)

        print_results('{:5} {:>5}   {}'.format(p, count, lineage))

    # CSV:
    if args.output:
        w = csv.writer(args.output)
        headers = ['count'] + list(lca_utils.taxlist())
        w.writerow(headers)

        for (lineage, count) in lineage_counts.items():
            debug('lineage:', lineage)
            row = [count] + lca_utils.zip_lineage(lineage)
            w.writerow(row)
Ejemplo n.º 8
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('lca_filename')
    p.add_argument('sigfiles', nargs='+')
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument(
        '--output-unassigned',
        type=argparse.FileType('wt'),
        help=
        'output unassigned portions of the query as a signature to this file')
    args = p.parse_args()

    # load lca info
    lca_db = lca_json.LCA_Database(args.lca_filename)
    taxfoo, hashval_to_lca, scaled = lca_db.get_database(args.ksize, SCALED)

    # load signatures
    siglist = []
    print('loading signatures from {} signature files'.format(
        len(args.sigfiles)))
    for sigfile in args.sigfiles:
        sigs = sourmash_lib.load_signatures(sigfile, ksize=args.ksize)
        sigs = list(sigs)
        siglist.extend(sigs)

    print('loaded {} signatures total at k={}'.format(len(siglist),
                                                      args.ksize))

    # downsample
    print('downsampling to scaled value: {}'.format(scaled))
    for sig in siglist:
        if sig.minhash.scaled < scaled:
            sig.minhash = sig.minhash.downsample_scaled(scaled)

    # now, extract hash values!
    hashvals = collections.defaultdict(int)
    for sig in siglist:
        for hashval in sig.minhash.get_mins():
            hashvals[hashval] += 1

    found = 0
    total = 0
    by_taxid = collections.defaultdict(int)

    unassigned_hashvals = set()

    # for every hash, get LCA of labels
    for hashval, count in hashvals.items():
        lca = hashval_to_lca.get(hashval)
        total += count

        if lca is None:
            by_taxid[0] += count
            unassigned_hashvals.add(hashval)
            continue

        by_taxid[lca] += count
        found += count

    print('found LCA classifications for', found, 'of', total, 'hashes')
    not_found = total - found

    # now, propogate counts up the taxonomic tree.
    by_taxid_lca = collections.defaultdict(int)
    for taxid, count in by_taxid.items():
        by_taxid_lca[taxid] += count

        parent = taxfoo.child_to_parent.get(taxid)
        while parent != None and parent != 1:
            by_taxid_lca[parent] += count
            parent = taxfoo.child_to_parent.get(parent)

    total_count = sum(by_taxid.values())

    # sort by lineage length
    x = []
    for taxid, count in by_taxid_lca.items():
        x.append((len(taxfoo.get_lineage(taxid)), taxid, count))

    x.sort()

    # ...aaaaaand output.
    print('{}\t{}\t{}\t{}\t{}\t{}'.format('percent', 'below', 'at node',
                                          'code', 'taxid', 'name'))
    for _, taxid, count_below in x:
        if taxid == 0:
            continue

        percent = round(100 * count_below / total_count, 2)
        count_at = by_taxid[taxid]

        rank = taxfoo.node_to_info.get(taxid)
        if rank:
            rank = rank[0]
            classify_code = kraken_rank_code.get(rank, '-')
        else:
            classify_code = '-'

        name = taxfoo.taxid_to_names.get(taxid)
        if name:
            name = name[0]
        else:
            name = '-'

        print('{}\t{}\t{}\t{}\t{}\t{}'.format(percent, count_below, count_at,
                                              classify_code, taxid, name))

    if not_found:
        classify_code = 'U'
        percent = round(100 * not_found / total_count, 2)
        count_below = not_found
        count_at = not_found
        taxid = 0
        name = 'not classified'

        print('{}\t{}\t{}\t{}\t{}\t{}'.format(percent, count_below, count_at,
                                              classify_code, taxid, name))

        if args.output_unassigned:
            outname = args.output_unassigned.name
            print('saving unassigned hashes to "{}"'.format(outname))

            e = sourmash_lib.MinHash(ksize=args.ksize, n=0, scaled=scaled)
            e.add_many(unassigned_hashvals)
            sourmash_lib.save_signatures(
                [sourmash_lib.SourmashSignature('', e)],
                args.output_unassigned)
Ejemplo n.º 9
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('csv')
    p.add_argument('lca_db_out')
    p.add_argument('genome_sigs', nargs='+')
    p.add_argument('--scaled', default=10000, type=float)
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-d', '--debug', action='store_true')
    p.add_argument('-1',
                   '--start-column',
                   default=2,
                   type=int,
                   help='column at which taxonomic assignments start')
    p.add_argument('-f', '--force', action='store_true')
    args = p.parse_args()

    if args.start_column < 2:
        print('error, --start-column cannot be less than 2', file=sys.stderr)
        sys.exit(-1)

    if args.debug:
        global _print_debug
        _print_debug = True

    scaled = int(args.scaled)
    ksize = int(args.ksize)

    # parse spreadsheet!
    r = csv.reader(open(args.csv, 'rt'))
    row_headers = ['identifiers']
    row_headers += ['_skip_'] * (args.start_column - 2)
    row_headers += taxlist

    # first check that headers are interpretable.
    print('examining spreadsheet headers...', file=sys.stderr)
    first_row = next(iter(r))

    n_disagree = 0
    for (column, value) in zip(row_headers, first_row):
        if column == '_skip_':
            continue

        if column.lower() != value.lower():
            print("** assuming column '{}' is {} in spreadsheet".format(
                value, column),
                  file=sys.stderr)
            n_disagree += 1
            if n_disagree > 2:
                print('whoa, too many assumptions. are the headers right?',
                      file=sys.stderr)
                if not args.force:
                    sys.exit(-1)
                print('...continue, because --force was specified.',
                      file=sys.argv)

    # convert
    assignments = {}
    for row in r:
        lineage = list(zip(row_headers, row))
        lineage = [x for x in lineage if x[0] != '_skip_']

        ident = lineage[0][1]
        lineage = lineage[1:]

        # clean lineage of null names
        lineage = [(a, b) for (a, b) in lineage if b not in null_names]

        # store lineage tuple
        assignments[ident] = tuple(lineage)

    # clean up with some indirection: convert lineages to numbers.
    next_lineage_index = 0
    lineage_dict = {}

    assignments_idx = {}
    lineage_to_idx = {}
    for (ident, lineage_tuple) in assignments.items():
        idx = lineage_to_idx.get(lineage_tuple)
        if idx is None:
            idx = next_lineage_index
            next_lineage_index += 1

            lineage_dict[idx] = lineage_tuple
            lineage_to_idx[lineage_tuple] = idx

        assignments_idx[ident] = idx

    # load signatures, construct index of hashvals to lineages
    hashval_to_lineage = defaultdict(list)
    md5_to_lineage = {}

    n = 0
    total_n = len(args.genome_sigs)
    for filename in args.genome_sigs:
        n += 1
        for sig in sourmash_lib.load_signatures(filename, ksize=args.ksize):
            print(u'\r\033[K', end=u'', file=sys.stderr)
            print('... loading signature {} (file {} of {})'.format(
                sig.name(), n, total_n),
                  end='\r',
                  file=sys.stderr)

            # is this one for which we have a lineage assigned?
            lineage_idx = assignments_idx.get(sig.name())
            if lineage_idx is not None:
                # downsample to specified scaled; this has the side effect of
                # making sure they're all at the same scaled value!
                sig.minhash = sig.minhash.downsample_scaled(args.scaled)

                # connect hashvals to lineage
                for hashval in sig.minhash.get_mins():
                    hashval_to_lineage[hashval].append(lineage_idx)

                # store md5 -> lineage too
                md5_to_lineage[sig.md5sum()] = lineage_idx

    print(u'\r\033[K', end=u'', file=sys.stderr)
    print('...found {} genomes with lineage assignments!!'.format(
        len(md5_to_lineage)),
          file=sys.stderr)

    # remove those lineages with no genomes associated
    assigned_lineages = set(md5_to_lineage.values())
    lineage_dict_2 = {}
    for idx in assigned_lineages:
        lineage_dict_2[idx] = lineage_dict[idx]

    print('{} assigned lineages out of {} distinct lineages in spreadsheet'.
          format(len(lineage_dict_2), len(lineage_dict)))
    lineage_dict = lineage_dict_2

    # now, save!
    print('saving to LCA DB v2: {}'.format(args.lca_db_out))
    with open(args.lca_db_out, 'wt') as fp:
        save_d = OrderedDict()
        save_d['version'] = '1.0'
        save_d['type'] = 'sourmash_lca'
        save_d['license'] = 'CC0'
        save_d['ksize'] = ksize
        save_d['scaled'] = scaled
        # convert lineage internals from tuples to dictionaries
        save_d['lineages'] = OrderedDict([ (k, OrderedDict(v)) \
                                           for k, v in lineage_dict.items() ])
        save_d['hashval_assignments'] = hashval_to_lineage
        save_d['signatures_to_lineage'] = md5_to_lineage
        json.dump(save_d, fp)