Пример #1
0
class FindNodesTestCase(unittest.TestCase):
    def setUp(self):
        self.taxonomy = TaxNode(rank='root', name='root', tax_id='1')
        self.taxonomy.ranks = ['root', 'class', 'genus', 'species']
        g1 = TaxNode(rank='genus', name='g1', tax_id='2')
        self.g1 = g1
        g1.sequence_ids = set(['s1', 's2'])
        self.taxonomy.add_child(g1)
        g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1'))
        g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2'))

        g2 = TaxNode(rank='genus', name='g2', tax_id='3')
        self.taxonomy.add_child(g2)
        s3 = TaxNode(rank='species', name='s3', tax_id='s3')
        s3.sequence_ids = set(['s3', 's4'])
        g2.add_child(s3)
        g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4'))

    def test_find_nodes(self):
        r = list(hrefpkg_build.find_nodes(self.taxonomy, 'class'))
        self.assertEqual(frozenset(['g1', 's3']), frozenset(i.name for i in r))

    def test_find_nodes_below_rank(self):
        r = list(hrefpkg_build.find_nodes(self.taxonomy, 'genus'))
        self.assertEqual(['s3'], [i.name for i in r])
class FindNodesTestCase(unittest.TestCase):
    def setUp(self):
        self.taxonomy = TaxNode(rank='root', name='root', tax_id='1')
        self.taxonomy.ranks = ['root', 'class', 'genus', 'species']
        g1 = TaxNode(rank='genus', name='g1', tax_id='2')
        self.g1 = g1
        g1.sequence_ids = set(['s1', 's2'])
        self.taxonomy.add_child(g1)
        g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1'))
        g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2'))

        g2 = TaxNode(rank='genus', name='g2', tax_id='3')
        self.taxonomy.add_child(g2)
        s3 = TaxNode(rank='species', name='s3', tax_id='s3')
        s3.sequence_ids = set(['s3', 's4'])
        g2.add_child(s3)
        g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4'))

    def test_find_nodes(self):
        r = list(hrefpkg_build.find_nodes(self.taxonomy, 'class'))
        self.assertEqual(frozenset(['g1', 's3']), frozenset(i.name for i in r))

    def test_find_nodes_below_rank(self):
        r = list(hrefpkg_build.find_nodes(self.taxonomy, 'genus'))
        self.assertEqual(['s3'], [i.name for i in r])
Пример #3
0
def generate_tax2tree_map(refpkg, output_fp):
    """
    Generate a tax2tree map from a reference package, writing to output_fp
    """
    with open(refpkg.file_abspath('taxonomy')) as fp:
        tax_root = TaxNode.from_taxtable(fp)

    def lineage(tax_id):
        l = tax_root.get_node(tax_id).lineage()
        d = {i.rank: i.tax_id for i in l}
        r = ('{0}__{1}'.format(rank[0], d.get(rank, ''))
             for rank in TAX2TREE_RANKS)
        return '; '.join(r)

    with open(refpkg.file_abspath('seq_info')) as fp:
        reader = csv.DictReader(fp)
        seq_map = ((i['seqname'], i['tax_id']) for i in reader)

        for seqname, tax_id in seq_map:
            if not tax_id:
                l = 'Unclassified'
            else:
                l = lineage(tax_id)

            print >> output_fp, '\t'.join((seqname, l))
