class FindNodesTestCase(unittest.TestCase): def setUp(self): self.taxonomy = TaxNode(rank='root', name='root', tax_id='1') self.taxonomy.ranks = ['root', 'class', 'genus', 'species'] g1 = TaxNode(rank='genus', name='g1', tax_id='2') self.g1 = g1 g1.sequence_ids = set(['s1', 's2']) self.taxonomy.add_child(g1) g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1')) g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2')) g2 = TaxNode(rank='genus', name='g2', tax_id='3') self.taxonomy.add_child(g2) s3 = TaxNode(rank='species', name='s3', tax_id='s3') s3.sequence_ids = set(['s3', 's4']) g2.add_child(s3) g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4')) def test_find_nodes(self): r = list(hrefpkg_build.find_nodes(self.taxonomy, 'class')) self.assertEqual(frozenset(['g1', 's3']), frozenset(i.name for i in r)) def test_find_nodes_below_rank(self): r = list(hrefpkg_build.find_nodes(self.taxonomy, 'genus')) self.assertEqual(['s3'], [i.name for i in r])
def generate_tax2tree_map(refpkg, output_fp): """ Generate a tax2tree map from a reference package, writing to output_fp """ with open(refpkg.file_abspath('taxonomy')) as fp: tax_root = TaxNode.from_taxtable(fp) def lineage(tax_id): l = tax_root.get_node(tax_id).lineage() d = {i.rank: i.tax_id for i in l} r = ('{0}__{1}'.format(rank[0], d.get(rank, '')) for rank in TAX2TREE_RANKS) return '; '.join(r) with open(refpkg.file_abspath('seq_info')) as fp: reader = csv.DictReader(fp) seq_map = ((i['seqname'], i['tax_id']) for i in reader) for seqname, tax_id in seq_map: if not tax_id: l = 'Unclassified' else: l = lineage(tax_id) print >> output_fp, '\t'.join((seqname, l))
def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( description= "Turn a BIOM file with a taxonomy into a taxtable and seqinfo.") parser.add_argument('biom', type=argparse.FileType('r'), help='input BIOM file') parser.add_argument('taxtable', type=argparse.FileType('w'), help='output taxtable') parser.add_argument('seqinfo', type=argparse.FileType('w'), help='output seqinfo') args = parser.parse_args() log.info('loading biom') with args.biom: j = json.load(args.biom) root = TaxNode('root', root_id, name='Root') root.ranks = rank_order seqinfo = csv.writer(args.seqinfo) seqinfo.writerow(('seqname', 'tax_id')) log.info('determining tax_ids') for leaf in j['rows']: leaf_taxonomy = leaf['metadata']['taxonomy'] # Drop nodes containing only rank (e.g. `s__`) leaf_taxonomy = [i for i in leaf_taxonomy if i[3:]] leaf_lineages = list(lineages([i for i in leaf_taxonomy if i[3:]])) seqinfo.writerow((leaf['id'], leaf_lineages[-1][0])) for tax_id, node, parent in leaf_lineages: if tax_id in root.index: continue root.get_node(parent).add_child( TaxNode(ranks[node[0]], tax_id, name=node[3:] or node)) log.info('writing taxtable') with args.taxtable: root.write_taxtable(args.taxtable)
def update_taxids(refpkg, tax2tree_dict, output_fp, allow_rename=True, unknown_taxid=None): with open(refpkg.file_abspath('taxonomy')) as fp: tax_root = TaxNode.from_taxtable(fp) def lineage_ids(tax_id): if not tax_id: return frozenset() n = tax_root.get_node(tax_id) s = frozenset(i.tax_id for i in n.lineage()) return s with open(refpkg.file_abspath('seq_info')) as fp: try: dialect = csv.Sniffer().sniff(fp.read(1024), delimiters=',') except csv.Error: dialect = csv.excel fp.seek(0) r = csv.DictReader(fp, dialect=dialect) h = r.fieldnames h.append('orig_tax_id') w = csv.DictWriter(output_fp, h, dialect=dialect) w.writeheader() for i in r: i['orig_tax_id'] = i['tax_id'] n = tax2tree_dict.get(i['seqname']) if n and n not in lineage_ids(i['tax_id']): node = tax_root.get_node(n) new_name, new_rank = node.name, node.rank if i['tax_id']: node = tax_root.get_node(i['tax_id']) orig_name, orig_rank = node.name, node.rank else: orig_name, orig_rank = '', '' logging.info('%s changed from "%s" (%s) to "%s" (%s)', i['seqname'], orig_name, orig_rank, new_name, new_rank) if allow_rename or not i['tax_id']: i['tax_id'] = n else: logging.info("Not applied.") elif n is None and not i['tax_id']: i['tax_id'] = unknown_taxid logging.warn("no taxonomy for %s", i['seqname']) w.writerow(i)
def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( description="Turn a BIOM file with a taxonomy into a taxtable and seqinfo.") parser.add_argument( 'biom', type=argparse.FileType('r'), help='input BIOM file') parser.add_argument( 'taxtable', type=argparse.FileType('w'), help='output taxtable') parser.add_argument( 'seqinfo', type=argparse.FileType('w'), help='output seqinfo') args = parser.parse_args() log.info('loading biom') with args.biom: j = json.load(args.biom) root = TaxNode('root', root_id, name='Root') root.ranks = rank_order seqinfo = csv.writer(args.seqinfo) seqinfo.writerow(('seqname', 'tax_id')) log.info('determining tax_ids') for leaf in j['rows']: leaf_taxonomy = leaf['metadata']['taxonomy'] # Drop nodes containing only rank (e.g. `s__`) leaf_taxonomy = [i for i in leaf_taxonomy if i[3:]] leaf_lineages = list(lineages([i for i in leaf_taxonomy if i[3:]])) seqinfo.writerow((leaf['id'], leaf_lineages[-1][0])) for tax_id, node, parent in leaf_lineages: if tax_id in root.index: continue root.get_node(parent).add_child( TaxNode(ranks[node[0]], tax_id, name=node[3:] or node)) log.info('writing taxtable') with args.taxtable: root.write_taxtable(args.taxtable)
def action(a): # Load taxonomy with a.taxonomy as fp: taxonomy = TaxNode.from_taxtable(fp) logging.info('Loaded taxonomy') # Load sequences into taxonomy with a.seqinfo_file as fp: taxonomy.populate_from_seqinfo(fp) logging.info('Added %d sequences', sum(1 for i in taxonomy.subtree_sequence_ids())) # Sequences which are classified above the desired rank should just be kept kept_ids = frozenset(sequences_above_rank(taxonomy, a.filter_rank)) log_taxid = None if a.log: writer = csv.writer(a.log, lineterminator='\n', quoting=csv.QUOTE_NONNUMERIC) writer.writerow(('tax_id', 'tax_name', 'n', 'kept', 'pruned')) def log_taxid(tax_id, tax_name, n, kept, pruned): writer.writerow((tax_id, tax_name, n, kept, pruned)) with a.output_fp as fp, a.log or util.nothing(): logging.info('Keeping %d sequences classified above %s', len(kept_ids), a.filter_rank) wrap.esl_sfetch(a.sequence_file, kept_ids, fp) # For each filter-rank, filter nodes = [i for i in taxonomy if i.rank == a.filter_rank] # Filter each tax_id, running in ``--threads`` tasks in parallel with futures.ThreadPoolExecutor(a.threads) as executor: futs = {} for i, node in enumerate(nodes): seqs = frozenset(node.subtree_sequence_ids()) if not seqs: logging.warn("No sequences for %s (%s)", node.tax_id, node.name) if log_taxid: log_taxid(node.tax_id, node.name, 0, 0, 0) continue elif len(seqs) == 1: logging.warn('Only 1 sequence for %s (%s). Dropping', node.tax_id, node.name) continue f = executor.submit(filter_worker, sequence_file=a.sequence_file, node=node, seqs=seqs, distance_cutoff=a.distance_cutoff, log_taxid=log_taxid) futs[f] = {'n_seqs': len(seqs), 'node': node} complete = 0 while futs: done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED) complete += len(done) sys.stderr.write('{0:8d}/{1:8d} taxa completed\r'.format(complete, complete+len(pending))) for f in done: if f.exception(): logging.exception("Error in child process: %s", f.exception()) executor.shutdown(False) raise f.exception() info = futs.pop(f) kept = f.result() kept_ids |= kept if len(kept) != info['n_seqs']: logging.warn('Pruned %d/%d sequences for %s (%s)', info['n_seqs'] - len(kept), info['n_seqs'], info['node'].tax_id, info['node'].name) # Extract wrap.esl_sfetch(a.sequence_file, kept, fp) # Filter seqinfo if a.filtered_seqinfo: with open(a.seqinfo_file.name) as fp: r = csv.DictReader(fp) rows = (i for i in r if i['seqname'] in kept_ids) with a.filtered_seqinfo as ofp: w = csv.DictWriter(ofp, r.fieldnames, quoting=csv.QUOTE_NONNUMERIC) w.writeheader() w.writerows(rows)
def action(a): with a.taxonomy as fp: taxonomy = TaxNode.from_taxtable(fp) with a.seqinfo_file as fp: # List of sequences r = csv.DictReader(fp) current_seqs = frozenset(i['seqname'] for i in r) fp.seek(0) taxonomy.populate_from_seqinfo(fp) # Find sequences from underrepresented taxids to search underrep = find_underrepresented(taxonomy, a.min_at_rank, a.rank) tax_seqs = {} seq_group = {} for n, seqs in underrep: tax_seqs[n.tax_id] = seqs seq_group.update({i:n.tax_id for i in seqs}) with util.ntf(prefix='to_expand-', suffix='.fasta') as expand_fp, \ util.ntf(prefix='expand_hits-', suffix='.fasta') as hits_fp: # Extract sequences c = wrap.esl_sfetch(a.sequence_file, seq_group, expand_fp) logging.info('fetched %d sequences', c) expand_fp.close() # Search sequences against unnamed r = uclust_search(expand_fp.name, a.unnamed_file, pct_id=a.pct_id, maxaccepts=4, search_pct_id=0.9, trunclabels=True) hits = list(r) # Map from hit to group hit_group = {i.target_label: seq_group[i.query_label] for i in hits} # Extract hits c = wrap.esl_sfetch(a.unnamed_file, hit_group, hits_fp) logging.info('%d hits', c) hits_fp.close() # Search hits back against named file r = uclust_search(hits_fp.name, expand_fp.name, pct_id=a.pct_id, maxaccepts=1, search_pct_id=0.9, trunclabels=True) # Sequences which hit the same group update_hits = dict((i.query_label, seq_group[i.target_label]) for i in r if seq_group[i.target_label] == hit_group[i.query_label]) overlap = frozenset(update_hits) & current_seqs if overlap: logging.warn('%d sequences already present in corpus: %s', len(overlap), ', '.join(overlap)) # Add sequences with open(a.output + '.fasta', 'w') as ofp: with open(a.sequence_file) as fp: shutil.copyfileobj(fp, ofp) try: wrap.esl_sfetch(hits_fp.name, frozenset(update_hits) - current_seqs, ofp) finally: os.remove(hits_fp.name + '.ssi') # Write a new seq_info with open(a.output + '.seq_info.csv', 'w') as ofp, open(a.seqinfo_file.name) as sinfo: r = csv.DictReader(sinfo) fn = list(r.fieldnames) + ['inferred_tax_id'] w = csv.DictWriter(ofp, fn, lineterminator='\n', quoting=csv.QUOTE_NONNUMERIC) w.writeheader() w.writerows(i for i in r if i['seqname'] not in overlap) if 'cluster' in fn: rows = ({'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes', 'cluster': v} for k, v in update_hits.items()) else: rows = ({'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes'} for k, v in update_hits.items()) w.writerows(rows)
def setUp(self): self.taxonomy = TaxNode(rank='root', name='root', tax_id='1') self.taxonomy.ranks = ['root', 'class', 'genus', 'species'] g1 = TaxNode(rank='genus', name='g1', tax_id='2') self.g1 = g1 g1.sequence_ids = set(['s1', 's2']) self.taxonomy.add_child(g1) g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1')) g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2')) g2 = TaxNode(rank='genus', name='g2', tax_id='3') self.taxonomy.add_child(g2) s3 = TaxNode(rank='species', name='s3', tax_id='s3') s3.sequence_ids = set(['s3', 's4']) g2.add_child(s3) g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4'))
def action(a): random.seed(a.seed) j = functools.partial(os.path.join, a.output_dir) if not os.path.isdir(j()): raise IOError('Does not exist: {0}'.format(j())) if os.path.exists(j('index.refpkg')): raise IOError('index.refpkg exists.') with open(a.taxonomy) as fp: logging.info('loading taxonomy') taxonomy = TaxNode.from_taxtable(fp) # If partitioning, partition with current args, return if a.partition_below_rank is not None or a.partition_rank is not None: if not a.partition_below_rank or not a.partition_rank: raise ValueError("--partition-below-rank and --partition-rank must be specified together") return partition_hrefpkg(a, taxonomy) with open(a.seqinfo_file) as fp: logging.info("loading seqinfo") seqinfo = load_seqinfo(fp) # Build an hrefpkg nodes = [i for i in taxonomy if i.rank == a.index_rank] hrefpkgs = [] futs = {} with open(j('index.csv'), 'w') as fp, \ open(j('train.fasta'), 'w') as train_fp, \ open(j('test.fasta'), 'w') as test_fp, \ futures.ThreadPoolExecutor(a.threads) as executor: def log_hrefpkg(tax_id): path = j(tax_id + '.refpkg') fp.write('{0},{0}.refpkg\n'.format(tax_id)) hrefpkgs.append(path) for i, node in enumerate(nodes): if a.only and node.tax_id not in a.only: logging.info("Skipping %s", node.tax_id) continue if os.path.exists(j(node.tax_id + '.refpkg')): logging.warn("Refpkg exists: %s.refpkg. Skipping", node.tax_id) log_hrefpkg(node.tax_id) continue f = executor.submit(tax_id_refpkg, node.tax_id, taxonomy, seqinfo, a.sequence_file, output_dir=a.output_dir, test_file=test_fp, train_file=train_fp) futs[f] = node.tax_id, node.name while futs: done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED) for f in done: tax_id, name = futs.pop(f) r = f.result() if r: logging.info("Finished refpkg for %s (%s) [%d remaining]", name, tax_id, len(pending)) log_hrefpkg(tax_id) assert len(futs) == len(pending) # Build index refpkg logging.info('Building index.refpkg') index_rp, sequence_ids = build_index_refpkg(hrefpkgs, a.sequence_file, seqinfo, taxonomy, dest=j('index.refpkg'), index_rank=a.index_rank) # Write unused seqs logging.info("Extracting unused sequences") seqs = (i for i in SeqIO.parse(a.sequence_file, 'fasta') if i.id not in sequence_ids) c = SeqIO.write(seqs, j('not_in_hrefpkgs.fasta'), 'fasta') logging.info("%d sequences not in hrefpkgs.", c)
def action(a): with a.taxonomy as fp: taxonomy = TaxNode.from_taxtable(fp) with a.seqinfo_file as fp: # List of sequences r = csv.DictReader(fp) current_seqs = frozenset(i['seqname'] for i in r) fp.seek(0) taxonomy.populate_from_seqinfo(fp) # Find sequences from underrepresented taxids to search underrep = find_underrepresented(taxonomy, a.min_at_rank, a.rank) tax_seqs = {} seq_group = {} for n, seqs in underrep: tax_seqs[n.tax_id] = seqs seq_group.update({i: n.tax_id for i in seqs}) with util.ntf(prefix='to_expand-', suffix='.fasta') as expand_fp, \ util.ntf(prefix='expand_hits-', suffix='.fasta') as hits_fp: # Extract sequences c = wrap.esl_sfetch(a.sequence_file, seq_group, expand_fp) logging.info('fetched %d sequences', c) expand_fp.close() # Search sequences against unnamed r = uclust_search(expand_fp.name, a.unnamed_file, pct_id=a.pct_id, maxaccepts=4, search_pct_id=0.9) hits = list(r) # Map from hit to group hit_group = {i.target_label: seq_group[i.query_label] for i in hits} # Extract hits c = wrap.esl_sfetch(a.unnamed_file, hit_group, hits_fp) logging.info('%d hits', c) hits_fp.close() # Search hits back against named file r = uclust_search(hits_fp.name, expand_fp.name, pct_id=a.pct_id, maxaccepts=1, search_pct_id=0.9) # Sequences which hit the same group update_hits = dict( (i.query_label, seq_group[i.target_label]) for i in r if seq_group[i.target_label] == hit_group[i.query_label]) overlap = frozenset(update_hits) & current_seqs if overlap: logging.warn('%d sequences already present in corpus: %s', len(overlap), ', '.join(overlap)) # Add sequences with open(a.output + '.fasta', 'w') as ofp: with open(a.sequence_file) as fp: shutil.copyfileobj(fp, ofp) try: wrap.esl_sfetch(hits_fp.name, frozenset(update_hits) - current_seqs, ofp) finally: os.remove(hits_fp.name + '.ssi') # Write a new seq_info with open(a.output + '.seq_info.csv', 'w') as ofp, open(a.seqinfo_file.name) as sinfo: r = csv.DictReader(sinfo) fn = list(r.fieldnames) + ['inferred_tax_id'] w = csv.DictWriter(ofp, fn, lineterminator='\n', quoting=csv.QUOTE_NONNUMERIC) w.writeheader() w.writerows(i for i in r if i['seqname'] not in overlap) if 'cluster' in fn: rows = ({ 'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes', 'cluster': v } for k, v in update_hits.items()) else: rows = ({ 'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes' } for k, v in update_hits.items()) w.writerows(rows)
def action(a): # Load taxtable with a.taxtable as fp: logging.info('Loading taxonomy') taxonomy = TaxNode.from_taxtable(fp) with a.seqinfo_file as fp: logging.info('Loading seqinfo') taxonomy.populate_from_seqinfo(fp) fp.seek(0) r = csv.DictReader(fp) seqinfo = {i['seqname']: i for i in r} if a.unnamed_sequence_meta: with a.unnamed_sequence_meta as fp: r = csv.DictReader(fp) unnamed_seqinfo = {i['seqname']: i for i in r} assert not set(unnamed_seqinfo) & set(seqinfo) seqinfo.update(unnamed_seqinfo) # Write clustering information for sequences with cluster_rank-level # classifications done = set() cluster_ids = {} with a.sequence_out: for tax_id, sequences in taxonomic_clustered(taxonomy, a.cluster_rank): for sequence in sequences: cluster_ids[sequence] = tax_id done |= set(sequences) # Fetch sequences logging.info('Fetching %d %s-level sequences', len(done), a.cluster_rank) wrap.esl_sfetch(a.named_sequence_file, done, a.sequence_out) a.sequence_out.flush() # Find sequences *above* cluster_rank above_rank_seqs = frozenset(i for i in taxonomy.subtree_sequence_ids() if i not in done) logging.info('%d sequences above rank %s', len(above_rank_seqs), a.cluster_rank) # Write sequences clustered above species level, unnamed sequences to # file with util.ntf(prefix='to_cluster', suffix='.fasta') as tf, \ util.ntf(prefix='unnamed_to_cluster', suffix='.fasta') as unnamed_fp: wrap.esl_sfetch(a.named_sequence_file, above_rank_seqs, tf) if a.unnamed_sequences: with open(a.unnamed_sequences) as fp: shutil.copyfileobj(fp, tf) tf.close() # Remove redundant sequences: we don't need anything that's unnamed # & close to something named. redundant_ids = cluster_identify_redundant(a.sequence_out.name, done, to_cluster=tf.name, threshold=a.redundant_cluster_id) logging.info('%d redundant sequences', len(redundant_ids)) # Extract desired sequences sequences = SeqIO.parse(tf.name, 'fasta') sequences = (i for i in sequences if i.id not in redundant_ids) # Write to file for clustering unnamed_count = SeqIO.write(sequences, unnamed_fp, 'fasta') logging.info('Kept %d non-redundant, unnamed sequences', unnamed_count) # Write to output sequence file unnamed_fp.seek(0) shutil.copyfileobj(unnamed_fp, a.sequence_out) unnamed_fp.close() # Cluster remaining sequences into OTUs for i, cluster_seqs in enumerate(identify_otus_unnamed(unnamed_fp.name, a.cluster_id)): done |= set(cluster_seqs) otu = 'otu_{0}'.format(i) for sequence in cluster_seqs: cluster_ids[sequence] = otu with a.seqinfo_out as fp: def add_cluster(i): """Add a cluster identifier to sequence metadata""" i['cluster'] = cluster_ids[i['seqname']] return i seqinfo_records = (seqinfo.get(i, {'seqname': i}) for i in done) seqinfo_records = (add_cluster(i) for i in seqinfo_records) fields = list(seqinfo.values()[0].keys()) fields.append('cluster') w = csv.DictWriter(fp, fields, quoting=csv.QUOTE_NONNUMERIC, lineterminator='\n') w.writeheader() w.writerows(seqinfo_records)
def setUp(self): with open(data_path('simple_taxtable.csv')) as fp: self.root = TaxNode.from_taxtable(fp)
def action(a): random.seed(a.seed) j = functools.partial(os.path.join, a.output_dir) if not os.path.isdir(j()): raise IOError('Does not exist: {0}'.format(j())) if os.path.exists(j('index.refpkg')): raise IOError('index.refpkg exists.') with open(a.taxonomy) as fp: logging.info('loading taxonomy') taxonomy = TaxNode.from_taxtable(fp) # If partitioning, partition with current args, return if a.partition_below_rank is not None or a.partition_rank is not None: if not a.partition_below_rank or not a.partition_rank: raise ValueError( "--partition-below-rank and --partition-rank must be specified together" ) return partition_hrefpkg(a, taxonomy) with open(a.seqinfo_file) as fp: logging.info("loading seqinfo") seqinfo = load_seqinfo(fp) # Build an hrefpkg nodes = [i for i in taxonomy if i.rank == a.index_rank] hrefpkgs = [] futs = {} with open(j('index.csv'), 'w') as fp, \ open(j('train.fasta'), 'w') as train_fp, \ open(j('test.fasta'), 'w') as test_fp, \ futures.ThreadPoolExecutor(a.threads) as executor: def log_hrefpkg(tax_id): path = j(tax_id + '.refpkg') fp.write('{0},{0}.refpkg\n'.format(tax_id)) hrefpkgs.append(path) for i, node in enumerate(nodes): if a.only and node.tax_id not in a.only: logging.info("Skipping %s", node.tax_id) continue if os.path.exists(j(node.tax_id + '.refpkg')): logging.warn("Refpkg exists: %s.refpkg. Skipping", node.tax_id) log_hrefpkg(node.tax_id) continue f = executor.submit(tax_id_refpkg, node.tax_id, taxonomy, seqinfo, a.sequence_file, output_dir=a.output_dir, test_file=test_fp, train_file=train_fp) futs[f] = node.tax_id, node.name while futs: done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED) for f in done: tax_id, name = futs.pop(f) r = f.result() if r: logging.info("Finished refpkg for %s (%s) [%d remaining]", name, tax_id, len(pending)) log_hrefpkg(tax_id) assert len(futs) == len(pending) # Build index refpkg logging.info('Building index.refpkg') index_rp, sequence_ids = build_index_refpkg(hrefpkgs, a.sequence_file, seqinfo, taxonomy, dest=j('index.refpkg'), index_rank=a.index_rank) # Write unused seqs logging.info("Extracting unused sequences") seqs = (i for i in SeqIO.parse(a.sequence_file, 'fasta') if i.id not in sequence_ids) c = SeqIO.write(seqs, j('not_in_hrefpkgs.fasta'), 'fasta') logging.info("%d sequences not in hrefpkgs.", c)
def main(): logging.basicConfig( level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( description="Add classifications to a database.") parser.add_argument('refpkg', type=Refpkg, help="refpkg containing input taxa") parser.add_argument('classification_db', type=sqlite3.connect, help="output sqlite database") parser.add_argument('classifications', type=argparse.FileType('r'), nargs='?', default=sys.stdin, help="input query sequences") args = parser.parse_args() log.info('loading taxonomy') taxtable = TaxNode.from_taxtable(args.refpkg.open_resource('taxonomy', 'rU')) rank_order = {rank: e for e, rank in enumerate(taxtable.ranks)} def full_lineage(node): rank_iter = reversed(taxtable.ranks) for n in reversed(node.lineage()): n_order = rank_order[n.rank] yield n, list(itertools.takewhile(lambda r: rank_order[r] >= n_order, rank_iter)) def multiclass_rows(placement_id, seq, taxa): ret = collections.defaultdict(float) likelihood = 1. / len(taxa) for taxon in taxa: for node, want_ranks in full_lineage(taxtable.get_node(taxon)): for want_rank in want_ranks: ret[placement_id, seq, want_rank, node.rank, node.tax_id] += likelihood for k, v in sorted(ret.items()): yield k + (v,) curs = args.classification_db.cursor() curs.execute('INSERT INTO runs (params) VALUES (?)', (' '.join(sys.argv),)) run_id = curs.lastrowid log.info('inserting classifications') for (name, mass), rows in group_by_name_and_mass(csv.DictReader(args.classifications)): curs.execute('INSERT INTO placements (classifier, run_id) VALUES ("csv", ?)', (run_id,)) placement_id = curs.lastrowid curs.execute( 'INSERT INTO placement_names (placement_id, name, origin, mass) VALUES (?, ?, ?, ?)', (placement_id, name, args.classifications.name, mass)) taxa = [row['tax_id'] for row in rows] curs.executemany('INSERT INTO multiclass VALUES (?, ?, ?, ?, ?, ?)', multiclass_rows(placement_id, name, taxa)) log.info('cleaning up `multiclass` table') curs.execute(""" CREATE TEMPORARY TABLE duplicates AS SELECT name FROM multiclass JOIN placements USING (placement_id) GROUP BY name HAVING SUM(run_id = ?) AND COUNT(DISTINCT run_id) > 1 """, (run_id,)) curs.execute(""" DELETE FROM multiclass WHERE (SELECT run_id FROM placements p WHERE p.placement_id = multiclass.placement_id) <> ? AND name IN (SELECT name FROM duplicates) """, (run_id,)) args.classification_db.commit()
def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( description="Add classifications to a database.") parser.add_argument('refpkg', type=Refpkg, help="refpkg containing input taxa") parser.add_argument('classification_db', type=sqlite3.connect, help="output sqlite database") parser.add_argument('classifications', type=argparse.FileType('r'), nargs='?', default=sys.stdin, help="input query sequences") args = parser.parse_args() log.info('loading taxonomy') taxtable = TaxNode.from_taxtable( args.refpkg.open_resource('taxonomy', 'rU')) rank_order = {rank: e for e, rank in enumerate(taxtable.ranks)} def full_lineage(node): rank_iter = reversed(taxtable.ranks) for n in reversed(node.lineage()): n_order = rank_order[n.rank] yield n, list( itertools.takewhile(lambda r: rank_order[r] >= n_order, rank_iter)) def multiclass_rows(placement_id, seq, taxa): ret = collections.defaultdict(float) likelihood = 1. / len(taxa) for taxon in taxa: for node, want_ranks in full_lineage(taxtable.get_node(taxon)): for want_rank in want_ranks: ret[placement_id, seq, want_rank, node.rank, node.tax_id] += likelihood for k, v in sorted(ret.items()): yield k + (v, ) curs = args.classification_db.cursor() curs.execute('INSERT INTO runs (params) VALUES (?)', (' '.join(sys.argv), )) run_id = curs.lastrowid log.info('inserting classifications') for (name, mass), rows in group_by_name_and_mass( csv.DictReader(args.classifications)): curs.execute( 'INSERT INTO placements (classifier, run_id) VALUES ("csv", ?)', (run_id, )) placement_id = curs.lastrowid curs.execute( 'INSERT INTO placement_names (placement_id, name, origin, mass) VALUES (?, ?, ?, ?)', (placement_id, name, args.classifications.name, mass)) taxa = [row['tax_id'] for row in rows] curs.executemany('INSERT INTO multiclass VALUES (?, ?, ?, ?, ?, ?)', multiclass_rows(placement_id, name, taxa)) log.info('cleaning up `multiclass` table') curs.execute( """ CREATE TEMPORARY TABLE duplicates AS SELECT name FROM multiclass JOIN placements USING (placement_id) GROUP BY name HAVING SUM(run_id = ?) AND COUNT(DISTINCT run_id) > 1 """, (run_id, )) curs.execute( """ DELETE FROM multiclass WHERE (SELECT run_id FROM placements p WHERE p.placement_id = multiclass.placement_id) <> ? AND name IN (SELECT name FROM duplicates) """, (run_id, )) args.classification_db.commit()