예제 #1
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('lca_filename')
    p.add_argument('-k', '--ksize-list', default="31", type=str)
    p.add_argument('-o', '--output', type=argparse.FileType('wt'))
    args = p.parse_args()

    lca_db = lca_json.LCA_Database(args.lca_filename)

    ksizes = list(map(int, args.ksize_list.split(',')))

    ksize_to_rank_counts = dict()

    for ksize in ksizes:
        #assert ksize not in ksize_to_rank_counts
        taxfoo, hashval_to_lca, scaled = lca_db.get_database(ksize, None)

        rank_counts = summarize_lca_db(taxfoo, hashval_to_lca)
        ksize_to_rank_counts[ksize] = rank_counts

    # this should be enforced by summarize_lca_db(...)
    all_ranks = set()
    for rank_counts in ksize_to_rank_counts.values():
        all_ranks.update(rank_counts.keys())

    assert all_ranks - set(want_taxonomy) == set()

    if args.output:
        w = csv.writer(args.output)
    else:
        w = csv.writer(sys.stdout)

    w.writerow(['rank'] + ksizes)
    for rank in want_taxonomy:
        count_list = [rank]
        for ksize in ksizes:
            rank_counts = ksize_to_rank_counts[ksize]
            count = rank_counts.get(rank, 0)
            count_list.append(str(count))

        w.writerow(count_list)
def main():
    p = argparse.ArgumentParser()
    p.add_argument('lca_db')
    p.add_argument('csv')
    p.add_argument('outcsv')
    args = p.parse_args()

    lca_db = lca_json.LCA_Database(args.lca_db)
    taxfoo = lca_db.get_taxonomy()

    r = csv.reader(open(args.csv, 'rt', encoding='utf8'))
    next(r)

    results = []

    n = 0
    for row in r:
        n += 1
        last_rank = None
        last_name = None

        genome_name, row = row[0], row[1:]

        for rank, name in zip(want_taxonomy, row):
            if name == 'null' or name.startswith('novel'):
                break
            last_rank, last_name = rank, name

        taxid = get_taxids_for_name(taxfoo, **{last_rank: last_name})
        if taxid == -1:
            taxid = 1

        lineage = taxfoo.get_lineage(taxid, want_taxonomy)
        results.append((genome_name, str(taxid), ";".join(lineage)))
        print(genome_name, taxid, ";".join(lineage))

    with open(args.outcsv, 'wt', encoding='utf8') as outfp:
        w = csv.writer(outfp)
        for name, taxid, lineage in results:
            w.writerow([name, taxid, lineage])
예제 #3
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('dir')
    args = p.parse_args()

    # load all the LCA JSON files
    lca_db_list = []
    for lca_filename in LCA_DBs:
        lca_db = lca_json.LCA_Database(lca_filename)
        taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)
        lca_db_list.append((taxfoo, hashval_to_lca))
    
    print('loading all signatures:', args.dir)
    sigdict = load_all_signatures(args.dir, args.ksize)
    print('...loaded {} signatures at k={}'.format(len(sigdict), args.ksize))

    ###

    disagree_at = collections.defaultdict(int)

    n = 0
    for name, sig in sigdict.items():
        taxid_set = collections.defaultdict(int)
        for hashval in sig.minhash.get_mins():

            this_hashval_taxids = set()
            for (_, hashval_to_lca) in lca_db_list:
                hashval_lca = hashval_to_lca.get(hashval)
                if hashval_lca is not None:
                    this_hashval_taxids.add(hashval_lca)

            this_hashval_lca = taxfoo.find_lca(this_hashval_taxids)
            taxid_set[this_hashval_lca] += 1

        abundant_taxids = [k for k in taxid_set if taxid_set[k] >= THRESHOLD]

        if not abundant_taxids:
            continue

        n += 1

        ranks_found = collections.defaultdict(set)
        for taxid in abundant_taxids:
            d = taxfoo.get_lineage_as_dict(taxid)
            for k, v in d.items():
                ranks_found[k].add(v)

        found_disagree = False
        for rank in reversed(want_taxonomy):
            if len(ranks_found[rank]) > 1:
                disagree_at[rank] += 1
                found_disagree = True
                break

        if found_disagree:
            print('{} has multiple LCA at rank \'{}\': {}'.format(name,
                                                                  rank,
                                                                  ", ".join(ranks_found[rank])))

    print('for', args.dir, 'found', len(sigdict), 'signatures;')
    print('out of {} that could be classified, {} disagree at some rank.'.format(n, sum(disagree_at.values())))

    for rank in want_taxonomy:
        if disagree_at.get(rank):
            print('\t{}: {}'.format(rank, disagree_at.get(rank, 0)))