def main():
    logging.basicConfig(level=logging.INFO,
                        format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser(
        description=
        "Turn a BIOM file with a taxonomy into a taxtable and seqinfo.")

    parser.add_argument('biom',
                        type=argparse.FileType('r'),
                        help='input BIOM file')
    parser.add_argument('taxtable',
                        type=argparse.FileType('w'),
                        help='output taxtable')
    parser.add_argument('seqinfo',
                        type=argparse.FileType('w'),
                        help='output seqinfo')

    args = parser.parse_args()

    log.info('loading biom')
    with args.biom:
        j = json.load(args.biom)

    root = TaxNode('root', root_id, name='Root')
    root.ranks = rank_order
    seqinfo = csv.writer(args.seqinfo)
    seqinfo.writerow(('seqname', 'tax_id'))

    log.info('determining tax_ids')
    for leaf in j['rows']:
        leaf_taxonomy = leaf['metadata']['taxonomy']

        # Drop nodes containing only rank (e.g. `s__`)
        leaf_taxonomy = [i for i in leaf_taxonomy if i[3:]]
        leaf_lineages = list(lineages([i for i in leaf_taxonomy if i[3:]]))

        seqinfo.writerow((leaf['id'], leaf_lineages[-1][0]))

        for tax_id, node, parent in leaf_lineages:
            if tax_id in root.index:
                continue
            root.get_node(parent).add_child(
                TaxNode(ranks[node[0]], tax_id, name=node[3:] or node))

    log.info('writing taxtable')
    with args.taxtable:
        root.write_taxtable(args.taxtable)
Пример #5
0
def update_taxids(refpkg,
                  tax2tree_dict,
                  output_fp,
                  allow_rename=True,
                  unknown_taxid=None):
    with open(refpkg.file_abspath('taxonomy')) as fp:
        tax_root = TaxNode.from_taxtable(fp)

    def lineage_ids(tax_id):
        if not tax_id:
            return frozenset()
        n = tax_root.get_node(tax_id)
        s = frozenset(i.tax_id for i in n.lineage())
        return s

    with open(refpkg.file_abspath('seq_info')) as fp:
        try:
            dialect = csv.Sniffer().sniff(fp.read(1024), delimiters=',')
        except csv.Error:
            dialect = csv.excel
        fp.seek(0)
        r = csv.DictReader(fp, dialect=dialect)
        h = r.fieldnames
        h.append('orig_tax_id')
        w = csv.DictWriter(output_fp, h, dialect=dialect)
        w.writeheader()
        for i in r:
            i['orig_tax_id'] = i['tax_id']
            n = tax2tree_dict.get(i['seqname'])
            if n and n not in lineage_ids(i['tax_id']):
                node = tax_root.get_node(n)
                new_name, new_rank = node.name, node.rank
                if i['tax_id']:
                    node = tax_root.get_node(i['tax_id'])
                    orig_name, orig_rank = node.name, node.rank
                else:
                    orig_name, orig_rank = '', ''

                logging.info('%s changed from "%s" (%s) to "%s" (%s)',
                             i['seqname'], orig_name, orig_rank, new_name,
                             new_rank)
                if allow_rename or not i['tax_id']:
                    i['tax_id'] = n
                else:
                    logging.info("Not applied.")
            elif n is None and not i['tax_id']:
                i['tax_id'] = unknown_taxid
                logging.warn("no taxonomy for %s", i['seqname'])
            w.writerow(i)
Пример #6
0
def update_taxids(refpkg, tax2tree_dict, output_fp, allow_rename=True, unknown_taxid=None):
    with open(refpkg.file_abspath('taxonomy')) as fp:
        tax_root = TaxNode.from_taxtable(fp)

    def lineage_ids(tax_id):
        if not tax_id:
            return frozenset()
        n = tax_root.get_node(tax_id)
        s = frozenset(i.tax_id for i in n.lineage())
        return s

    with open(refpkg.file_abspath('seq_info')) as fp:
        try:
            dialect = csv.Sniffer().sniff(fp.read(1024), delimiters=',')
        except csv.Error:
            dialect = csv.excel
        fp.seek(0)
        r = csv.DictReader(fp, dialect=dialect)
        h = r.fieldnames
        h.append('orig_tax_id')
        w = csv.DictWriter(output_fp, h, dialect=dialect)
        w.writeheader()
        for i in r:
            i['orig_tax_id'] = i['tax_id']
            n = tax2tree_dict.get(i['seqname'])
            if n and n not in lineage_ids(i['tax_id']):
                node = tax_root.get_node(n)
                new_name, new_rank = node.name, node.rank
                if i['tax_id']:
                    node = tax_root.get_node(i['tax_id'])
                    orig_name, orig_rank = node.name, node.rank
                else:
                    orig_name, orig_rank = '', ''

                logging.info('%s changed from "%s" (%s) to "%s" (%s)',
                        i['seqname'], orig_name, orig_rank, new_name, new_rank)
                if allow_rename or not i['tax_id']:
                    i['tax_id'] = n
                else:
                    logging.info("Not applied.")
            elif n is None and not i['tax_id']:
                i['tax_id'] = unknown_taxid
                logging.warn("no taxonomy for %s", i['seqname'])
            w.writerow(i)
Пример #7
0
def generate_tax2tree_map(refpkg, output_fp):
    """
    Generate a tax2tree map from a reference package, writing to output_fp
    """
    with open(refpkg.file_abspath('taxonomy')) as fp:
        tax_root = TaxNode.from_taxtable(fp)

    def lineage(tax_id):
        l = tax_root.get_node(tax_id).lineage()
        d = {i.rank: i.tax_id for i in l}
        r = ('{0}__{1}'.format(rank[0], d.get(rank, '')) for rank in TAX2TREE_RANKS)
        return '; '.join(r)

    with open(refpkg.file_abspath('seq_info')) as fp:
        reader = csv.DictReader(fp)
        seq_map = ((i['seqname'], i['tax_id']) for i in reader)

        for seqname, tax_id in seq_map:
            if not tax_id:
                l = 'Unclassified'
            else:
                l = lineage(tax_id)

            print >> output_fp, '\t'.join((seqname, l))
Пример #8
0
def main():
    logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser(
        description="Turn a BIOM file with a taxonomy into a taxtable and seqinfo.")

    parser.add_argument(
        'biom', type=argparse.FileType('r'), help='input BIOM file')
    parser.add_argument(
        'taxtable', type=argparse.FileType('w'), help='output taxtable')
    parser.add_argument(
        'seqinfo', type=argparse.FileType('w'), help='output seqinfo')

    args = parser.parse_args()

    log.info('loading biom')
    with args.biom:
        j = json.load(args.biom)

    root = TaxNode('root', root_id, name='Root')
    root.ranks = rank_order
    seqinfo = csv.writer(args.seqinfo)
    seqinfo.writerow(('seqname', 'tax_id'))

    log.info('determining tax_ids')
    for leaf in j['rows']:
        leaf_taxonomy = leaf['metadata']['taxonomy']

        # Drop nodes containing only rank (e.g. `s__`)
        leaf_taxonomy = [i for i in leaf_taxonomy if i[3:]]
        leaf_lineages = list(lineages([i for i in leaf_taxonomy if i[3:]]))

        seqinfo.writerow((leaf['id'], leaf_lineages[-1][0]))

        for tax_id, node, parent in leaf_lineages:
            if tax_id in root.index:
                continue
            root.get_node(parent).add_child(
                TaxNode(ranks[node[0]], tax_id, name=node[3:] or node))

    log.info('writing taxtable')
    with args.taxtable:
        root.write_taxtable(args.taxtable)
Пример #9
0
def action(a):
    # Load taxonomy
    with a.taxonomy as fp:
        taxonomy = TaxNode.from_taxtable(fp)
        logging.info('Loaded taxonomy')

    # Load sequences into taxonomy
    with a.seqinfo_file as fp:
        taxonomy.populate_from_seqinfo(fp)
        logging.info('Added %d sequences', sum(1 for i in taxonomy.subtree_sequence_ids()))

    # Sequences which are classified above the desired rank should just be kept
    kept_ids = frozenset(sequences_above_rank(taxonomy, a.filter_rank))

    log_taxid = None
    if a.log:
        writer = csv.writer(a.log, lineterminator='\n',
                quoting=csv.QUOTE_NONNUMERIC)
        writer.writerow(('tax_id', 'tax_name', 'n', 'kept', 'pruned'))
        def log_taxid(tax_id, tax_name, n, kept, pruned):
            writer.writerow((tax_id, tax_name, n, kept, pruned))

    with a.output_fp as fp, a.log or util.nothing():
        logging.info('Keeping %d sequences classified above %s', len(kept_ids), a.filter_rank)
        wrap.esl_sfetch(a.sequence_file, kept_ids, fp)

        # For each filter-rank, filter
        nodes = [i for i in taxonomy if i.rank == a.filter_rank]

        # Filter each tax_id, running in ``--threads`` tasks in parallel
        with futures.ThreadPoolExecutor(a.threads) as executor:
            futs = {}
            for i, node in enumerate(nodes):
                seqs = frozenset(node.subtree_sequence_ids())
                if not seqs:
                    logging.warn("No sequences for %s (%s)", node.tax_id, node.name)
                    if log_taxid:
                        log_taxid(node.tax_id, node.name, 0, 0, 0)
                    continue
                elif len(seqs) == 1:
                    logging.warn('Only 1 sequence for %s (%s). Dropping', node.tax_id, node.name)
                    continue

                f = executor.submit(filter_worker,
                        sequence_file=a.sequence_file,
                        node=node,
                        seqs=seqs,
                        distance_cutoff=a.distance_cutoff,
                        log_taxid=log_taxid)
                futs[f] = {'n_seqs': len(seqs), 'node': node}

            complete = 0
            while futs:
                done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED)
                complete += len(done)
                sys.stderr.write('{0:8d}/{1:8d} taxa completed\r'.format(complete,
                    complete+len(pending)))
                for f in done:
                    if f.exception():
                        logging.exception("Error in child process: %s", f.exception())
                        executor.shutdown(False)
                        raise f.exception()

                    info = futs.pop(f)
                    kept = f.result()
                    kept_ids |= kept
                    if len(kept) != info['n_seqs']:
                        logging.warn('Pruned %d/%d sequences for %s (%s)',
                                info['n_seqs'] - len(kept), info['n_seqs'],
                                info['node'].tax_id, info['node'].name)
                    # Extract
                    wrap.esl_sfetch(a.sequence_file, kept, fp)

    # Filter seqinfo
    if a.filtered_seqinfo:
        with open(a.seqinfo_file.name) as fp:
            r = csv.DictReader(fp)
            rows = (i for i in r if i['seqname'] in kept_ids)
            with a.filtered_seqinfo as ofp:
                w = csv.DictWriter(ofp, r.fieldnames, quoting=csv.QUOTE_NONNUMERIC)
                w.writeheader()
                w.writerows(rows)