예제 #4
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('--lca', nargs='+', default=LCA_DBs)
    p.add_argument('sig')
    p.add_argument('-o', '--output-csv')
    p.add_argument('--threshold', type=int, default=THRESHOLD,
                   help="minimum number of times a taxid must be present to count")
    p.add_argument('-X', '--output-unassigned', action='store_true')

    args = p.parse_args()

    if args.output_csv:
        output_filename = args.output_csv
    else:
        output_filename = os.path.basename(args.sig) + '.taxonomy.csv'

    # load the LCA databases from the JSON file(s)
    lca_db_list = []
    for lca_filename in args.lca:
        print('loading LCA database from {}'.format(lca_filename))
        lca_db = lca_json.LCA_Database(lca_filename)
        taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)
        lca_db_list.append((taxfoo, hashval_to_lca))

    # load signature
    sig = sourmash_lib.signature.load_one_signature(args.sig,
                                                    ksize=args.ksize)
    hashes = sig.minhash.get_mins()

    # open output file.
    outfp = open(output_filename, 'wt')
    outw = csv.writer(outfp)
    outw.writerow(['hash', 'taxid', 'status', 'rank_info', 'lineage'])

    ###

    # track number of classifications at various rank levels
    classified_at = collections.defaultdict(int)

    # also track unassigned
    unassigned = set()

    # for each hash in the minhash signature, get its LCA.
    n_in_lca = 0
    taxid_counts = collections.defaultdict(int)
    lca_to_hashvals = collections.defaultdict(set)
    for hashval in hashes:
        # if a k-mer is present in multiple DBs, pull the
        # least-common-ancestor taxonomic node across all of the
        # DBs.

        this_hashval_taxids = set()
        for (_, hashval_to_lca) in lca_db_list:
            hashval_lca = hashval_to_lca.get(hashval)
            if hashval_lca is not None and hashval_lca != 1:
                this_hashval_taxids.add(hashval_lca)

        if this_hashval_taxids:
            this_hashval_lca = taxfoo.find_lca(this_hashval_taxids)
            if this_hashval_lca != None:
                taxid_counts[this_hashval_lca] += 1
                lca_to_hashvals[this_hashval_lca].add(hashval)
        else:
            unassigned.add(hashval)

    # filter on given threshold - only taxids that show up in this
    # signature more than THRESHOLD.
    abundant_taxids = set([k for (k, cnt) in taxid_counts.items() \
                           if cnt >= args.threshold])

    # remove root (taxid == 1) if it's in there:
    if 1 in abundant_taxids:
        abundant_taxids.remove(1)

    # now, output hashval classifications.
    n_classified = 0
    for lca_taxid in abundant_taxids:
        for hashval in lca_to_hashvals[lca_taxid]:
            status = 'match'
            status_rank = taxfoo.get_taxid_rank(lca_taxid)
            lineage = taxfoo.get_lineage(lca_taxid,
                                         want_taxonomy=want_taxonomy)
            lineage = ";".join(lineage)

            classified_at[status_rank] += 1
            n_classified += 1

            outw.writerow([str(hashval), str(lca_taxid),
                           status, status_rank, lineage])

    # output unassigned?
    if args.output_unassigned:
        for hashval in unassigned:
            status = 'nomatch'
            status_rank = 'unknown'
            lineage = 'UNKNOWN'

            outw.writerow([str(hashval), str(lca_taxid),
                           status, status_rank, lineage])

    print('')
    print('classified sourmash signature \'{}\''.format(args.sig))
    print('LCA databases used: {}'.format(', '.join(args.lca)))
    print('')

    print('total hash values: {}'.format(len(hashes)))
    print('num classified: {}'.format(n_classified))

    n_rare_taxids = sum([cnt for (k, cnt) in taxid_counts.items() \
                         if cnt < args.threshold ])
    print('n rare taxids not used: {}'.format(n_rare_taxids))
    print('unclassified: {}'.format(len(unassigned)))

    print('')
    print('number classified unambiguously, by lowest classification rank:')
    for rank in want_taxonomy:
        if classified_at.get(rank):
            print('\t{}: {}'.format(rank, classified_at.get(rank, 0)))

    print('')
    print('classification output as CSV, here: {}'.format(output_filename))
예제 #5
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('lca')
    p.add_argument('revindex')
    p.add_argument('accessions_csv')
    args = p.parse_args()

    # load the LCA databases from the JSON file(s)
    print('loading LCA database from {}'.format(args.lca))
    lca_db = lca_json.LCA_Database(args.lca)
    taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)

    print('loading revindex:', args.revindex)
    revidx = HashvalRevindex(args.revindex)
    print('...loaded.')

    # load classification CSV
    print('loading classifications:', args.accessions_csv)
    taxfoo.load_accessions_csv(args.accessions_csv)
    print('...loaded.')

    ###

    # track number of classifications at various rank levels
    classified_at = collections.defaultdict(int)
    classified_samples = collections.defaultdict(int)

    hashval_troubles = collections.defaultdict(set)
    for hashval, lca in hashval_to_lca.items():
        rank = taxfoo.get_taxid_rank(lca)
        classified_at[rank] += 1

        if rank == 'superkingdom' and 0:
            n_sigids = len(revidx.hashval_to_sigids[hashval])
            classified_samples[n_sigids] += 1
            if n_sigids >= 4:
                for sigid in revidx.hashval_to_sigids[hashval]:
                    siginfo = revidx.sigid_to_siginfo[sigid]
                    hashval_troubles[hashval].add(siginfo)

    for hashval, siginfo_set in hashval_troubles.items():
        break
        print('getting {} sigs for {}'.format(len(siginfo_set), hashval))
        siglist = []
        for (filename, md5) in siginfo_set:
            sig = revindex_utils.get_sourmash_signature(filename, md5)
            siglist.append(sig)

        for sig in siglist:
            acc = sig.name().split()[0]
            taxid = taxfoo.get_taxid(acc)
            if taxid:
                print('\t', ";".join(taxfoo.get_lineage(taxid)),
                      taxfoo.get_taxid_rank(taxid))

    print('')
    for rank in want_taxonomy:
        if classified_at.get(rank):
            print('\t{}: {}'.format(rank, classified_at.get(rank, 0)))

    print(classified_samples)