Пример #10
0
def action(a):
    with a.taxonomy as fp:
        taxonomy = TaxNode.from_taxtable(fp)
    with a.seqinfo_file as fp:
        # List of sequences
        r = csv.DictReader(fp)
        current_seqs = frozenset(i['seqname'] for i in r)
        fp.seek(0)
        taxonomy.populate_from_seqinfo(fp)

    # Find sequences from underrepresented taxids to search
    underrep = find_underrepresented(taxonomy, a.min_at_rank, a.rank)

    tax_seqs = {}
    seq_group = {}
    for n, seqs in underrep:
        tax_seqs[n.tax_id] = seqs
        seq_group.update({i:n.tax_id for i in seqs})

    with util.ntf(prefix='to_expand-', suffix='.fasta') as expand_fp, \
         util.ntf(prefix='expand_hits-', suffix='.fasta') as hits_fp:
        # Extract sequences
        c = wrap.esl_sfetch(a.sequence_file, seq_group, expand_fp)
        logging.info('fetched %d sequences', c)
        expand_fp.close()

        # Search sequences against unnamed
        r = uclust_search(expand_fp.name, a.unnamed_file, pct_id=a.pct_id,
                maxaccepts=4, search_pct_id=0.9, trunclabels=True)
        hits = list(r)
        # Map from hit to group
        hit_group = {i.target_label: seq_group[i.query_label] for i in hits}

        # Extract hits
        c = wrap.esl_sfetch(a.unnamed_file, hit_group, hits_fp)
        logging.info('%d hits', c)
        hits_fp.close()

        # Search hits back against named file
        r = uclust_search(hits_fp.name, expand_fp.name, pct_id=a.pct_id,
                maxaccepts=1, search_pct_id=0.9, trunclabels=True)

        # Sequences which hit the same group
        update_hits = dict((i.query_label, seq_group[i.target_label])
                       for i in r
                       if seq_group[i.target_label] == hit_group[i.query_label])

        overlap = frozenset(update_hits) & current_seqs
        if overlap:
            logging.warn('%d sequences already present in corpus: %s',
                    len(overlap), ', '.join(overlap))

        # Add sequences
        with open(a.output + '.fasta', 'w') as ofp:
            with open(a.sequence_file) as fp:
                shutil.copyfileobj(fp, ofp)
            try:
                wrap.esl_sfetch(hits_fp.name, frozenset(update_hits) - current_seqs, ofp)
            finally:
                os.remove(hits_fp.name + '.ssi')

        # Write a new seq_info
        with open(a.output + '.seq_info.csv', 'w') as ofp, open(a.seqinfo_file.name) as sinfo:
            r = csv.DictReader(sinfo)
            fn = list(r.fieldnames) + ['inferred_tax_id']
            w = csv.DictWriter(ofp, fn, lineterminator='\n', quoting=csv.QUOTE_NONNUMERIC)
            w.writeheader()
            w.writerows(i for i in r if i['seqname'] not in overlap)
            if 'cluster' in fn:
                rows = ({'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes', 'cluster': v}
                        for k, v in update_hits.items())
            else:
                rows = ({'seqname': k, 'tax_id': v, 'inferred_tax_id': 'yes'}
                        for k, v in update_hits.items())
            w.writerows(rows)
Пример #11
0
    def setUp(self):
        self.taxonomy = TaxNode(rank='root', name='root', tax_id='1')
        self.taxonomy.ranks = ['root', 'class', 'genus', 'species']
        g1 = TaxNode(rank='genus', name='g1', tax_id='2')
        self.g1 = g1
        g1.sequence_ids = set(['s1', 's2'])
        self.taxonomy.add_child(g1)
        g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1'))
        g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2'))

        g2 = TaxNode(rank='genus', name='g2', tax_id='3')
        self.taxonomy.add_child(g2)
        s3 = TaxNode(rank='species', name='s3', tax_id='s3')
        s3.sequence_ids = set(['s3', 's4'])
        g2.add_child(s3)
        g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4'))
Пример #12
0
    def setUp(self):
        self.taxonomy = TaxNode(rank='root', name='root', tax_id='1')
        self.taxonomy.ranks = ['root', 'class', 'genus', 'species']
        g1 = TaxNode(rank='genus', name='g1', tax_id='2')
        self.g1 = g1
        g1.sequence_ids = set(['s1', 's2'])
        self.taxonomy.add_child(g1)
        g1.add_child(TaxNode(rank='species', name='s1', tax_id='s1'))
        g1.add_child(TaxNode(rank='species', name='s2', tax_id='s2'))

        g2 = TaxNode(rank='genus', name='g2', tax_id='3')
        self.taxonomy.add_child(g2)
        s3 = TaxNode(rank='species', name='s3', tax_id='s3')
        s3.sequence_ids = set(['s3', 's4'])
        g2.add_child(s3)
        g2.add_child(TaxNode(rank='species', name='s4', tax_id='s4'))
Пример #13
0
def action(a):
    random.seed(a.seed)
    j = functools.partial(os.path.join, a.output_dir)
    if not os.path.isdir(j()):
        raise IOError('Does not exist: {0}'.format(j()))
    if os.path.exists(j('index.refpkg')):
        raise IOError('index.refpkg exists.')

    with open(a.taxonomy) as fp:
        logging.info('loading taxonomy')
        taxonomy = TaxNode.from_taxtable(fp)

    # If partitioning, partition with current args, return
    if a.partition_below_rank is not None or a.partition_rank is not None:
        if not a.partition_below_rank or not a.partition_rank:
            raise ValueError("--partition-below-rank and --partition-rank must be specified together")
        return partition_hrefpkg(a, taxonomy)

    with open(a.seqinfo_file) as fp:
        logging.info("loading seqinfo")
        seqinfo = load_seqinfo(fp)

    # Build an hrefpkg
    nodes = [i for i in taxonomy if i.rank == a.index_rank]
    hrefpkgs = []
    futs = {}
    with open(j('index.csv'), 'w') as fp, \
         open(j('train.fasta'), 'w') as train_fp, \
         open(j('test.fasta'), 'w') as test_fp, \
         futures.ThreadPoolExecutor(a.threads) as executor:
        def log_hrefpkg(tax_id):
            path = j(tax_id + '.refpkg')
            fp.write('{0},{0}.refpkg\n'.format(tax_id))
            hrefpkgs.append(path)

        for i, node in enumerate(nodes):
            if a.only and node.tax_id not in a.only:
                logging.info("Skipping %s", node.tax_id)
                continue

            if os.path.exists(j(node.tax_id + '.refpkg')):
                logging.warn("Refpkg exists: %s.refpkg. Skipping", node.tax_id)
                log_hrefpkg(node.tax_id)
                continue

            f = executor.submit(tax_id_refpkg, node.tax_id, taxonomy, seqinfo,
                    a.sequence_file, output_dir=a.output_dir, test_file=test_fp,
                    train_file=train_fp)
            futs[f] = node.tax_id, node.name

        while futs:
            done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED)
            for f in done:
                tax_id, name = futs.pop(f)
                r = f.result()
                if r:
                    logging.info("Finished refpkg for %s (%s) [%d remaining]", name, tax_id, len(pending))
                    log_hrefpkg(tax_id)
            assert len(futs) == len(pending)

        # Build index refpkg
        logging.info('Building index.refpkg')
        index_rp, sequence_ids = build_index_refpkg(hrefpkgs, a.sequence_file,
            seqinfo, taxonomy, dest=j('index.refpkg'),
            index_rank=a.index_rank)

        # Write unused seqs
        logging.info("Extracting unused sequences")
        seqs = (i for i in SeqIO.parse(a.sequence_file, 'fasta')
                if i.id not in sequence_ids)
        c = SeqIO.write(seqs, j('not_in_hrefpkgs.fasta'), 'fasta')
        logging.info("%d sequences not in hrefpkgs.", c)