예제 #6
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('lca_output')
    p.add_argument('genbank_csv')
    p.add_argument('nodes_dmp')
    p.add_argument('sigs', nargs='+')
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('--scaled', default=10000, type=int)

    p.add_argument('--traverse-directory', action='store_true')

    p.add_argument('-s', '--save-hashvals', action='store_true')
    p.add_argument('-l', '--load-hashvals', action='store_true')

    p.add_argument('--lca-json')
    p.add_argument('--names-dmp', default='')
    args = p.parse_args()

    taxfoo = NCBI_TaxonomyFoo()

    # load the accessions->taxid info
    taxfoo.load_accessions_csv(args.genbank_csv)

    # load the nodes_dmp file to get the tax tree
    print('loading nodes_dmp / taxonomic tree')
    taxfoo.load_nodes_dmp(args.nodes_dmp)

    # track hashval -> set of taxids
    hashval_to_taxids = defaultdict(set)

    # for every minhash in every signature, link it to its NCBI taxonomic ID.
    if args.traverse_directory:
        inp_files = list(traverse_find_sigs(args.sigs))
    else:
        inp_files = list(args.sigs)

    if args.load_hashvals:
        with open(args.lca_output + '.hashvals', 'rb') as hashval_fp:
            print('loading hashvals dict per -l/--load-hashvals...')
            hashval_to_taxids = load(hashval_fp)
            print('loaded {} hashvals'.format(len(hashval_to_taxids)))
    else:

        print('loading signatures & traversing hashes')
        bad_input = 0
        for n, filename in enumerate(inp_files):
            if n % 100 == 0:
                print('... loading file #', n, 'of', len(inp_files), end='\r')

            try:
                sig = sourmash_lib.signature.load_one_signature(
                    filename, ksize=args.ksize)
            except (FileNotFoundError, ValueError):
                if not args.traverse_directory:
                    raise

                bad_input += 1
                continue

            acc = sig.name().split(' ')[0]  # first part of sequence name
            acc = acc.split('.')[0]  # get acc w/o version

            taxid = taxfoo.get_taxid(acc)
            if taxid == None:
                continue

            sig.minhash = sig.minhash.downsample_scaled(args.scaled)

            mins = sig.minhash.get_mins()

            for m in mins:
                hashval_to_taxids[m].add(taxid)
        print('\n...done')
        if bad_input:
            print('failed to load {} of {} files found'.format(
                bad_input, len(inp_files)))

        if args.save_hashvals:
            with open(args.lca_output + '.hashvals', 'wb') as hashval_fp:
                dump(hashval_to_taxids, hashval_fp)

    ####

    print(
        'traversing tags and finding last-common-ancestor for {} tags'.format(
            len(hashval_to_taxids)))

    hashval_to_lca = {}
    found_root = 0
    empty_set = 0

    # find the LCA for each hashval and store.
    for n, (hashval, taxid_set) in enumerate(hashval_to_taxids.items()):
        if n % 10000 == 0:
            print('...', n, end='\r')

        # find associated least-common-ancestors.
        lca = taxfoo.find_lca(taxid_set)

        if lca == 1:
            if taxid_set:
                found_root += 1
            else:
                empty_set += 1
            continue

        # save!!
        hashval_to_lca[hashval] = lca
    print('\ndone')

    if found_root:
        print('found root {} times'.format(found_root))
    if empty_set:
        print('found empty set {} times'.format(empty_set))

    print('saving to', args.lca_output)
    with lca_json.xopen(args.lca_output, 'wb') as lca_fp:
        dump(hashval_to_lca, lca_fp)

    # update LCA DB JSON file if provided
    if args.lca_json:
        lca_db = lca_json.LCA_Database()
        if os.path.exists(args.lca_json):
            print('loading LCA JSON file:', args.lca_json)
            lca_db.load(args.lca_json)

        prefix = os.path.dirname(args.lca_json) + '/'

        lca_output = args.lca_output
        if lca_output.startswith(prefix):
            lca_output = lca_output[len(prefix):]

        nodes_dmp = args.nodes_dmp
        if nodes_dmp.startswith(prefix):
            nodes_dmp = nodes_dmp[len(prefix):]

        names_dmp = args.names_dmp
        if names_dmp:
            if names_dmp.startswith(prefix):
                names_dmp = names_dmp[len(prefix):]
        else:
            names_dmp = nodes_dmp.replace('nodes', 'names')

        lca_db.add_db(args.ksize, args.scaled, lca_output, nodes_dmp,
                      names_dmp)
        print('saving LCA JSON file:', args.lca_json)
        lca_db.save(args.lca_json)