Пример #14
0
def action(a):
    with a.taxonomy as fp:
        taxonomy = TaxNode.from_taxtable(fp)
    with a.seqinfo_file as fp:
        # List of sequences
        r = csv.DictReader(fp)
        current_seqs = frozenset(i['seqname'] for i in r)
        fp.seek(0)
        taxonomy.populate_from_seqinfo(fp)

    # Find sequences from underrepresented taxids to search
    underrep = find_underrepresented(taxonomy, a.min_at_rank, a.rank)

    tax_seqs = {}
    seq_group = {}
    for n, seqs in underrep:
        tax_seqs[n.tax_id] = seqs
        seq_group.update({i: n.tax_id for i in seqs})

    with util.ntf(prefix='to_expand-', suffix='.fasta') as expand_fp, \
         util.ntf(prefix='expand_hits-', suffix='.fasta') as hits_fp:
        # Extract sequences
        c = wrap.esl_sfetch(a.sequence_file, seq_group, expand_fp)
        logging.info('fetched %d sequences', c)
        expand_fp.close()

        # Search sequences against unnamed
        r = uclust_search(expand_fp.name,
                          a.unnamed_file,
                          pct_id=a.pct_id,
                          maxaccepts=4,
                          search_pct_id=0.9)
        hits = list(r)
        # Map from hit to group
        hit_group = {i.target_label: seq_group[i.query_label] for i in hits}

        # Extract hits
        c = wrap.esl_sfetch(a.unnamed_file, hit_group, hits_fp)
        logging.info('%d hits', c)
        hits_fp.close()

        # Search hits back against named file
        r = uclust_search(hits_fp.name,
                          expand_fp.name,
                          pct_id=a.pct_id,
                          maxaccepts=1,
                          search_pct_id=0.9)

        # Sequences which hit the same group
        update_hits = dict(
            (i.query_label, seq_group[i.target_label]) for i in r
            if seq_group[i.target_label] == hit_group[i.query_label])

        overlap = frozenset(update_hits) & current_seqs
        if overlap:
            logging.warn('%d sequences already present in corpus: %s',
                         len(overlap), ', '.join(overlap))

        # Add sequences
        with open(a.output + '.fasta', 'w') as ofp:
            with open(a.sequence_file) as fp:
                shutil.copyfileobj(fp, ofp)
            try:
                wrap.esl_sfetch(hits_fp.name,
                                frozenset(update_hits) - current_seqs, ofp)
            finally:
                os.remove(hits_fp.name + '.ssi')

        # Write a new seq_info
        with open(a.output + '.seq_info.csv',
                  'w') as ofp, open(a.seqinfo_file.name) as sinfo:
            r = csv.DictReader(sinfo)
            fn = list(r.fieldnames) + ['inferred_tax_id']
            w = csv.DictWriter(ofp,
                               fn,
                               lineterminator='\n',
                               quoting=csv.QUOTE_NONNUMERIC)
            w.writeheader()
            w.writerows(i for i in r if i['seqname'] not in overlap)
            if 'cluster' in fn:
                rows = ({
                    'seqname': k,
                    'tax_id': v,
                    'inferred_tax_id': 'yes',
                    'cluster': v
                } for k, v in update_hits.items())
            else:
                rows = ({
                    'seqname': k,
                    'tax_id': v,
                    'inferred_tax_id': 'yes'
                } for k, v in update_hits.items())
            w.writerows(rows)