예제 #7
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('lca_filename')
    p.add_argument('sigfiles', nargs='+')
    p.add_argument('-k', '--ksize', default=31, type=int)
    args = p.parse_args()

    # load lca info
    lca_db = lca_json.LCA_Database(args.lca_filename)
    taxfoo, hashval_to_lca, scaled = lca_db.get_database(args.ksize, SCALED)

    # load signatures
    siglist = []
    print('loading signatures from {} signature files'.format(
        len(args.sigfiles)))
    for sigfile in args.sigfiles:
        sigs = sourmash_lib.signature.load_signatures(sigfile,
                                                      ksize=args.ksize)
        sigs = list(sigs)
        siglist.extend(sigs)

    print('loaded {} signatures total at k={}'.format(len(siglist),
                                                      args.ksize))

    # downsample
    print('downsampling to scaled value: {}'.format(scaled))
    for sig in siglist:
        if sig.minhash.scaled < scaled:
            sig.minhash = sig.minhash.downsample_scaled(scaled)

    # now, extract hash values!
    hashvals = collections.defaultdict(int)
    for sig in siglist:
        for hashval in sig.minhash.get_mins():
            hashvals[hashval] += 1

    found = 0
    total = 0
    by_taxid = collections.defaultdict(int)

    # for every hash, get LCA of labels
    for hashval, count in hashvals.items():
        lca = hashval_to_lca.get(hashval)
        total += count

        if lca is None:
            by_taxid[0] += count
            continue

        by_taxid[lca] += count
        found += count

    print('found LCA classifications for', found, 'of', total, 'hashes')
    not_found = total - found

    # now, propogate counts up the taxonomic tree.
    by_taxid_lca = collections.defaultdict(int)
    for taxid, count in by_taxid.items():
        by_taxid_lca[taxid] += count

        parent = taxfoo.child_to_parent.get(taxid)
        while parent != None and parent != 1:
            by_taxid_lca[parent] += count
            parent = taxfoo.child_to_parent.get(parent)

    total_count = sum(by_taxid.values())

    # sort by lineage length
    x = []
    for taxid, count in by_taxid_lca.items():
        x.append((len(taxfoo.get_lineage(taxid)), taxid, count))

    x.sort()

    # ...aaaaaand output.
    print('{}\t{}\t{}\t{}\t{}\t{}'.format('percent', 'below', 'at node',
                                          'code', 'taxid', 'name'))
    for _, taxid, count_below in x:
        percent = round(100 * count_below / total_count, 2)
        count_at = by_taxid[taxid]

        rank = taxfoo.node_to_info.get(taxid)
        if rank:
            rank = rank[0]
            classify_code = kraken_rank_code.get(rank, '-')
        else:
            classify_code = '-'

        name = taxfoo.taxid_to_names.get(taxid)
        if name:
            name = name[0]
        else:
            name = '-'

        print('{}\t{}\t{}\t{}\t{}\t{}'.format(percent, count_below, count_at,
                                              classify_code, taxid, name))

    if not_found:
        classify_code = 'U'
        percent = round(100 * not_found / total_count, 2)
        count_below = not_found
        count_at = not_found
        taxid = 0
        name = 'not classified'

        print('{}\t{}\t{}\t{}\t{}\t{}'.format(percent, count_below, count_at,
                                              classify_code, taxid, name))