Пример #15
0
def action(a):
    # Load taxtable
    with a.taxtable as fp:
        logging.info('Loading taxonomy')
        taxonomy = TaxNode.from_taxtable(fp)
    with a.seqinfo_file as fp:
        logging.info('Loading seqinfo')
        taxonomy.populate_from_seqinfo(fp)
        fp.seek(0)
        r = csv.DictReader(fp)
        seqinfo = {i['seqname']: i for i in r}
    if a.unnamed_sequence_meta:
        with a.unnamed_sequence_meta as fp:
            r = csv.DictReader(fp)
            unnamed_seqinfo = {i['seqname']: i for i in r}
        assert not set(unnamed_seqinfo) & set(seqinfo)
        seqinfo.update(unnamed_seqinfo)

    # Write clustering information for sequences with cluster_rank-level
    # classifications
    done = set()
    cluster_ids = {}
    with a.sequence_out:
        for tax_id, sequences in taxonomic_clustered(taxonomy, a.cluster_rank):
            for sequence in sequences:
                cluster_ids[sequence] = tax_id
            done |= set(sequences)

        # Fetch sequences
        logging.info('Fetching %d %s-level sequences', len(done), a.cluster_rank)
        wrap.esl_sfetch(a.named_sequence_file, done, a.sequence_out)
        a.sequence_out.flush()

        # Find sequences *above* cluster_rank
        above_rank_seqs = frozenset(i for i in taxonomy.subtree_sequence_ids()
                                    if i not in done)
        logging.info('%d sequences above rank %s', len(above_rank_seqs),
                a.cluster_rank)

        # Write sequences clustered above species level, unnamed sequences to
        # file
        with util.ntf(prefix='to_cluster', suffix='.fasta') as tf, \
                util.ntf(prefix='unnamed_to_cluster', suffix='.fasta') as unnamed_fp:
            wrap.esl_sfetch(a.named_sequence_file, above_rank_seqs, tf)
            if a.unnamed_sequences:
                with open(a.unnamed_sequences) as fp:
                    shutil.copyfileobj(fp, tf)
            tf.close()

            # Remove redundant sequences: we don't need anything that's unnamed
            # & close to something named.
            redundant_ids = cluster_identify_redundant(a.sequence_out.name,
                    done, to_cluster=tf.name, threshold=a.redundant_cluster_id)
            logging.info('%d redundant sequences', len(redundant_ids))

            # Extract desired sequences
            sequences = SeqIO.parse(tf.name, 'fasta')
            sequences = (i for i in sequences if i.id not in redundant_ids)

            # Write to file for clustering
            unnamed_count = SeqIO.write(sequences, unnamed_fp, 'fasta')
            logging.info('Kept %d non-redundant, unnamed sequences', unnamed_count)

            # Write to output sequence file
            unnamed_fp.seek(0)
            shutil.copyfileobj(unnamed_fp, a.sequence_out)
            unnamed_fp.close()

            # Cluster remaining sequences into OTUs
            for i, cluster_seqs in enumerate(identify_otus_unnamed(unnamed_fp.name, a.cluster_id)):
                done |= set(cluster_seqs)
                otu = 'otu_{0}'.format(i)
                for sequence in cluster_seqs:
                    cluster_ids[sequence] = otu

    with a.seqinfo_out as fp:
        def add_cluster(i):
            """Add a cluster identifier to sequence metadata"""
            i['cluster'] = cluster_ids[i['seqname']]
            return i
        seqinfo_records = (seqinfo.get(i, {'seqname': i}) for i in done)
        seqinfo_records = (add_cluster(i) for i in seqinfo_records)

        fields = list(seqinfo.values()[0].keys())
        fields.append('cluster')
        w = csv.DictWriter(fp, fields,
                quoting=csv.QUOTE_NONNUMERIC, lineterminator='\n')
        w.writeheader()
        w.writerows(seqinfo_records)
Пример #16
0
 def setUp(self):
     with open(data_path('simple_taxtable.csv')) as fp:
         self.root = TaxNode.from_taxtable(fp)
Пример #17
0
def action(a):
    random.seed(a.seed)
    j = functools.partial(os.path.join, a.output_dir)
    if not os.path.isdir(j()):
        raise IOError('Does not exist: {0}'.format(j()))
    if os.path.exists(j('index.refpkg')):
        raise IOError('index.refpkg exists.')

    with open(a.taxonomy) as fp:
        logging.info('loading taxonomy')
        taxonomy = TaxNode.from_taxtable(fp)

    # If partitioning, partition with current args, return
    if a.partition_below_rank is not None or a.partition_rank is not None:
        if not a.partition_below_rank or not a.partition_rank:
            raise ValueError(
                "--partition-below-rank and --partition-rank must be specified together"
            )
        return partition_hrefpkg(a, taxonomy)

    with open(a.seqinfo_file) as fp:
        logging.info("loading seqinfo")
        seqinfo = load_seqinfo(fp)

    # Build an hrefpkg
    nodes = [i for i in taxonomy if i.rank == a.index_rank]
    hrefpkgs = []
    futs = {}
    with open(j('index.csv'), 'w') as fp, \
         open(j('train.fasta'), 'w') as train_fp, \
         open(j('test.fasta'), 'w') as test_fp, \
         futures.ThreadPoolExecutor(a.threads) as executor:

        def log_hrefpkg(tax_id):
            path = j(tax_id + '.refpkg')
            fp.write('{0},{0}.refpkg\n'.format(tax_id))
            hrefpkgs.append(path)

        for i, node in enumerate(nodes):
            if a.only and node.tax_id not in a.only:
                logging.info("Skipping %s", node.tax_id)
                continue

            if os.path.exists(j(node.tax_id + '.refpkg')):
                logging.warn("Refpkg exists: %s.refpkg. Skipping", node.tax_id)
                log_hrefpkg(node.tax_id)
                continue

            f = executor.submit(tax_id_refpkg,
                                node.tax_id,
                                taxonomy,
                                seqinfo,
                                a.sequence_file,
                                output_dir=a.output_dir,
                                test_file=test_fp,
                                train_file=train_fp)
            futs[f] = node.tax_id, node.name

        while futs:
            done, pending = futures.wait(futs, 1, futures.FIRST_COMPLETED)
            for f in done:
                tax_id, name = futs.pop(f)
                r = f.result()
                if r:
                    logging.info("Finished refpkg for %s (%s) [%d remaining]",
                                 name, tax_id, len(pending))
                    log_hrefpkg(tax_id)
            assert len(futs) == len(pending)

        # Build index refpkg
        logging.info('Building index.refpkg')
        index_rp, sequence_ids = build_index_refpkg(hrefpkgs,
                                                    a.sequence_file,
                                                    seqinfo,
                                                    taxonomy,
                                                    dest=j('index.refpkg'),
                                                    index_rank=a.index_rank)

        # Write unused seqs
        logging.info("Extracting unused sequences")
        seqs = (i for i in SeqIO.parse(a.sequence_file, 'fasta')
                if i.id not in sequence_ids)
        c = SeqIO.write(seqs, j('not_in_hrefpkgs.fasta'), 'fasta')
        logging.info("%d sequences not in hrefpkgs.", c)
Пример #18
0
 def setUp(self):
     with open(data_path('simple_taxtable.csv')) as fp:
         self.root = TaxNode.from_taxtable(fp)
Пример #19
0
def action(a):
    # Load taxtable
    with a.taxtable as fp:
        logging.info('Loading taxonomy')
        taxonomy = TaxNode.from_taxtable(fp)
    with a.seqinfo_file as fp:
        logging.info('Loading seqinfo')
        taxonomy.populate_from_seqinfo(fp)
        fp.seek(0)
        r = csv.DictReader(fp)
        seqinfo = {i['seqname']: i for i in r}
    if a.unnamed_sequence_meta:
        with a.unnamed_sequence_meta as fp:
            r = csv.DictReader(fp)
            unnamed_seqinfo = {i['seqname']: i for i in r}
        assert not set(unnamed_seqinfo) & set(seqinfo)
        seqinfo.update(unnamed_seqinfo)


    # Write clustering information for sequences with cluster_rank-level
    # classifications
    done = set()
    cluster_ids = {}
    with a.sequence_out:
        for tax_id, sequences in taxonomic_clustered(taxonomy, a.cluster_rank):
            for sequence in sequences:
                cluster_ids[sequence] = tax_id
            done |= set(sequences)

        # Fetch sequences
        logging.info('Fetching %d %s-level sequences', len(done), a.cluster_rank)
        wrap.esl_sfetch(a.named_sequence_file, done, a.sequence_out)
        a.sequence_out.flush()

        # Find sequences *above* cluster_rank
        above_rank_seqs = frozenset(i for i in taxonomy.subtree_sequence_ids()
                                    if i not in done)
        logging.info('%d sequences above rank %s', len(above_rank_seqs),
                a.cluster_rank)

        # Write sequences clustered above species level, unnamed sequences to
        # file
        with util.ntf(prefix='to_cluster', suffix='.fasta') as tf, \
                util.ntf(prefix='unnamed_to_cluster', suffix='.fasta') as unnamed_fp:
            wrap.esl_sfetch(a.named_sequence_file, above_rank_seqs, tf)
            if a.unnamed_sequences:
                with open(a.unnamed_sequences) as fp:
                    shutil.copyfileobj(fp, tf)
            tf.close()

            # Remove redundant sequences: we don't need anything that's unnamed
            # & close to something named.
            redundant_ids = cluster_identify_redundant(a.sequence_out.name,
                    done, to_cluster=tf.name, threshold=a.redundant_cluster_id)
            logging.info('%d redundant sequences', len(redundant_ids))

            # Extract desired sequences
            sequences = SeqIO.parse(tf.name, 'fasta')
            sequences = (i for i in sequences if i.id not in redundant_ids)

            # Write to file for clustering
            unnamed_count = SeqIO.write(sequences, unnamed_fp, 'fasta')
            logging.info('Kept %d non-redundant, unnamed sequences', unnamed_count)

            # Write to output sequence file
            unnamed_fp.seek(0)
            shutil.copyfileobj(unnamed_fp, a.sequence_out)
            unnamed_fp.close()

            # Cluster remaining sequences into OTUs
            for i, cluster_seqs in enumerate(identify_otus_unnamed(unnamed_fp.name, a.cluster_id)):
                done |= set(cluster_seqs)
                otu = 'otu_{0}'.format(i)
                for sequence in cluster_seqs:
                    cluster_ids[sequence] = otu

    with a.seqinfo_out as fp:
        def add_cluster(i):
            """Add a cluster identifier to sequence metadata"""
            i['cluster'] = cluster_ids[i['seqname']]
            return i
        seqinfo_records = (seqinfo.get(i, {'seqname': i}) for i in done)
        seqinfo_records = (add_cluster(i) for i in seqinfo_records)

        fields = list(seqinfo.values()[0].keys())
        fields.append('cluster')
        w = csv.DictWriter(fp, fields,
                quoting=csv.QUOTE_NONNUMERIC, lineterminator='\n')
        w.writeheader()
        w.writerows(seqinfo_records)