예제 #8
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('csv')
    p.add_argument('revindex')
    p.add_argument('siglist', nargs='+')
    p.add_argument('--lca', nargs='+', default=LCA_DBs)
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('-o', '--output', type=argparse.FileType('wt'),
                   help='output CSV to this file instead of stdout')
    #p.add_argument('-v', '--verbose', action='store_true')
    p.add_argument('-d', '--debug', action='store_true')
    args = p.parse_args()

    if args.debug:
        global _print_debug
        _print_debug = True

    ## load LCA databases
    lca_db_list = []
    for lca_filename in args.lca:
        print('loading LCA database from {}'.format(lca_filename),
              file=sys.stderr)
        lca_db = lca_json.LCA_Database(lca_filename)
        taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)
        lca_db_list.append((taxfoo, hashval_to_lca))
    
    # reverse index names -> taxids
    names_to_taxids = defaultdict(set)
    for taxid, (name, _, _) in taxfoo.taxid_to_names.items():
        names_to_taxids[name].add(taxid)

    ### parse spreadsheet
    r = csv.reader(open(args.csv, 'rt'))
    row_headers = ['identifier'] + taxlist

    print('examining spreadsheet headers...', file=sys.stderr)
    first_row = next(iter(r))

    n_disagree = 0
    for (column, value) in zip(row_headers, first_row):
        if column.lower() != value.lower():
            print('** assuming {} == {} in spreadsheet'.format(column, value),
                  file=sys.stderr)
            n_disagree += 1
            if n_disagree > 2:
                print('whoa, too many assumptions. are the headers right?',
                      file=sys.stderr)
                sys.exit(-1)

    confusing_lineages = defaultdict(list)
    incompatible_lineages = defaultdict(list)
    assignments = {}
    for row in r:
        lineage = list(zip(row_headers, row))

        ident = lineage[0][1]
        lineage = lineage[1:]

        # clean lineage of null names
        lineage = [(a,b) for (a,b) in lineage if b not in null_names]

        # ok, find the least-common-ancestor taxid...
        taxid, rest = get_lca_taxid_for_lineage(taxfoo, names_to_taxids,
                                                lineage)

        # and find the *lowest* identifiable ancestor taxid, just to see
        # if there are confusing lineages.
        lowest_taxid, lowest_rest = \
          get_lowest_taxid_for_lineage(taxfoo, names_to_taxids, lineage)

        # do they match? if not, report.
        if lowest_taxid != taxid:
            lowest_lineage = taxfoo.get_lineage(lowest_taxid, taxlist)
            lowest_str = ', '.join(lowest_lineage)

            # find last matching, in case different classification levels.
            match_lineage = [ b for (a, b) in lineage ]
            end = match_lineage.index(lowest_lineage[-1])
            assert end >= 0
            match_lineage = match_lineage[:end + 1]
            match_str = ', '.join(match_lineage)

            confusing_lineages[(match_str, lowest_str)].append(ident)

        # check! NCBI lineage should be lineage of taxid + rest
        ncbi_lineage = taxfoo.get_lineage(taxid, taxlist)
        assert len(ncbi_lineage)
        reconstructed = ncbi_lineage + [ b for (a,b) in rest ]

        # ...make a comparable lineage from the CSV line...
        csv_lineage = [ b for (a, b) in lineage ]

        # are NCBI-rooted and CSV lineages the same?? if not, report.
        if csv_lineage != reconstructed:
            csv_str = ", ".join(csv_lineage[:len(ncbi_lineage)])
            ncbi_str = ", ".join(ncbi_lineage)
            incompatible_lineages[(csv_str, ncbi_str)].append(ident)

        # all is well if we've reached this point! We've got NCBI-rooted
        # taxonomies and now we need to record. next:
        #
        # build a set of triples: (rank, name, taxid), where taxid can
        # be None.

        lineage_taxids = taxfoo.get_lineage_as_taxids(taxid)
        tuples_info = []
        for taxid in lineage_taxids:
            name = taxfoo.get_taxid_name(taxid)
            rank = taxfoo.get_taxid_rank(taxid)

            if rank in taxlist:
                tuples_info.append((rank, name))

        for (rank, name) in rest:
            assert rank in taxlist
            tuples_info.append((rank, name))

        assignments[ident] = tuples_info

    print("{} weird lineages that maybe don't match with NCBI.".format(len(confusing_lineages) + len(incompatible_lineages)), file=sys.stderr)

    ## next phase: collapse lineages etc.

    ## load revindex
    print('loading reverse index:', args.revindex, file=sys.stderr)
    custom_bins_ri = revindex_utils.HashvalRevindex(args.revindex)

    # load the signatures associated with each revindex.
    print('loading signatures for custom genomes...', file=sys.stderr)
    sigids_to_sig = {}
    for sigid, (filename, md5) in custom_bins_ri.sigid_to_siginfo.items():
        sig = revindex_utils.get_sourmash_signature(filename, md5)
        if sig.name() in assignments:
            sigids_to_sig[sigid] = sig
        else:
            debug('no assignment:', sig.name())

    # figure out what ksize we're talking about here! (this should
    # probably be stored on the revindex...)
    random_db_sig = next(iter(sigids_to_sig.values()))
    ksize = random_db_sig.minhash.ksize

    print('...found {} custom genomes that also have assignments!!'.format(len(sigids_to_sig)), file=sys.stderr)

    ## now, connect the dots: hashvals to custom classifications
    hashval_to_custom = defaultdict(list)
    for hashval, sigids in custom_bins_ri.hashval_to_sigids.items():
        for sigid in sigids:
            sig = sigids_to_sig.get(sigid, None)
            if sig:
                assignment = assignments[sig.name()]
                hashval_to_custom[hashval].append(assignment)

    # whew! done!! we can now go from a hashval to a custom assignment!!

    # for each query, gather all the matches in both custom and NCBI, then
    # classify.
    csvfp = csv.writer(sys.stdout)
    if args.output:
        print("outputting classifications to '{}'".format(args.output.name))
        csvfp = csv.writer(args.output)
    else:
        print("outputting classifications to stdout")
    csvfp.writerow(['ID'] + taxlist)

    total_count = 0
    for query_filename in args.siglist:
        for query_sig in sourmash_lib.load_signatures(query_filename,
                                                      ksize=ksize):
            print(u'\r\033[K', end=u'', file=sys.stderr)
            print('... classifying {}'.format(query_sig.name()), end='\r',
                  file=sys.stderr)
            debug('classifying', query_sig.name())
            total_count += 1

            these_assignments = defaultdict(list)
            n_custom = 0
            for hashval in query_sig.minhash.get_mins():
                # custom
                assignment = hashval_to_custom.get(hashval, [])
                if assignment:
                    these_assignments[hashval].extend(assignment)
                    n_custom += 1

                # NCBI
                for (this_taxfoo, hashval_to_lca) in lca_db_list:
                    hashval_lca = hashval_to_lca.get(hashval)
                    if hashval_lca is not None and hashval_lca != 1:
                        lineage = this_taxfoo.get_lineage_as_dict(hashval_lca,
                                                                  taxlist)

                        tuple_info = []
                        for rank in taxlist:
                            if rank not in lineage:
                                break
                            tuple_info.append((rank, lineage[rank]))
                        these_assignments[hashval_lca].append(tuple_info)

            check_counts = Counter()
            for tuple_info in these_assignments.values():
                last_tup = tuple(tuple_info[-1])
                check_counts[last_tup] += 1

            debug('n custom hashvals:', n_custom)
            debug(pprint.pformat(check_counts.most_common()))

            # now convert to trees -> do LCA & counts
            counts = Counter()
            parents = {}
            for hashval in these_assignments:

                # for each list of tuple_info [(rank, name), ...] build
                # a tree that lets us discover least-common-ancestor.
                tuple_info = these_assignments[hashval]
                tree = build_tree(tuple_info)

                # also update a tree that we can ascend from leaves -> parents
                # for all assignments for all hashvals
                parents = build_reverse_tree(tuple_info, parents)

                # now find either a leaf or the first node with multiple
                # children; that's our least-common-ancestor node.
                lca, reason = find_lca(tree)
                counts[lca] += 1

            # ok, we now have the LCAs for each hashval, and their number
            # of counts. Now sum across "significant" LCAs - those above
            # threshold.

            tree = {}
            tree_counts = defaultdict(int)

            debug(pprint.pformat(counts.most_common()))

            n = 0
            for lca, count in counts.most_common():
                if count < THRESHOLD:
                    break

                n += 1

                xx = []
                parent = lca
                while parent:
                    xx.insert(0, parent)
                    tree_counts[parent] += count
                    parent = parents.get(parent)
                debug(n, count, xx[1:])

                # update tree with this set of assignments
                build_tree([xx], tree)

            if n > 1:
                debug('XXX', n)

            # now find LCA? or whatever.
            lca, reason = find_lca(tree)
            if reason == 0:               # leaf node
                debug('END', lca)
            else:                         # internal node
                debug('MULTI', lca)

            # backtrack to full lineage via parents
            lineage = []
            parent = lca
            while parent != ('root', 'root'):
                lineage.insert(0, parent)
                parent = parents.get(parent)

            # output!
            row = [query_sig.name()]
            for taxrank, (rank, name) in itertools.zip_longest(taxlist, lineage, fillvalue=('', '')):
                if rank:
                    assert taxrank == rank
                row.append(name)

            csvfp.writerow(row)

    print(u'\r\033[K', end=u'', file=sys.stderr)
    print('classified {} signatures total'.format(total_count), file=sys.stderr)