Пример #20
0
def main():
    logging.basicConfig(
        level=logging.INFO, format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser(
        description="Add classifications to a database.")
    parser.add_argument('refpkg', type=Refpkg,
                        help="refpkg containing input taxa")
    parser.add_argument('classification_db', type=sqlite3.connect,
                        help="output sqlite database")
    parser.add_argument('classifications', type=argparse.FileType('r'), nargs='?',
                        default=sys.stdin, help="input query sequences")

    args = parser.parse_args()

    log.info('loading taxonomy')
    taxtable = TaxNode.from_taxtable(args.refpkg.open_resource('taxonomy', 'rU'))
    rank_order = {rank: e for e, rank in enumerate(taxtable.ranks)}
    def full_lineage(node):
        rank_iter = reversed(taxtable.ranks)
        for n in reversed(node.lineage()):
            n_order = rank_order[n.rank]
            yield n, list(itertools.takewhile(lambda r: rank_order[r] >= n_order, rank_iter))

    def multiclass_rows(placement_id, seq, taxa):
        ret = collections.defaultdict(float)
        likelihood = 1. / len(taxa)
        for taxon in taxa:
            for node, want_ranks in full_lineage(taxtable.get_node(taxon)):
                for want_rank in want_ranks:
                    ret[placement_id, seq, want_rank, node.rank, node.tax_id] += likelihood
        for k, v in sorted(ret.items()):
            yield k + (v,)

    curs = args.classification_db.cursor()
    curs.execute('INSERT INTO runs (params) VALUES (?)', (' '.join(sys.argv),))
    run_id = curs.lastrowid

    log.info('inserting classifications')
    for (name, mass), rows in group_by_name_and_mass(csv.DictReader(args.classifications)):
        curs.execute('INSERT INTO placements (classifier, run_id) VALUES ("csv", ?)', (run_id,))
        placement_id = curs.lastrowid
        curs.execute(
            'INSERT INTO placement_names (placement_id, name, origin, mass) VALUES (?, ?, ?, ?)',
            (placement_id, name, args.classifications.name, mass))
        taxa = [row['tax_id'] for row in rows]
        curs.executemany('INSERT INTO multiclass VALUES (?, ?, ?, ?, ?, ?)',
                         multiclass_rows(placement_id, name, taxa))

    log.info('cleaning up `multiclass` table')
    curs.execute("""
        CREATE TEMPORARY TABLE duplicates AS
          SELECT name
            FROM multiclass
                 JOIN placements USING (placement_id)
           GROUP BY name
          HAVING SUM(run_id = ?)
                 AND COUNT(DISTINCT run_id) > 1
    """, (run_id,))
    curs.execute("""
        DELETE FROM multiclass
         WHERE (SELECT run_id
                  FROM placements p
                 WHERE p.placement_id = multiclass.placement_id) <> ?
           AND name IN (SELECT name
                          FROM duplicates)
    """, (run_id,))

    args.classification_db.commit()
Пример #21
0
def main():
    logging.basicConfig(level=logging.INFO,
                        format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser(
        description="Add classifications to a database.")
    parser.add_argument('refpkg',
                        type=Refpkg,
                        help="refpkg containing input taxa")
    parser.add_argument('classification_db',
                        type=sqlite3.connect,
                        help="output sqlite database")
    parser.add_argument('classifications',
                        type=argparse.FileType('r'),
                        nargs='?',
                        default=sys.stdin,
                        help="input query sequences")

    args = parser.parse_args()

    log.info('loading taxonomy')
    taxtable = TaxNode.from_taxtable(
        args.refpkg.open_resource('taxonomy', 'rU'))
    rank_order = {rank: e for e, rank in enumerate(taxtable.ranks)}

    def full_lineage(node):
        rank_iter = reversed(taxtable.ranks)
        for n in reversed(node.lineage()):
            n_order = rank_order[n.rank]
            yield n, list(
                itertools.takewhile(lambda r: rank_order[r] >= n_order,
                                    rank_iter))

    def multiclass_rows(placement_id, seq, taxa):
        ret = collections.defaultdict(float)
        likelihood = 1. / len(taxa)
        for taxon in taxa:
            for node, want_ranks in full_lineage(taxtable.get_node(taxon)):
                for want_rank in want_ranks:
                    ret[placement_id, seq, want_rank, node.rank,
                        node.tax_id] += likelihood
        for k, v in sorted(ret.items()):
            yield k + (v, )

    curs = args.classification_db.cursor()
    curs.execute('INSERT INTO runs (params) VALUES (?)',
                 (' '.join(sys.argv), ))
    run_id = curs.lastrowid

    log.info('inserting classifications')
    for (name, mass), rows in group_by_name_and_mass(
            csv.DictReader(args.classifications)):
        curs.execute(
            'INSERT INTO placements (classifier, run_id) VALUES ("csv", ?)',
            (run_id, ))
        placement_id = curs.lastrowid
        curs.execute(
            'INSERT INTO placement_names (placement_id, name, origin, mass) VALUES (?, ?, ?, ?)',
            (placement_id, name, args.classifications.name, mass))
        taxa = [row['tax_id'] for row in rows]
        curs.executemany('INSERT INTO multiclass VALUES (?, ?, ?, ?, ?, ?)',
                         multiclass_rows(placement_id, name, taxa))

    log.info('cleaning up `multiclass` table')
    curs.execute(
        """
        CREATE TEMPORARY TABLE duplicates AS
          SELECT name
            FROM multiclass
                 JOIN placements USING (placement_id)
           GROUP BY name
          HAVING SUM(run_id = ?)
                 AND COUNT(DISTINCT run_id) > 1
    """, (run_id, ))
    curs.execute(
        """
        DELETE FROM multiclass
         WHERE (SELECT run_id
                  FROM placements p
                 WHERE p.placement_id = multiclass.placement_id) <> ?
           AND name IN (SELECT name
                          FROM duplicates)
    """, (run_id, ))

    args.classification_db.commit()