예제 #9
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('-k', '--ksize', default=31, type=int)
    p.add_argument('--lca', nargs='+', default=LCA_DBs)
    p.add_argument('dir')
    p.add_argument('-o', '--output-csv')
    p.add_argument(
        '--threshold',
        type=int,
        default=THRESHOLD,
        help="minimum number of times a taxid must be present to count")

    args = p.parse_args()

    if args.output_csv:
        output_filename = args.output_csv
    else:
        output_filename = os.path.basename(args.dir) + '.taxonomy.csv'

    outfp = open(output_filename, 'wt')
    outw = csv.writer(outfp)
    outw.writerow(['name', 'taxid', 'status', 'rank_info', 'lineage'])

    # load the LCA databases from the JSON file(s)
    lca_db_list = []
    for lca_filename in args.lca:
        print('loading LCA database from {}'.format(lca_filename))
        lca_db = lca_json.LCA_Database(lca_filename)
        taxfoo, hashval_to_lca, _ = lca_db.get_database(args.ksize, SCALED)
        lca_db_list.append((taxfoo, hashval_to_lca))

    print('loading all signatures in directory:', args.dir)
    sigdict = load_all_signatures(args.dir, args.ksize)
    print('...loaded {} signatures at k={}'.format(len(sigdict), args.ksize))

    ###

    # track number of classifications at various rank levels
    classified_at = collections.defaultdict(int)

    # track number of disagreements at various rank levels
    disagree_at = collections.defaultdict(int)

    # for each minhash signature in the directory,
    n_in_lca = 0
    for name, sig in sigdict.items():

        # for each k-mer in each minhash signature, collect assigned taxids
        # across all databases (& count).
        taxid_set = collections.defaultdict(int)

        for hashval in sig.minhash.get_mins():

            # if a k-mer is present in multiple DBs, pull the
            # least-common-ancestor taxonomic node across all of the
            # DBs.

            this_hashval_taxids = set()
            for (_, hashval_to_lca) in lca_db_list:
                hashval_lca = hashval_to_lca.get(hashval)
                if hashval_lca is not None and hashval_lca != 1:
                    this_hashval_taxids.add(hashval_lca)

            if this_hashval_taxids:
                this_hashval_lca = taxfoo.find_lca(this_hashval_taxids)
                if this_hashval_lca != None:
                    taxid_set[this_hashval_lca] += 1

        # filter on given threshold - only taxids that show up in this
        # signature more than THRESHOLD.
        abundant_taxids = set([k for (k, cnt) in taxid_set.items() \
                               if cnt >= args.threshold])

        # remove root (taxid == 1) if it's in there:
        if 1 in abundant_taxids:
            abundant_taxids.remove(1)

        # default to nothing found. boo.
        status = 'nomatch'
        status_rank = ''
        taxid = 0

        # ok - out of the loop, got our LCAs, ...are there any left?
        if abundant_taxids:
            # increment number that are classifiable at *some* rank.
            n_in_lca += 1

            try:
                disagree_below_rank, taxid_at_rank, disagree_taxids = \
                  taxfoo.get_lineage_first_disagreement(abundant_taxids,
                                                        want_taxonomy)
            except ValueError:
                # @CTB this is probably bad.
                assert 0
                continue

            # we found a disagreement - report the rank *at* the disagreement,
            # the lineage *above* the disagreement.
            if disagree_below_rank:
                # global record of disagreements
                disagree_at[disagree_below_rank] += 1

                list_at_rank = [
                    taxfoo.get_taxid_name(r) for r in disagree_taxids
                ]
                list_at_rank = ", ".join(list_at_rank)

                print('{} has multiple LCA below {}: \'{}\''.format(
                    name, disagree_below_rank, list_at_rank))

                # set output
                status = 'disagree'
                status_rank = disagree_below_rank
                taxid = taxid_at_rank
            else:
                # found unambiguous! yay.
                status = 'found'

                taxid = taxfoo.get_lowest_lineage(abundant_taxids,
                                                  want_taxonomy)
                status_rank = taxfoo.get_taxid_rank(taxid)
                status = 'found'
                classified_at[status_rank] += 1

        if taxid != 0:
            lineage_found = taxfoo.get_lineage(taxid,
                                               want_taxonomy=want_taxonomy)
            lineage_found = ";".join(lineage_found)
        else:
            lineage_found = ""

        outw.writerow([name, taxid, status, status_rank, lineage_found])

    print('')
    print('classified sourmash signatures in directory: \'{}\''.format(
        args.dir))
    print('LCA databases used: {}'.format(', '.join(args.lca)))
    print('')

    print('total signatures found: {}'.format(len(sigdict)))
    print('no classification information: {}'.format(len(sigdict) - n_in_lca))
    print('')
    print('could classify {}'.format(n_in_lca))
    print('of those, {} disagree at some rank.'.format(
        sum(disagree_at.values())))

    print('')
    print('number classified unambiguously, by lowest classification rank:')
    for rank in want_taxonomy:
        if classified_at.get(rank):
            print('\t{}: {}'.format(rank, classified_at.get(rank, 0)))

    print('')
    print('disagreements by rank:')
    for rank in want_taxonomy:
        if disagree_at.get(rank):
            print('\t{}: {}'.format(rank, disagree_at.get(rank, 0)))

    print('')
    print('classification output as CSV, here: {}'.format(output_filename))