Example #1
0
    def check_identical_ranks(self):
        if not self.dupseq_sets:
            self.check_identical_seqs()
        self.merged_ranks = {}
        for dup_ids in self.dupseq_sets:
            if len(dup_ids) > 1:
                duprank_map = {}
                for seq_name in dup_ids:
                    rank_id = self.taxonomy.seq_rank_id(seq_name)
                    duprank_map[rank_id] = duprank_map.get(rank_id, 0) + 1
                if len(duprank_map) > 1 and self.cfg.debug:
                    self.cfg.log.debug("Ranks sharing duplicates: %s\n", str(duprank_map))
                dup_ranks = []
                for rank_id, count in duprank_map.iteritems():
                    if count > self.cfg.taxa_ident_thres * self.taxonomy.get_rank_seq_count(rank_id):
                      dup_ranks += [rank_id]
                if len(dup_ranks) > 1:
                    prefix = "__TAXCLUSTER%d__" % (len(self.merged_ranks) + 1)
                    merged_rank_id = self.taxonomy.merge_ranks(dup_ranks, prefix)
                    self.merged_ranks[merged_rank_id] = dup_ranks

        if self.verbose:
            merged_count = 0
            for merged_rank_id, dup_ranks in self.merged_ranks.iteritems():
                dup_ranks_str = "\n".join([Taxonomy.rank_uid_to_lineage_str(rank_id) for rank_id in dup_ranks])
                self.cfg.log.warning("\nWARNING: Following taxa share >%.0f%% indentical sequences und thus considered indistinguishable:\n%s", self.cfg.taxa_ident_thres * 100, dup_ranks_str)
                merged_rank_str = Taxonomy.rank_uid_to_lineage_str(merged_rank_id)
                self.cfg.log.warning("For the purpose of mislabels identification, they were merged into one taxon:\n%s\n", merged_rank_str)
                merged_count += len(dup_ranks)
            
            if merged_count > 0:
                self.cfg.log.warning("WARNING: %d indistinguishable taxa have been merged into %d clusters.\n", merged_count, len(self.merged_ranks))

        return self.merged_ranks
Example #2
0
    def get_parent_tip_ranks(self, tax_tree):
        rank_tips = {}
        rank_parent = {}
        for node in tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue
                
            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue
            
            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size-4:
                continue

#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks
            
        return rank_parent, rank_tips
Example #3
0
    def get_parent_tip_ranks(self, tax_tree):
        rank_tips = {}
        rank_parent = {}
        for node in tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue

            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue

            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size - 4:
                continue


#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks

        return rank_parent, rank_tips
Example #4
0
 def test_merge_ranks(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     merge_sids = ["UnpC[Ceti]", "UnpSomer,"]
     rank_ids = [tax.seq_rank_id(sid) for sid in merge_sids]
     new_rank_id = tax.merge_ranks(rank_ids)
     self.assertEqual(merge_sids, tax.get_rank_seqs(new_rank_id))
     for sid in merge_sids:
         self.assertEqual(tax.seq_rank_id(sid), new_rank_id)
Example #5
0
 def test_normalize_seq_ids(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     self.assertTrue("UnpC[Ceti]" in tax.seq_ranks_map)
     self.assertTrue("UnpSomer," in tax.seq_ranks_map)
     tax.normalize_seq_ids()
     self.assertFalse("UnpC[Ceti]" in tax.seq_ranks_map)
     self.assertTrue("UnpC_Ceti_" in tax.seq_ranks_map)
     self.assertFalse("UnpSomer," in tax.seq_ranks_map)
     self.assertTrue("UnpSomer_" in tax.seq_ranks_map)
Example #6
0
 def test_normalize_seq_ids(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     self.assertTrue("UnpC[Ceti]" in tax.seq_ranks_map)
     self.assertTrue("UnpSomer," in tax.seq_ranks_map)
     tax.normalize_seq_ids()
     self.assertFalse("UnpC[Ceti]" in tax.seq_ranks_map)
     self.assertTrue("UnpC_Ceti_" in tax.seq_ranks_map)
     self.assertFalse("UnpSomer," in tax.seq_ranks_map)
     self.assertTrue("UnpSomer_" in tax.seq_ranks_map)
Example #7
0
    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")

        #validate input json format
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error(
                "ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()

        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()

        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg,
                                                 self.bid_taxonomy_map,
                                                 self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy,
                                            self.tax_tree)

        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)

        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
        #        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
Example #8
0
 def test_taxtree_builder(self):
     cfg = EpacConfig()
     testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles")
     tax_fname = os.path.join(testfile_dir, "test.tax")
     tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname)
     tree_fname = os.path.join(testfile_dir, "taxtree.nw")
     expected_tree = Tree(tree_fname, format=8)
     tb = TaxTreeBuilder(cfg, tax)
     tax_tree, seq_ids = tb.build()
     self.assertEqual(seq_ids, tax.get_map().keys())
     self.assertEqual(tax_tree.write(format=8), expected_tree.write(format=8))
Example #9
0
 def test_taxtree_builder(self):
     cfg = EpacConfig()
     testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles")
     tax_fname = os.path.join(testfile_dir, "test.tax")
     tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname)
     tree_fname = os.path.join(testfile_dir, "taxtree.nw")
     expected_tree = Tree(tree_fname, format=8)
     tb = TaxTreeBuilder(cfg, tax)
     tax_tree, seq_ids = tb.build()
     self.assertEqual(seq_ids, tax.get_map().keys())
     self.assertEqual(tax_tree.write(format=8), expected_tree.write(format=8))
Example #10
0
    def setUp(self):
        self.testfile_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "testfiles")
        self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax")
        self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname)
        tax_map = self.taxonomy.get_map()
        cfg = EpacConfig()
        self.taxtree_helper = TaxTreeHelper(cfg, tax_map)

        outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw")
        self.expected_outgr = Tree(outgr_fname)
 def mis_rec_to_string(self, mis_rec):
     lvl = mis_rec['orig_level']
     output = mis_rec['name'] + "\t"
     output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
         mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
     output += Taxonomy.lineage_str(mis_rec['orig_ranks']) + "\t"
     output += Taxonomy.lineage_str(mis_rec['ranks']) + "\t"
     output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
     if 'rank_conf' in mis_rec:
         output += "\t%.3f" % mis_rec['rank_conf']
     return output
Example #12
0
 def test_normalize_rank_names(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     ranks = tax.get_seq_ranks("UpbRectu")
     self.assertEqual(ranks[0], "[Bacteria]")
     self.assertEqual(ranks[1], "'Firmicutes'")
     self.assertEqual(ranks[2], "Clostridia(1)")
     corr_ranks = tax.normalize_rank_names()
     self.assertEqual(len(corr_ranks), 3)
     ranks = tax.get_seq_ranks("UpbRectu")
     self.assertEqual(ranks[0], "_Bacteria_")
     self.assertEqual(ranks[1], "_Firmicutes_")
     self.assertEqual(ranks[2], "Clostridia_1_")
Example #13
0
 def test_normalize_rank_names(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     ranks = tax.get_seq_ranks("UpbRectu")
     self.assertEqual(ranks[0], "[Bacteria]")
     self.assertEqual(ranks[1], "'Firmicutes'")
     self.assertEqual(ranks[2], "Clostridia(1)")
     corr_ranks = tax.normalize_rank_names()
     self.assertEqual(len(corr_ranks), 3)
     ranks = tax.get_seq_ranks("UpbRectu")
     self.assertEqual(ranks[0], "_Bacteria_")
     self.assertEqual(ranks[1], "_Firmicutes_")
     self.assertEqual(ranks[2], "Clostridia_1_")
Example #14
0
    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None

        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
        #        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec[
                'real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks), len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[
                        rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(
                    orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(
                    real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec[
                    'real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]

        if mis_rec:
            self.mislabels.append(mis_rec)

        return mis_rec
Example #15
0
 def build_ref_tree(self):
     self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname)
     self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname)
     self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname)
     self.load_alignment()
     self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
     self.validate_taxonomy()
     self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count())
     self.build_multif_tree()
     self.cfg.log.info("=====> Building the reference alignment ...\n")
     self.export_ref_alignment()
     self.export_ref_taxonomy()
     self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n")
     self.save_rooting()
     self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num)
     self.resolve_multif()
     self.load_reduced_refalign()
     self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n")
     self.epa_branch_labeling()
     self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n")
     self.epa_post_process()
     self.calc_node_heights()
     
     self.cfg.log.debug("\n==========> Checking branch labels ...")
     self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks()))
     self.cfg.log.debug("shared rank names after  training: %s\n", repr(self.mono_index()))
     
     self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname)
     self.write_json()
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:        
            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")

        placements = jp.get_placement()
        seq_count = 0
        for place in placements:
            seq_name = place["n"][0]
            
            # get original taxonomic label
            orig_ranks = self.get_orig_ranks(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
            # cross-check with higher rank mislabels
            if self.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2,len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        return seq_count    
    def label_bf_tree_with_ranks(self):
        """labeling inner tree nodes with taxonomic ranks"""
        if not self.bf_rooted_tree:
            raise AssertionError(
                "self.bf_rooted_tree is not set: TaxTreeHelper.set_bf_unrooted_tree() must be called before!"
            )

        for node in self.bf_rooted_tree.traverse("postorder"):
            if node.is_leaf():
                seq_ranks = self.origin_taxonomy[node.name]
                rank_level = Taxonomy.lowest_assigned_rank_level(seq_ranks)
                node.add_feature("rank_level", rank_level)
                node.add_feature("ranks", seq_ranks)
                node.name += "__" + seq_ranks[rank_level]
            else:
                if len(node.children) != 2:
                    raise AssertionError("FATAL ERROR: tree is not bifurcating!")
                lchild = node.children[0]
                rchild = node.children[1]
                rank_level = min(lchild.rank_level, rchild.rank_level)
                while rank_level >= 0 and lchild.ranks[rank_level] != rchild.ranks[rank_level]:
                    rank_level -= 1
                node.add_feature("rank_level", rank_level)
                node_ranks = [Taxonomy.EMPTY_RANK] * 7
                if rank_level >= 0:
                    node_ranks[0 : rank_level + 1] = lchild.ranks[0 : rank_level + 1]
                    node.name = lchild.ranks[rank_level]
                else:
                    node.name = "Undefined"
                    if hasattr(node, "B") and self.cfg.verbose:
                        print "INFO: no taxonomic annotation for branch %s (reason: children belong to different kingdoms)" % node.B

                node.add_feature("ranks", node_ranks)

        self.tax_tree = self.bf_rooted_tree
Example #18
0
    def check_identical_ranks(self):
        if not self.dupseq_sets:
            self.check_identical_seqs()
        self.merged_ranks = {}
        for dup_ids in self.dupseq_sets:
            if len(dup_ids) > 1:
                duprank_map = {}
                for seq_name in dup_ids:
                    rank_id = self.taxonomy.seq_rank_id(seq_name)
                    duprank_map[rank_id] = duprank_map.get(rank_id, 0) + 1
                if len(duprank_map) > 1 and self.cfg.debug:
                    self.cfg.log.debug("Ranks sharing duplicates: %s\n",
                                       str(duprank_map))
                dup_ranks = []
                for rank_id, count in duprank_map.iteritems():
                    if count > self.cfg.taxa_ident_thres * self.taxonomy.get_rank_seq_count(
                            rank_id):
                        dup_ranks += [rank_id]
                if len(dup_ranks) > 1:
                    prefix = "__TAXCLUSTER%d__" % (len(self.merged_ranks) + 1)
                    merged_rank_id = self.taxonomy.merge_ranks(
                        dup_ranks, prefix)
                    self.merged_ranks[merged_rank_id] = dup_ranks

        if self.verbose:
            merged_count = 0
            for merged_rank_id, dup_ranks in self.merged_ranks.iteritems():
                dup_ranks_str = "\n".join([
                    Taxonomy.rank_uid_to_lineage_str(rank_id)
                    for rank_id in dup_ranks
                ])
                self.cfg.log.warning(
                    "\nWARNING: Following taxa share >%.0f%% indentical sequences und thus considered indistinguishable:\n%s",
                    self.cfg.taxa_ident_thres * 100, dup_ranks_str)
                merged_rank_str = Taxonomy.rank_uid_to_lineage_str(
                    merged_rank_id)
                self.cfg.log.warning(
                    "For the purpose of mislabels identification, they were merged into one taxon:\n%s\n",
                    merged_rank_str)
                merged_count += len(dup_ranks)

            if merged_count > 0:
                self.cfg.log.warning(
                    "WARNING: %d indistinguishable taxa have been merged into %d clusters.\n",
                    merged_count, len(self.merged_ranks))

        return self.merged_ranks
Example #19
0
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
        #        if self.jplace_fname:
        #            jp = EpaJsonParser(self.jplace_fname)
        #        else:

        #create file with subtrees
        rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree)

        subtree_list = list(rank_tips.items())
        if len(subtree_list) == 0:
            return 0

        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))

        jp_list = self.raxml.run_epa(job_name,
                                     self.refalign_fname,
                                     self.reftree_fname,
                                     self.optmod_fname,
                                     mode="l1o_subtree",
                                     subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.tax_code.guess_rank_level_name(
                    orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
                #                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks,
                                                     ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count
 def get_orig_ranks(self, seq_name):
     nodes = self.tax_tree.get_leaves_by_name(seq_name)
     if len(nodes) != 1:
         print "FATAL ERROR: Sequence %s is not found in the taxonomic tree, or is present more than once!" % seq_name
         sys.exit()
     seq_node = nodes[0]
     orig_ranks = Taxonomy.split_rank_uid(seq_node.up.name)
     return orig_ranks
Example #21
0
 def setUp(self):
     test_dir = os.path.dirname(os.path.abspath(__file__))
     self.tax_fname = os.path.join(test_dir, "test.tax")
     self.PREFIXED_TAX_DICT = {}
     with open(self.tax_fname, "w") as outf:
         for sid, ranks in self.TAX_DICT.iteritems():
             outf.write("%s\t%s\n" % (sid, ";".join(ranks)))
             self.PREFIXED_TAX_DICT[EpacConfig.REF_SEQ_PREFIX+sid] = ranks
     self.taxonomy = Taxonomy("", self.tax_fname)
Example #22
0
    def build_ref_tree(self):
        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n",
                          self.cfg.taxonomy_fname)
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_fname=self.cfg.taxonomy_fname)
        self.cfg.log.info(
            "==> Loading reference alignment from file: %s ...\n",
            self.cfg.align_fname)
        self.load_alignment()
        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
        self.validate_taxonomy()
        self.cfg.log.info(
            "====> Building a multifurcating tree from taxonomy with %d seqs ...\n",
            self.taxonomy.seq_count())
        self.build_multif_tree()
        self.cfg.log.info("=====> Building the reference alignment ...\n")
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        self.cfg.log.info(
            "======> Saving the outgroup for later re-rooting ...\n")
        self.save_rooting()
        self.cfg.log.info(
            "=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n"
            % self.cfg.rep_num)
        self.resolve_multif()
        self.load_reduced_refalign()
        self.cfg.log.info(
            "========> Calling RAxML-EPA to obtain branch labels ...\n")
        self.epa_branch_labeling()
        self.cfg.log.info(
            "=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n"
        )
        self.epa_post_process()
        self.calc_node_heights()

        self.cfg.log.debug("\n==========> Checking branch labels ...")
        self.cfg.log.debug("shared rank names before training: %s",
                           repr(self.taxonomy.get_common_ranks()))
        self.cfg.log.debug("shared rank names after  training: %s\n",
                           repr(self.mono_index()))

        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" %
                          self.cfg.refjson_fname)
        self.write_json()
Example #23
0
    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None
        
        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
#        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks),len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]
    
        if mis_rec:
            self.mislabels.append(mis_rec)
            
        return mis_rec
Example #24
0
    def setUp(self):
        self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles")
        self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax")
        self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname)   
        tax_map = self.taxonomy.get_map()    
        cfg = EpacConfig()
        self.taxtree_helper = TaxTreeHelper(cfg, tax_map)

        outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw")
        self.expected_outgr = Tree(outgr_fname)
Example #25
0
 def mis_rec_to_string(self, mis_rec):
     lvl = mis_rec['orig_level']
     uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name']))
     uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks'])
     uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
     output = uncorr_name + "\t"
   
     if lvl >= 0:
         output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
             uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl])
     else:
         output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
             "NA", "NA", mis_rec['lws'][0])
     
     output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
     output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
     output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
     if 'rank_conf' in mis_rec:
         output += "\t%.3f" % mis_rec['rank_conf']
     return output
Example #26
0
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
#        if self.jplace_fname:
#            jp = EpaJsonParser(self.jplace_fname)
#        else:        

        #create file with subtrees
        rank_parent, rank_tips = get_parent_tip_ranks(self.tax_tree)

        subtree_list = rank_tips.items()
        if len(subtree_list) == 0:
            return 0
            
        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))
        
        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
            mode="l1o_subtree", subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count    
Example #27
0
 def test_merge_ranks(self):
     tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map)
     merge_sids = ["UnpC[Ceti]", "UnpSomer,"]
     rank_ids = [tax.seq_rank_id(sid) for sid in merge_sids]
     new_rank_id = tax.merge_ranks(rank_ids)
     self.assertEqual(merge_sids, tax.get_rank_seqs(new_rank_id))
     for sid in merge_sids:
         self.assertEqual(tax.seq_rank_id(sid), new_rank_id)
Example #28
0
    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        uncorr_name = EpacConfig.strip_ref_prefix(
            self.refjson.get_uncorr_seqid(mis_rec['name']))
        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(
            mis_rec['orig_ranks'])
        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
        output = uncorr_name + "\t"

        if lvl >= 0:
            output += "%s\t%s\t%s\t%.3f\t" % (
                mis_rec['level_name'], uncorr_orig_ranks[lvl],
                uncorr_ranks[lvl], mis_rec['lws'][lvl])
        else:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA",
                                              "NA", mis_rec['lws'][0])

        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output
Example #29
0
    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")
            
        #validate input json format 
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()
        
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()
            
        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree)
        
        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)
        
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
#        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
Example #30
0
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()
                
            config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask)
        else:        
            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq")
        
        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]
            
            # get original taxonomic label
#            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks =  self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)
            
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2,len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)
            
        return seq_count    
Example #31
0
 def test_subst_ranks(self):
     testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles")
     tax_fname = os.path.join(testfile_dir, "test.tax")
     tax = Taxonomy("", tax_fname)
     old_ranks = tax.get_seq_ranks("WgeSangu")
     self.assertEqual(old_ranks[-2], 'Sneathia')
     syn_map = {'Sneathia' : 'Sebaldella'}
     tax.subst_synonyms(syn_map)
     new_ranks = tax.get_seq_ranks("WgeSangu")
     self.assertEqual(old_ranks[-2], 'Sebaldella')
Example #32
0
    def setUp(self):
        cfg = EpacTrainerConfig()
        cfg.debug = True
        testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                    "testfiles")
        tax_fname = os.path.join(testfile_dir, "test.tax")
        phy_fname = os.path.join(testfile_dir, "test.phy")
        tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname)
        seqs = SeqGroup(sequences=phy_fname, format="phylip")
        self.inval = InputValidator(cfg, tax, seqs, False)

        self.expected_mis_ids = ["Missing1", "Missing2"]
        self.expected_dups = ["DupSeq(01)", "DupSeq02"]
        self.expected_merges = [
            self.inval.taxonomy.seq_rank_id(sid) for sid in self.expected_dups
        ]
Example #33
0
class TaxTreeHelperTests(unittest.TestCase):
    def setUp(self):
        self.testfile_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "testfiles")
        self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax")
        self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname)
        tax_map = self.taxonomy.get_map()
        cfg = EpacConfig()
        self.taxtree_helper = TaxTreeHelper(cfg, tax_map)

        outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw")
        self.expected_outgr = Tree(outgr_fname)

    def tearDown(self):
        self.taxonomy = None
        self.taxtree_helper = None

    def test_outgroup(self):
        mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw")
        mfu_tree = Tree(mfu_tree_fname)
        self.taxtree_helper.set_mf_rooted_tree(mfu_tree)
        outgr = self.taxtree_helper.get_outgroup()
        self.assertEqual(outgr.get_leaf_names(),
                         self.expected_outgr.get_leaf_names())

    def test_branch_labeling(self):
        bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw")
        bfu_tree = Tree(bfu_tree_fname)
        map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt")
        self.expected_map = {}
        with open(map_fname) as inf:
            for line in inf:
                bid, rank_id, rdiff, brlen = line.strip().split("\t")
                self.expected_map[bid] = (rank_id, int(rdiff), float(brlen))

        self.taxtree_helper.set_outgroup(self.expected_outgr)
        self.taxtree_helper.set_bf_unrooted_tree(bfu_tree)
        bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map()
        self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3)
        for bid in self.expected_map.iterkeys():
            e_rec = self.expected_map[bid]
            rec = bid_tax_map[bid]
            self.assertEqual(e_rec[0], rec[0])
            self.assertEqual(e_rec[1], rec[1])
            self.assertAlmostEqual(e_rec[2], rec[2], 6)
Example #34
0
class TaxTreeHelperTests(unittest.TestCase):

    def setUp(self):
        self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles")
        self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax")
        self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname)   
        tax_map = self.taxonomy.get_map()    
        cfg = EpacConfig()
        self.taxtree_helper = TaxTreeHelper(cfg, tax_map)

        outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw")
        self.expected_outgr = Tree(outgr_fname)
    
    def tearDown(self):
        self.taxonomy = None
        self.taxtree_helper = None
   
    def test_outgroup(self):
        mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw")
        mfu_tree = Tree(mfu_tree_fname)
        self.taxtree_helper.set_mf_rooted_tree(mfu_tree)
        outgr = self.taxtree_helper.get_outgroup()  
        self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names())

    def test_branch_labeling(self):
        bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw")
        bfu_tree = Tree(bfu_tree_fname)
        map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt")
        self.expected_map = {}
        with open(map_fname) as inf:
            for line in inf:
                bid, rank_id, rdiff, brlen = line.strip().split("\t")
                self.expected_map[bid] = (rank_id, int(rdiff), float(brlen))
                
        self.taxtree_helper.set_outgroup(self.expected_outgr)
        self.taxtree_helper.set_bf_unrooted_tree(bfu_tree)
        bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map()
        self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3)
        for bid in self.expected_map.iterkeys():
            e_rec = self.expected_map[bid]
            rec = bid_tax_map[bid]
            self.assertEqual(e_rec[0], rec[0])
            self.assertEqual(e_rec[1], rec[1])
            self.assertAlmostEqual(e_rec[2], rec[2], 6)
Example #35
0
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()

            config.log.debug("Loaded %d placements from %s\n", len(placements),
                             jplace_fmask)
        else:
            jp = self.raxml.run_epa(job_name,
                                    self.refalign_fname,
                                    self.reftree_fname,
                                    self.optmod_fname,
                                    mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname(
                    "%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name,
                                           out_jplace_fname,
                                           move=True,
                                           mode="l1o_seq")

        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            #            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks,
                                                lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2, len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf,
                                        self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)

        return seq_count
Example #36
0
class RefTreeBuilder:
    def __init__(self, config):
        self.cfg = config
        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
        self.raxml_wrapper = RaxmlWrapper(config)

        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")

    def load_alignment(self):
        in_file = self.cfg.align_fname
        self.input_seqs = None
        formats = [
            "fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"
        ]
        for fmt in formats:
            try:
                self.input_seqs = SeqGroup(sequences=in_file, format=fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.input_seqs == None:
            self.cfg.exit_user_error(
                "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
                % in_file)

    def validate_taxonomy(self):
        self.input_validator = InputValidator(self.cfg, self.taxonomy,
                                              self.input_seqs)
        self.input_validator.validate()

    def build_multif_tree(self):
        c = self.cfg

        tb = TaxTreeBuilder(c, self.taxonomy)
        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf,
                            c.reftree_clades_to_include,
                            c.reftree_clades_to_ignore)
        self.reftree_ids = frozenset(ids)
        self.reftree_size = len(ids)
        self.reftree_multif = t

        # IMPORTANT: select GAMMA or CAT model based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)

        if self.cfg.debug:
            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
            # list of sequence ids which comprise the reference tree
            with open(refseq_fname, "w") as f:
                for sid in ids:
                    f.write("%s\n" % sid)

            # original tree with taxonomic ranks as internal node labels
            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
            t.write(outfile=reftax_fname, format=8)
        #    t.show()

    def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""

        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in self.input_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.input_validator.corr_seqid:
                    seq_name = self.input_validator.corr_seqid[seq_name]
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")

        # we do not need the original alignment anymore, so free its memory
        self.input_seqs = None

    def export_ref_taxonomy(self):
        self.taxonomy_map = {}

        for sid, ranks in self.taxonomy.iteritems():
            if sid in self.reftree_ids:
                self.taxonomy_map[sid] = ranks

        if self.cfg.debug:
            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
            with open(tax_fname, "w") as fout:
                for sid, ranks in self.taxonomy_map.iteritems():
                    ranks_str = self.taxonomy.seq_lineage_str(sid)
                    fout.write(sid + "\t" + ranks_str + "\n")

    def save_rooting(self):
        rt = self.reftree_multif

        tax_map = self.taxonomy.get_map()
        self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map)
        self.taxtree_helper.set_mf_rooted_tree(rt)
        outgr = self.taxtree_helper.get_outgroup()
        outgr_size = len(outgr.get_leaves())
        outgr.write(outfile=self.outgr_fname, format=9)
        self.reftree_outgroup = outgr
        self.cfg.log.debug(
            "Outgroup for rooting was saved to: %s, outgroup size: %d",
            self.outgr_fname, outgr_size)

        # remove unifurcation at the root
        if len(rt.children) == 1:
            rt = rt.children[0]

        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
        rt.write(outfile=self.reftree_mfu_fname, format=9)

    # RAxML call to convert multifurcating tree to the strictly bifurcating one
    def resolve_multif(self):
        self.cfg.log.debug("\nReducing the alignment: \n")
        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(
            self.refalign_fname)

        self.cfg.log.debug("\nConstrained ML inference: \n")
        raxml_params = [
            "-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname,
            "--no-seq-check", "-N",
            str(self.cfg.rep_num)
        ]
        if self.cfg.mfresolv_method == "fast":
            raxml_params += ["-D"]
        elif self.cfg.mfresolv_method == "ultrafast":
            raxml_params += ["-f", "e"]
        if self.cfg.restart and self.raxml_wrapper.result_exists(
                self.mfresolv_job_name):
            self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(
                self.mfresolv_job_name)
            self.cfg.log.debug(
                "\nUsing existing ML tree found in: %s\n",
                self.raxml_wrapper.result_fname(self.mfresolv_job_name))
        else:
            self.invocation_raxml_multif = self.raxml_wrapper.run(
                self.mfresolv_job_name, raxml_params)
            #            self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num)
            if self.cfg.mfresolv_method == "ultrafast":
                self.raxml_wrapper.copy_result_tree(
                    self.mfresolv_job_name,
                    self.raxml_wrapper.besttree_fname(self.mfresolv_job_name))

        if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name):
            if not self.cfg.reopt_model:
                self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name,
                                                  self.reftree_bfu_fname)
                self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name,
                                                      self.optmod_fname)
                self.invocation_raxml_optmod = ""
                job_name = self.mfresolv_job_name
            else:
                bfu_fname = self.raxml_wrapper.besttree_fname(
                    self.mfresolv_job_name)
                job_name = self.optmod_job_name

                # RAxML call to optimize model parameters and write them down to the binary model file
                self.cfg.log.debug("\nOptimizing model parameters: \n")
                raxml_params = [
                    "-f", "e", "-s", self.reduced_refalign_fname, "-t",
                    bfu_fname, "--no-seq-check"
                ]
                if self.cfg.raxml_model.startswith(
                        "GTRCAT") and not self.cfg.compress_patterns:
                    raxml_params += ["-H"]
                if self.cfg.restart and self.raxml_wrapper.result_exists(
                        self.optmod_job_name):
                    self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(
                        self.optmod_job_name)
                    self.cfg.log.debug(
                        "\nUsing existing optimized tree and parameters found in: %s\n",
                        self.raxml_wrapper.result_fname(self.optmod_job_name))
                else:
                    self.invocation_raxml_optmod = self.raxml_wrapper.run(
                        self.optmod_job_name, raxml_params)
                if self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.raxml_wrapper.copy_result_tree(
                        self.optmod_job_name, self.reftree_bfu_fname)
                    self.raxml_wrapper.copy_optmod_params(
                        self.optmod_job_name, self.optmod_fname)
                else:
                    errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \
                            % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
                    self.cfg.exit_fatal_error(errmsg)

            if self.cfg.raxml_model.startswith("GTRCAT"):
                mod_name = "CAT"
            else:
                mod_name = "GAMMA"
            self.reftree_loglh = self.raxml_wrapper.get_tree_lh(
                job_name, mod_name)
            self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" %
                               (mod_name, self.reftree_loglh))

        else:
            errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def load_reduced_refalign(self):
        formats = ["fasta", "phylip_relaxed"]
        for fmt in formats:
            try:
                self.reduced_refalign_seqs = SeqGroup(
                    sequences=self.reduced_refalign_fname, format=fmt)
                break
            except:
                pass
        if self.reduced_refalign_seqs == None:
            errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
            self.cfg.exit_fatal_error(errmsg)

    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks
    def epa_branch_labeling(self):
        # create alignment with dummy query seq
        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
        self.reduced_refalign_seqs.write(format="fasta",
                                         outfile=self.lblalign_fname)

        with open(self.lblalign_fname, "a") as fout:
            fout.write(">" + "DUMMY131313" + "\n")
            fout.write("A" * self.refalign_width + "\n")

        # TODO always load model regardless of the config file settings?
        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name,
                                                self.lblalign_fname,
                                                self.reftree_bfu_fname,
                                                self.optmod_fname,
                                                mode="epa_mp")
        self.reftree_lbl_str = epa_result.get_std_newick_tree()
        self.raxml_version = epa_result.get_raxml_version()
        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()

        if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):
            errmsg = "RAxML EPA run failed, please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def epa_post_process(self):
        lbl_tree = Tree(self.reftree_lbl_str)
        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
        self.reftree_tax = self.taxtree_helper.get_tax_tree()
        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()

        if self.cfg.debug:
            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)
            with open(self.reftree_lbl_fname, "w") as outf:
                outf.write(self.reftree_lbl_str)
            with open(self.brmap_fname, "w") as outf:
                for bid, br_rec in self.bid_ranks_map.iteritems():
                    outf.write("%s\t%s\t%d\t%f\n" %
                               (bid, br_rec[0], br_rec[1], br_rec[2]))

    def calc_node_heights(self):
        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
           Algorithm is as follows:
           Tip node or node resolved to Species level: height = 1 
           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1 
         """
        nh_map = {}
        dummy_added = False
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root():
                if not hasattr(node, "B"):
                    # In a rooted tree, there is always one more node/branch than in unrooted one
                    # That's why one branch will be always not EPA-labelled after the rooting
                    if not dummy_added:
                        node.B = "DDD"
                        dummy_added = True
                        species_rank = Taxonomy.EMPTY_RANK
                    else:
                        errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
                        self.cfg.exit_fatal_error(errmsg)
                else:
                    species_rank = self.bid_ranks_map[node.B][-1]
                bid = node.B
                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
                    nh_map[bid] = 1
                else:
                    lchild = node.children[0]
                    rchild = node.children[1]
                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1

        # remove heights for dummy nodes, since there won't be any placements on them
        if dummy_added:
            del nh_map["DDD"]

        self.node_height_map = nh_map

    def __get_all_rank_names(self, root):
        rnames = set([])
        for node in root.traverse("postorder"):
            ranks = node.ranks
            for rk in ranks:
                rnames.add(rk)
        return rnames

    def mono_index(self):
        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
        children = self.reftree_tax.children
        if len(children) == 1:
            while len(children) == 1:
                children = children[0].children
        if len(children) == 2:
            left = children[0]
            right = children[1]
            lset = self.__get_all_rank_names(left)
            rset = self.__get_all_rank_names(right)
            iset = lset & rset
            return iset
        else:
            print("Error: input tree not birfurcating")
            return set([])

    def build_hmm_profile(self, json_builder):
        print "Building the HMMER profile...\n"

        # this stupid workaround is needed because RAxML outputs the reduced
        # alignment in relaxed PHYLIP format, which is not supported by HMMER
        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
        self.reduced_refalign_seqs.write(outfile=refalign_fasta)

        hmm = hmmer(self.cfg, refalign_fasta)
        fprofile = hmm.build_hmm_profile()

        json_builder.set_hmm_profile(fprofile)

    def write_json(self):
        jw = RefJsonBuilder()

        jw.set_branch_tax_map(self.bid_ranks_map)
        jw.set_tree(self.reftree_lbl_str)
        jw.set_outgroup(self.reftree_outgroup)
        jw.set_ratehet_model(self.cfg.raxml_model)
        jw.set_tax_tree(self.reftree_multif)
        jw.set_pattern_compression(self.cfg.compress_patterns)
        jw.set_taxcode(self.cfg.taxcode_name)

        jw.set_merged_ranks_map(self.input_validator.merged_ranks)
        corr_ranks_reverse = dict(
            (reversed(item)
             for item in self.input_validator.corr_ranks.items()))
        jw.set_corr_ranks_map(corr_ranks_reverse)
        corr_seqid_reverse = dict(
            (reversed(item)
             for item in self.input_validator.corr_seqid.items()))
        jw.set_corr_seqid_map(corr_seqid_reverse)

        mdata = {
            "ref_tree_size": self.reftree_size,
            "ref_alignment_width": self.refalign_width,
            "raxml_version": self.raxml_version,
            "timestamp": str(datetime.datetime.now()),
            "invocation_epac": self.invocation_epac,
            "invocation_raxml_multif": self.invocation_raxml_multif,
            "invocation_raxml_optmod": self.invocation_raxml_optmod,
            "invocation_raxml_epalbl": self.invocation_raxml_epalbl,
            "reftree_loglh": self.reftree_loglh
        }
        jw.set_metadata(mdata)

        seqs = self.reduced_refalign_seqs.get_entries()
        jw.set_sequences(seqs)

        if not self.cfg.no_hmmer:
            self.build_hmm_profile(jw)

        orig_tax = self.taxonomy_map
        jw.set_origin_taxonomy(orig_tax)

        self.cfg.log.debug("Calculating the speciation rate...\n")
        tp = tree_param(tree=self.reftree_lbl_str, origin_taxonomy=orig_tax)
        jw.set_rate(tp.get_speciation_rate_fast())
        jw.set_nodes_height(self.node_height_map)

        jw.set_binary_model(self.optmod_fname)

        self.cfg.log.debug("Writing down the reference file...\n")
        jw.dump(self.cfg.refjson_fname)

    # top-level function to build a reference tree
    def build_ref_tree(self):
        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n",
                          self.cfg.taxonomy_fname)
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_fname=self.cfg.taxonomy_fname)
        self.cfg.log.info(
            "==> Loading reference alignment from file: %s ...\n",
            self.cfg.align_fname)
        self.load_alignment()
        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
        self.validate_taxonomy()
        self.cfg.log.info(
            "====> Building a multifurcating tree from taxonomy with %d seqs ...\n",
            self.taxonomy.seq_count())
        self.build_multif_tree()
        self.cfg.log.info("=====> Building the reference alignment ...\n")
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        self.cfg.log.info(
            "======> Saving the outgroup for later re-rooting ...\n")
        self.save_rooting()
        self.cfg.log.info(
            "=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n"
            % self.cfg.rep_num)
        self.resolve_multif()
        self.load_reduced_refalign()
        self.cfg.log.info(
            "========> Calling RAxML-EPA to obtain branch labels ...\n")
        self.epa_branch_labeling()
        self.cfg.log.info(
            "=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n"
        )
        self.epa_post_process()
        self.calc_node_heights()

        self.cfg.log.debug("\n==========> Checking branch labels ...")
        self.cfg.log.debug("shared rank names before training: %s",
                           repr(self.taxonomy.get_common_ranks()))
        self.cfg.log.debug("shared rank names after  training: %s\n",
                           repr(self.mono_index()))

        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" %
                          self.cfg.refjson_fname)
        self.write_json()
Example #37
0
 def test_rank_uid(self):
     tax = self.taxonomy
     for sid in tax.get_map().iterkeys():
         self.assertEqual(tax.get_seq_ranks(sid), Taxonomy.split_rank_uid(tax.seq_rank_id(sid)))
Example #38
0
 def test_load(self):
     self.assertEqual(self.TAX_DICT, self.taxonomy.seq_ranks_map)
     prefixed_tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname)
     self.assertEqual(self.PREFIXED_TAX_DICT, prefixed_tax.seq_ranks_map)
     prefixed_tax = None
Example #39
0
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
        #        if self.jplace_fname:
        #            jp = EpaJsonParser(self.jplace_fname)
        #        else:

        #create file with subtrees
        rank_tips = {}
        rank_parent = {}
        for node in self.tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue

            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue

            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size - 4:
                continue


#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks

        subtree_list = rank_tips.items()

        if len(subtree_list) == 0:
            return 0

        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))

        jp_list = self.raxml.run_epa(job_name,
                                     self.refalign_fname,
                                     self.reftree_fname,
                                     self.optmod_fname,
                                     mode="l1o_subtree",
                                     subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(
                    orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
                #                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks,
                                                     ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count
Example #40
0
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
#        if self.jplace_fname:
#            jp = EpaJsonParser(self.jplace_fname)
#        else:        

        #create file with subtrees
        rank_tips = {}
        rank_parent = {}
        for node in self.tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue
                
            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue
            
            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size-4:
                continue

#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks
                
        subtree_list = rank_tips.items()
        
        if len(subtree_list) == 0:
            return 0
            
        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))
        
        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
            mode="l1o_subtree", subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count    
Example #41
0
class LeaveOneTest:
    def __init__(self, config):
        self.cfg = config
        
        self.mis_fname = self.cfg.out_fname("%NAME%.mis")
        self.premis_fname = self.cfg.out_fname("%NAME%.premis")
        self.misrank_fname = self.cfg.out_fname("%NAME%.misrank")
        self.stats_fname = self.cfg.out_fname("%NAME%.stats")
        
        if os.path.isfile(self.mis_fname):
            print "\nERROR: Output file already exists: %s" % self.mis_fname
            print "Please specify a different job name using -n or remove old output files."
            self.cfg.exit_user_error()

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")

        self.mislabels = []
        self.mislabels_cnt = []
        self.rank_mislabels = []
        self.rank_mislabels_cnt = []
        self.misrank_conf_map = {}
        
    def write_bid_tax_map(self, bid_tax_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix)
            with open(bid_fname, "w") as outf:
              for bid, bid_rec in bid_tax_map.iteritems():
                outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2]));    

    def write_assignments(self, assign_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix)
            with open(assign_fname, "w") as outf:
                for seq_name in assign_map.iterkeys():
                    ranks, lws = assign_map[seq_name]
                    outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join(["%.3f" % l for l in lws])))

    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")
            
        #validate input json format 
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()
        
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()
            
        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree)
        
        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)
        
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
#        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        
    def run_epa_trainer(self):
        epa_trainer.run_trainer(self.cfg)

        if not os.path.isfile(self.cfg.refjson_fname):
            self.cfg.log.error("\nBuilding reference tree failed, see error messages above.")
            self.cfg.exit_fatal_error()
        
    def classify_seq(self, placement):
        edges = placement["p"]
        if len(edges) > 0:
            return self.classify_helper.classify_seq(edges)
        else:
            print "ERROR: no placements! something is definitely wrong!"

    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None
        
        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
#        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks),len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]
    
        if mis_rec:
            self.mislabels.append(mis_rec)
            
        return mis_rec
        
    def filter_mislabels(self):
        filtered_mis = []
        for i in range(len(self.mislabels)):
            if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff:
                filtered_mis.append(self.mislabels[i])
        
        self.mislabels = filtered_mis

    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks),len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = rank_name
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.rank_mislabels.append(mis_rec)
               
            return mis_rec
        else:
            return None                

    def mis_rec_to_string_old(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], 
            mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += ";".join(mis_rec['orig_ranks']) + "\n"
        output += ";".join(mis_rec['ranks']) + "\n"
        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
        return output

    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name']))
        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks'])
        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
        output = uncorr_name + "\t"
      
        if lvl >= 0:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
                uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl])
        else:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
                "NA", "NA", mis_rec['lws'][0])
        
        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output

    def sort_mislabels(self):
        self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
        for mis_rec in self.mislabels:
            real_lvl = mis_rec["real_level"]
            self.mislabels_cnt[real_lvl] += 1
        
        if self.cfg.ranktest:
            self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
            for mis_rec in self.rank_mislabels:
                real_lvl = mis_rec["real_level"]
                self.rank_mislabels_cnt[real_lvl] += 1
    
    def write_stats(self, toFile=False):
        self.cfg.log.info("Mislabeled sequences by rank:")
        seq_sum = 0
        rank_sum = 0
        stats = []
        for i in range(len(self.mislabels_cnt)):
            if i > 0:
                rname = self.tax_code.rank_level_name(i)[0].ljust(12)
            else:
                rname = "[NotIngroup]"
            if self.mislabels_cnt[i] > 0:
                seq_sum += self.mislabels_cnt[i]
#                    output = "%s:\t%d" % (rname, seq_sum)
                output = "%s:\t%d" % (rname, self.mislabels_cnt[i])
                if self.cfg.ranktest:
                    rank_sum += self.rank_mislabels_cnt[i]
                    output += "\t%d" % rank_sum
                self.cfg.log.info(output) 
                stats.append(output)

        if toFile:
            with open(self.stats_fname, "w") as fo_stat:
                for line in stats:
                    fo_stat.write(line + "\n")
    
    def write_mislabels(self, final=True):
        if final:
            out_fname = self.mis_fname
        else:
            out_fname = self.premis_fname
        
        with open(out_fname, "w") as fo_all:
            fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
            if self.cfg.ranktest:
                fields += ["HigherRankMisplacedConfidence"]
            header = ";" + "\t".join(fields) + "\n"
            fo_all.write(header)
            if self.cfg.verbose and len(self.mislabels) > 0 and final:
                print "Mislabeled sequences:\n"
                print header 
            for mis_rec in self.mislabels:
                output = self.mis_rec_to_string(mis_rec)  + "\n"
                fo_all.write(output)
                if self.cfg.verbose and final:
                    print(output) 
                    
        if not final:
            return

        if self.cfg.ranktest:
            with open(self.misrank_fname, "w") as fo_all:
                fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
                header = ";" + "\t".join(fields)  + "\n"
                fo_all.write(header)
                if self.cfg.verbose  and len(self.rank_mislabels) > 0:
                    print "\nMislabeled higher ranks:\n"
                    print header 
                for mis_rec in self.rank_mislabels:
                    output = self.mis_rec_to_string(mis_rec) + "\n"
                    fo_all.write(output)
                    if self.cfg.verbose:
                        print(output) 
                        
        self.write_stats()
   
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
#        if self.jplace_fname:
#            jp = EpaJsonParser(self.jplace_fname)
#        else:        

        #create file with subtrees
        rank_tips = {}
        rank_parent = {}
        for node in self.tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue
                
            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue
            
            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size-4:
                continue

#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks
                
        subtree_list = rank_tips.items()
        
        if len(subtree_list) == 0:
            return 0
            
        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))
        
        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
            mode="l1o_subtree", subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count    
        
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()
                
            config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask)
        else:        
            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq")
        
        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]
            
            # get original taxonomic label
#            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks =  self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)
            
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2,len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)
            
        return seq_count    
        
    def run_final_epa_test(self):
        self.reftree_outgroup = self.refjson.get_outgroup()

        tmp_reftree = self.reftree.copy(method="newick") 
        name2refnode = {}
        for leaf in tmp_reftree.iter_leaves():
            name2refnode[leaf.name] = leaf        

        tmp_taxtree = self.tax_tree.copy(method="newick") 
        name2taxnode = {}
        for leaf in tmp_taxtree.iter_leaves():
            name2taxnode[leaf.name] = leaf        

        for mis_rec in self.mislabels:
            rname = mis_rec['name']
#            rname = EpacConfig.REF_SEQ_PREFIX + name

            if rname in name2refnode:
                name2refnode[rname].delete()
            else:
                print "Node not found in the reference tree: %s" % rname

            if rname in name2taxnode:
                name2taxnode[rname].delete()
            else:
                print "Node not found in the taxonomic tree: %s" % rname

        # remove unifurcation at the root
        if len(tmp_reftree.children) == 1:
            tmp_reftree = tmp_reftree.children[0]
            
        self.mislabels = []

        th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
        th.set_mf_rooted_tree(tmp_taxtree)
            
        epa_result = self.run_epa_once(tmp_reftree)
        
        reftree_epalbl_str = epa_result.get_std_newick_tree()        
        placements = epa_result.get_placement()
        
        # update branchid-taxonomy mapping to account for possible changes in branch numbering
        reftree_tax = Tree(reftree_epalbl_str)
        th.set_bf_unrooted_tree(reftree_tax)
        bid_tax_map = th.get_bid_taxonomy_map()
        
        self.write_bid_tax_map(bid_tax_map, final=True)

        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height)
        
#        newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre")
#        th.get_tax_tree().write(outfile=newtax_fname, format=3)

        final_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # EXPERIMENTAL FEATURE - disabled for now!
            # It could happen that certain ranks were present in the "original" reference tree, but 
            # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" 
            # after the leave-one-out test and thus pruned)
            # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade
            # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from  
            # pruned tree to "Undefined".
#            orig_ranks = th.strip_missing_ranks(orig_ranks)
#            print orig_ranks

            # get EPA tax label
            ranks, lws = cl.classify_seq(place["p"])
            final_ass[seq_name] = (ranks, lws)

            #print seq_name, ": ", orig_ranks, "--->", ranks

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)

        self.write_assignments(final_ass, final=True)

    def run_epa_once(self, reftree):
        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
        job_name = self.cfg.subst_name("final_epa_%NAME%")

        reftree.write(outfile=reftree_fname)

        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! 
        optmod_fname=""
        epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname)

        if self.cfg.output_interim_files:
            out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace")
            self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True)

        return epa_result

    def run_test(self):
        self.raxml = RaxmlWrapper(self.cfg)

#        config.log.info("Number of sequences in the reference: %d\n", self.reftree_size)

        self.refjson.get_raxml_readable_tree(self.reftree_fname)
        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)        
        self.refjson.get_binary_model(self.optmod_fname)
        
        if self.cfg.ranktest:
            config.log.info("Running the leave-one-rank-out test...\n")
            subtree_count = self.run_leave_subtree_out_test()
            
        config.log.info("Running the leave-one-sequence-out test...\n")
        self.run_leave_seq_out_test()

        if len(self.mislabels) > 0:
            config.log.info("Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels))
            if self.cfg.debug:
                self.write_mislabels(final=False)
            self.run_final_epa_test()

        self.filter_mislabels()
        self.sort_mislabels()
        self.write_mislabels()
        config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
Example #42
0
class LeaveOneTest:
    def __init__(self, config):
        self.cfg = config

        self.mis_fname = self.cfg.out_fname("%NAME%.mis")
        self.premis_fname = self.cfg.out_fname("%NAME%.premis")
        self.misrank_fname = self.cfg.out_fname("%NAME%.misrank")
        self.stats_fname = self.cfg.out_fname("%NAME%.stats")

        if os.path.isfile(self.mis_fname):
            print("\nERROR: Output file already exists: %s" % self.mis_fname)
            print(
                "Please specify a different job name using -n or remove old output files."
            )
            self.cfg.exit_user_error()

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")

        self.mislabels = []
        self.mislabels_cnt = []
        self.rank_mislabels = []
        self.rank_mislabels_cnt = []
        self.misrank_conf_map = {}

    def write_bid_tax_map(self, bid_tax_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            bid_fname = self.cfg.tmp_fname("%NAME%_" +
                                           "bid_tax_map_%s.txt" % fname_suffix)
            with open(bid_fname, "w") as outf:
                for bid, bid_rec in bid_tax_map.items():
                    outf.write("%s\t%s\t%d\t%f\n" %
                               (bid, bid_rec[0], bid_rec[1], bid_rec[2]))

    def write_assignments(self, assign_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" %
                                              fname_suffix)
            with open(assign_fname, "w") as outf:
                for seq_name in assign_map.keys():
                    ranks, lws = assign_map[seq_name]
                    outf.write("%s\t%s\t%s\n" %
                               (seq_name, ";".join(ranks), ";".join(
                                   ["%.3f" % l for l in lws])))

    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")

        #validate input json format
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error(
                "ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()

        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()

        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg,
                                                 self.bid_taxonomy_map,
                                                 self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy,
                                            self.tax_tree)

        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)

        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
        #        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS

    def run_epa_trainer(self):
        epa_trainer.run_trainer(self.cfg)

        if not os.path.isfile(self.cfg.refjson_fname):
            self.cfg.log.error(
                "\nBuilding reference tree failed, see error messages above.")
            self.cfg.exit_fatal_error()

    def classify_seq(self, placement):
        edges = placement["p"]
        if len(edges) > 0:
            return self.classify_helper.classify_seq(edges)
        else:
            print("ERROR: no placements! something is definitely wrong!")

    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None

        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
        #        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec[
                'real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks), len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[
                        rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(
                    orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(
                    real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec[
                    'real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]

        if mis_rec:
            self.mislabels.append(mis_rec)

        return mis_rec

    def filter_mislabels(self):
        filtered_mis = []
        for i in range(len(self.mislabels)):
            if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff:
                filtered_mis.append(self.mislabels[i])

        self.mislabels = filtered_mis

    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks), len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[
                    rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = rank_name
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.rank_mislabels.append(mis_rec)

            return mis_rec
        else:
            return None

    def mis_rec_to_string_old(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\n" % (
            mis_rec['level_name'], mis_rec['orig_ranks'][lvl],
            mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += ";".join(mis_rec['orig_ranks']) + "\n"
        output += ";".join(mis_rec['ranks']) + "\n"
        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
        return output

    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        uncorr_name = EpacConfig.strip_ref_prefix(
            self.refjson.get_uncorr_seqid(mis_rec['name']))
        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(
            mis_rec['orig_ranks'])
        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
        output = uncorr_name + "\t"

        if lvl >= 0:
            output += "%s\t%s\t%s\t%.3f\t" % (
                mis_rec['level_name'], uncorr_orig_ranks[lvl],
                uncorr_ranks[lvl], mis_rec['lws'][lvl])
        else:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA",
                                              "NA", mis_rec['lws'][0])

        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output

    def sort_mislabels(self):
        self.mislabels = sorted(self.mislabels,
                                key=itemgetter('inv_level', 'conf', 'name'),
                                reverse=True)
        for mis_rec in self.mislabels:
            real_lvl = mis_rec["real_level"]
            self.mislabels_cnt[real_lvl] += 1

        if self.cfg.ranktest:
            self.rank_mislabels = sorted(self.rank_mislabels,
                                         key=itemgetter(
                                             'inv_level', 'conf', 'name'),
                                         reverse=True)
            for mis_rec in self.rank_mislabels:
                real_lvl = mis_rec["real_level"]
                self.rank_mislabels_cnt[real_lvl] += 1

    def write_stats(self, toFile=False):
        self.cfg.log.info("Mislabeled sequences by rank:")
        seq_sum = 0
        rank_sum = 0
        stats = []
        for i in range(len(self.mislabels_cnt)):
            if i > 0:
                rname = self.tax_code.rank_level_name(i)[0].ljust(12)
            else:
                rname = "[NotIngroup]"
            if self.mislabels_cnt[i] > 0:
                seq_sum += self.mislabels_cnt[i]
                #                    output = "%s:\t%d" % (rname, seq_sum)
                output = "%s:\t%d" % (rname, self.mislabels_cnt[i])
                if self.cfg.ranktest:
                    rank_sum += self.rank_mislabels_cnt[i]
                    output += "\t%d" % rank_sum
                self.cfg.log.info(output)
                stats.append(output)

        if toFile:
            with open(self.stats_fname, "w") as fo_stat:
                for line in stats:
                    fo_stat.write(line + "\n")

    def write_mislabels_header(self, fo, final, fields):
        header = ";" + "\t".join(fields) + "\n"

        # write to file
        if final:
            for line in DISCLAIMER.split("\n"):
                fo.write(";%s\n" % line)
            fo.write(";\n")
        fo.write(header)

        # print to console
        if final and self.cfg.verbose and len(self.rank_mislabels) > 0:
            print(DISCLAIMER, "\n")
            print("Mislabeled sequences:\n")
            print(header)

    def write_rank_mislabels(self, final=True):
        if not self.cfg.ranktest:
            return

        with open(self.misrank_fname, "w") as fo_all:
            fields = [
                "RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel",
                "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath",
                "PerRankConfidence"
            ]
            self.write_mislabels_header(fo_all, final, fields)
            for mis_rec in self.rank_mislabels:
                output = self.mis_rec_to_string(mis_rec) + "\n"
                fo_all.write(output)
                if self.cfg.verbose:
                    print(output)

    def write_mislabels(self, final=True):
        if final:
            out_fname = self.mis_fname
        else:
            out_fname = self.premis_fname

        with open(out_fname, "w") as fo_all:
            fields = [
                "SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel",
                "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath",
                "PerRankConfidence"
            ]
            if self.cfg.ranktest:
                fields += ["HigherRankMisplacedConfidence"]
            self.write_mislabels_header(fo_all, final, fields)
            for mis_rec in self.mislabels:
                output = self.mis_rec_to_string(mis_rec) + "\n"
                fo_all.write(output)
                if self.cfg.verbose and final:
                    print(output)

        if final:
            self.write_rank_mislabels()
            self.write_stats()

    def get_parent_tip_ranks(self, tax_tree):
        rank_tips = {}
        rank_parent = {}
        for node in tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue

            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue

            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size - 4:
                continue


#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks

        return rank_parent, rank_tips

    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
        #        if self.jplace_fname:
        #            jp = EpaJsonParser(self.jplace_fname)
        #        else:

        #create file with subtrees
        rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree)

        subtree_list = list(rank_tips.items())
        if len(subtree_list) == 0:
            return 0

        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))

        jp_list = self.raxml.run_epa(job_name,
                                     self.refalign_fname,
                                     self.reftree_fname,
                                     self.optmod_fname,
                                     mode="l1o_subtree",
                                     subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.tax_code.guess_rank_level_name(
                    orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
                #                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks,
                                                     ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count

    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()

            config.log.debug("Loaded %d placements from %s\n", len(placements),
                             jplace_fmask)
        else:
            jp = self.raxml.run_epa(job_name,
                                    self.refalign_fname,
                                    self.reftree_fname,
                                    self.optmod_fname,
                                    mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname(
                    "%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name,
                                           out_jplace_fname,
                                           move=True,
                                           mode="l1o_seq")

        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            #            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks,
                                                lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2, len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf,
                                        self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)

        return seq_count

    def prune_mislabels_from_tree(self, src_tree, tree_name):
        pruned_tree = src_tree.copy(method="newick")
        name2node = {}
        for leaf in pruned_tree.iter_leaves():
            name2node[leaf.name] = leaf

        for mis_rec in self.mislabels:
            rname = mis_rec['name']
            #            rname = EpacConfig.REF_SEQ_PREFIX + name

            if rname in name2node:
                name2node[rname].delete()
            else:
                config.log.debug("Node not found in the %s tree: %s" %
                                 (tree_name, rname))

        return pruned_tree

    def run_final_epa_test(self):
        self.reftree_outgroup = self.refjson.get_outgroup()

        pruned_reftree = self.prune_mislabels_from_tree(
            self.reftree, "reference")
        pruned_taxtree = self.prune_mislabels_from_tree(
            self.reftree, "taxonomic")

        # remove unifurcation at the root
        if len(pruned_reftree.children) == 1:
            pruned_reftree = pruned_reftree.children[0]

        self.mislabels = []

        th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
        th.set_mf_rooted_tree(pruned_taxtree)

        reftree_epalbl_str = None
        if self.cfg.final_jplace_fname:
            if os.path.isdir(self.cfg.final_jplace_fname):
                jplace_fmask = os.path.join(self.cfg.final_jplace_fname,
                                            '*.jplace')
            else:
                jplace_fmask = self.cfg.final_jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            placements = []
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()
                if not reftree_epalbl_str:
                    reftree_epalbl_str = jp.get_std_newick_tree()

            config.log.debug("Loaded %d final epa placements from %s\n",
                             len(placements), jplace_fmask)
        else:
            epa_result = self.run_epa_once(pruned_reftree)
            reftree_epalbl_str = epa_result.get_std_newick_tree()
            placements = epa_result.get_placement()

        # update branchid-taxonomy mapping to account for possible changes in branch numbering
        reftree_tax = Tree(reftree_epalbl_str)
        th.set_bf_unrooted_tree(reftree_tax)
        bid_tax_map = th.get_bid_taxonomy_map()

        self.write_bid_tax_map(bid_tax_map, final=True)

        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate,
                               self.node_height)

        #        newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre")
        #        th.get_tax_tree().write(outfile=newtax_fname, format=3)

        final_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # EXPERIMENTAL FEATURE - disabled for now!
            # It could happen that certain ranks were present in the "original" reference tree, but
            # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious"
            # after the leave-one-out test and thus pruned)
            # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade
            # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from
            # pruned tree to "Undefined".
            #            orig_ranks = th.strip_missing_ranks(orig_ranks)
            #            print orig_ranks

            # get EPA tax label
            ranks, lws = cl.classify_seq(place["p"])
            final_ass[seq_name] = (ranks, lws)

            #print seq_name, ": ", orig_ranks, "--->", ranks

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks,
                                                lws)

        self.write_assignments(final_ass, final=True)

    def run_epa_once(self, reftree):
        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
        job_name = self.cfg.subst_name("final_epa_%NAME%")

        reftree.write(outfile=reftree_fname)

        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!!
        optmod_fname = ""
        epa_result = self.raxml.run_epa(job_name, self.refalign_fname,
                                        reftree_fname, optmod_fname)

        if self.cfg.output_interim_files:
            out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace")
            self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True)

        return epa_result

    def run_test(self):
        self.raxml = RaxmlWrapper(self.cfg)

        #        config.log.info("Number of sequences in the reference: %d\n", self.reftree_size)

        self.refjson.get_raxml_readable_tree(self.reftree_fname)
        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)
        self.refjson.get_binary_model(self.optmod_fname)

        if self.cfg.ranktest:
            config.log.info("Running the leave-one-rank-out test...\n")
            subtree_count = self.run_leave_subtree_out_test()

        config.log.info("Running the leave-one-sequence-out test...\n")
        self.run_leave_seq_out_test()

        if len(self.mislabels) > 0:
            config.log.info(
                "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n",
                len(self.mislabels))
            if self.cfg.debug:
                self.write_mislabels(final=False)
            self.run_final_epa_test()

        self.filter_mislabels()
        self.sort_mislabels()
        self.write_mislabels()
        config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels),
                        (float(len(self.mislabels)) / self.reftree_size * 100))
Example #43
0
class RefTreeBuilder:
    def __init__(self, config): 
        self.cfg = config
        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
        self.raxml_wrapper = RaxmlWrapper(config)
        
        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")

    def load_alignment(self):
        in_file = self.cfg.align_fname
        self.input_seqs = None
        formats = ["fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"]
        for fmt in formats:
            try:
                self.input_seqs = SeqGroup(sequences=in_file, format = fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.input_seqs == None:
            self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file)
            
    def validate_taxonomy(self):
        self.input_validator = InputValidator(self.cfg, self.taxonomy, self.input_seqs)
        self.input_validator.validate()
        
    def build_multif_tree(self):
        c = self.cfg
        
        tb = TaxTreeBuilder(c, self.taxonomy)
        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore)
        self.reftree_ids = frozenset(ids)
        self.reftree_size = len(ids)
        self.reftree_multif = t

        # IMPORTANT: select GAMMA or CAT model based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)

        if self.cfg.debug:
            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
            # list of sequence ids which comprise the reference tree
            with open(refseq_fname, "w") as f:
                for sid in ids:
                    f.write("%s\n" % sid)

            # original tree with taxonomic ranks as internal node labels
            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
            t.write(outfile=reftax_fname, format=8)
        #    t.show()

    def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""
        
        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in self.input_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.input_validator.corr_seqid:
                  seq_name = self.input_validator.corr_seqid[seq_name]
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")

        # we do not need the original alignment anymore, so free its memory
        self.input_seqs = None

    def export_ref_taxonomy(self):
        self.taxonomy_map = {}
        
        for sid, ranks in self.taxonomy.iteritems():
            if sid in self.reftree_ids:
                self.taxonomy_map[sid] = ranks
            
        if self.cfg.debug:
            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
            with open(tax_fname, "w") as fout:
                for sid, ranks in self.taxonomy_map.iteritems():
                    ranks_str = self.taxonomy.seq_lineage_str(sid) 
                    fout.write(sid + "\t" + ranks_str + "\n")   

    def save_rooting(self):
        rt = self.reftree_multif

        tax_map = self.taxonomy.get_map()
        self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map)
        self.taxtree_helper.set_mf_rooted_tree(rt)
        outgr = self.taxtree_helper.get_outgroup()
        outgr_size = len(outgr.get_leaves())
        outgr.write(outfile=self.outgr_fname, format=9)
        self.reftree_outgroup = outgr
        self.cfg.log.debug("Outgroup for rooting was saved to: %s, outgroup size: %d", self.outgr_fname, outgr_size)
            
        # remove unifurcation at the root
        if len(rt.children) == 1:
            rt = rt.children[0]
        
        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
        rt.write(outfile=self.reftree_mfu_fname, format=9)

    # RAxML call to convert multifurcating tree to the strictly bifurcating one
    def resolve_multif(self):
        self.cfg.log.debug("\nReducing the alignment: \n")
        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(self.refalign_fname)
        
        self.cfg.log.debug("\nConstrained ML inference: \n")
        raxml_params = ["-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "--no-seq-check", "-N", str(self.cfg.rep_num)] 
        if self.cfg.mfresolv_method  == "fast":
            raxml_params += ["-D"]
        elif self.cfg.mfresolv_method  == "ultrafast":
            raxml_params += ["-f", "e"]
        if self.cfg.restart and self.raxml_wrapper.result_exists(self.mfresolv_job_name):
            self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(self.mfresolv_job_name)
            self.cfg.log.debug("\nUsing existing ML tree found in: %s\n", self.raxml_wrapper.result_fname(self.mfresolv_job_name))
        else:
            self.invocation_raxml_multif = self.raxml_wrapper.run(self.mfresolv_job_name, raxml_params)
#            self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num)
            if self.cfg.mfresolv_method  == "ultrafast":
              self.raxml_wrapper.copy_result_tree(self.mfresolv_job_name, self.raxml_wrapper.besttree_fname(self.mfresolv_job_name))
              
        if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name):        
            if not self.cfg.reopt_model:
                self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name, self.reftree_bfu_fname)
                self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname)
                self.invocation_raxml_optmod = ""
                job_name = self.mfresolv_job_name
            else:
                bfu_fname = self.raxml_wrapper.besttree_fname(self.mfresolv_job_name)
                job_name = self.optmod_job_name

                # RAxML call to optimize model parameters and write them down to the binary model file
                self.cfg.log.debug("\nOptimizing model parameters: \n")
                raxml_params = ["-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check"]
                if self.cfg.raxml_model.startswith("GTRCAT") and not self.cfg.compress_patterns:
                    raxml_params +=  ["-H"]
                if self.cfg.restart and self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(self.optmod_job_name)
                    self.cfg.log.debug("\nUsing existing optimized tree and parameters found in: %s\n", self.raxml_wrapper.result_fname(self.optmod_job_name))
                else:
                    self.invocation_raxml_optmod = self.raxml_wrapper.run(self.optmod_job_name, raxml_params)
                if self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.raxml_wrapper.copy_result_tree(self.optmod_job_name, self.reftree_bfu_fname)
                    self.raxml_wrapper.copy_optmod_params(self.optmod_job_name, self.optmod_fname)
                else:
                    errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \
                            % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
                    self.cfg.exit_fatal_error(errmsg)
                    
            if self.cfg.raxml_model.startswith("GTRCAT"):
              mod_name = "CAT"
            else:
              mod_name = "GAMMA" 
            self.reftree_loglh = self.raxml_wrapper.get_tree_lh(job_name, mod_name)
            self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" % (mod_name, self.reftree_loglh))

        else:
            errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
            self.cfg.exit_fatal_error(errmsg)
            
    def load_reduced_refalign(self):
        formats = ["fasta", "phylip_relaxed"]
        for fmt in formats:
            try:
                self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
                break
            except:
                pass
        if self.reduced_refalign_seqs == None:
            errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
            self.cfg.exit_fatal_error(errmsg)
    
    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks    
    def epa_branch_labeling(self):
        # create alignment with dummy query seq
        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
        self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname)
        
        with open(self.lblalign_fname, "a") as fout:
            fout.write(">" + "DUMMY131313" + "\n")        
            fout.write("A"*self.refalign_width + "\n")        
        
        # TODO always load model regardless of the config file settings?
        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname, mode="epa_mp")
        self.reftree_lbl_str = epa_result.get_std_newick_tree()
        self.raxml_version = epa_result.get_raxml_version()
        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()

        if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):        
            errmsg = "RAxML EPA run failed, please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def epa_post_process(self):
        lbl_tree = Tree(self.reftree_lbl_str)
        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
        self.reftree_tax = self.taxtree_helper.get_tax_tree()
        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()
        
        if self.cfg.debug:
            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)
            with open(self.reftree_lbl_fname, "w") as outf:
                outf.write(self.reftree_lbl_str)
            with open(self.brmap_fname, "w") as outf:
                for bid, br_rec in self.bid_ranks_map.iteritems():
                    outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2]))

    def calc_node_heights(self):
        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
           Algorithm is as follows:
           Tip node or node resolved to Species level: height = 1 
           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1 
         """
        nh_map = {}
        dummy_added = False
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root():
                if not hasattr(node, "B"):                
                    # In a rooted tree, there is always one more node/branch than in unrooted one
                    # That's why one branch will be always not EPA-labelled after the rooting
                    if not dummy_added: 
                        node.B = "DDD"
                        dummy_added = True
                        species_rank = Taxonomy.EMPTY_RANK
                    else:
                        errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
                        self.cfg.exit_fatal_error(errmsg)
                else:
                    species_rank = self.bid_ranks_map[node.B][-1]
                bid = node.B
                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
                    nh_map[bid] = 1
                else:
                    lchild = node.children[0]
                    rchild = node.children[1]
                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1

        # remove heights for dummy nodes, since there won't be any placements on them
        if dummy_added:
            del nh_map["DDD"]
            
        self.node_height_map = nh_map

    def __get_all_rank_names(self, root):
        rnames = set([])
        for node in root.traverse("postorder"):
            ranks = node.ranks
            for rk in ranks:
                rnames.add(rk)
        return rnames

    def mono_index(self):
        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
        children = self.reftree_tax.children
        if len(children) == 1:
            while len(children) == 1:
                children = children[0].children 
        if len(children) == 2:
            left = children[0]
            right =children[1]
            lset = self.__get_all_rank_names(left)
            rset = self.__get_all_rank_names(right)
            iset = lset & rset
            return iset
        else:
            print("Error: input tree not birfurcating")
            return set([])

    def build_hmm_profile(self, json_builder):
        print "Building the HMMER profile...\n"

        # this stupid workaround is needed because RAxML outputs the reduced
        # alignment in relaxed PHYLIP format, which is not supported by HMMER
        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
        self.reduced_refalign_seqs.write(outfile=refalign_fasta)

        hmm = hmmer(self.cfg, refalign_fasta)
        fprofile = hmm.build_hmm_profile()

        json_builder.set_hmm_profile(fprofile)
        
    def write_json(self):
        jw = RefJsonBuilder()

        jw.set_branch_tax_map(self.bid_ranks_map)
        jw.set_tree(self.reftree_lbl_str)
        jw.set_outgroup(self.reftree_outgroup)
        jw.set_ratehet_model(self.cfg.raxml_model)
        jw.set_tax_tree(self.reftree_multif)
        jw.set_pattern_compression(self.cfg.compress_patterns)
        jw.set_taxcode(self.cfg.taxcode_name)
        
        jw.set_merged_ranks_map(self.input_validator.merged_ranks)
        corr_ranks_reverse = dict((reversed(item) for item in self.input_validator.corr_ranks.items()))
        jw.set_corr_ranks_map(corr_ranks_reverse)
        corr_seqid_reverse = dict((reversed(item) for item in self.input_validator.corr_seqid.items()))
        jw.set_corr_seqid_map(corr_seqid_reverse)

        mdata = { "ref_tree_size": self.reftree_size, 
                  "ref_alignment_width": self.refalign_width,
                  "raxml_version": self.raxml_version,
                  "timestamp": str(datetime.datetime.now()),
                  "invocation_epac": self.invocation_epac,
                  "invocation_raxml_multif": self.invocation_raxml_multif,
                  "invocation_raxml_optmod": self.invocation_raxml_optmod,
                  "invocation_raxml_epalbl": self.invocation_raxml_epalbl,
                  "reftree_loglh": self.reftree_loglh
                }
        jw.set_metadata(mdata)

        seqs = self.reduced_refalign_seqs.get_entries()    
        jw.set_sequences(seqs)
        
        if not self.cfg.no_hmmer:
            self.build_hmm_profile(jw)

        orig_tax = self.taxonomy_map
        jw.set_origin_taxonomy(orig_tax)
        
        self.cfg.log.debug("Calculating the speciation rate...\n")
        tp = tree_param(tree = self.reftree_lbl_str, origin_taxonomy = orig_tax)
        jw.set_rate(tp.get_speciation_rate_fast())
        jw.set_nodes_height(self.node_height_map)
        
        jw.set_binary_model(self.optmod_fname)
        
        self.cfg.log.debug("Writing down the reference file...\n")
        jw.dump(self.cfg.refjson_fname)

    # top-level function to build a reference tree    
    def build_ref_tree(self):
        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname)
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname)
        self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname)
        self.load_alignment()
        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
        self.validate_taxonomy()
        self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count())
        self.build_multif_tree()
        self.cfg.log.info("=====> Building the reference alignment ...\n")
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n")
        self.save_rooting()
        self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num)
        self.resolve_multif()
        self.load_reduced_refalign()
        self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n")
        self.epa_branch_labeling()
        self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n")
        self.epa_post_process()
        self.calc_node_heights()
        
        self.cfg.log.debug("\n==========> Checking branch labels ...")
        self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks()))
        self.cfg.log.debug("shared rank names after  training: %s\n", repr(self.mono_index()))
        
        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname)
        self.write_json()
    def assign_taxonomy_maxsum(self, edges, minlw):
        """this function sums up all LH-weights for each rank and takes the rank with the max. sum """
        # in EPA result, each placement(=branch) has a "weight"
        # since we are interested in taxonomic placement, we do not care about branch vs. branch comparisons,
        # but only consider rank vs. rank (e. g. G1 S1 vs. G1 S2 vs. G1)
        # Thus we accumulate weights for each rank, there are to measures:
        # "own" weight  = sum of weight of all placements EXACTLY to this rank (e.g. for G1: G1 only)
        # "total" rank  = own rank + own rank of all children (for G1: G1 or G1 S1 or G1 S2)
        rw_own = {}
        rw_total = {}
        rb = {}

        ranks = [Taxonomy.EMPTY_RANK]

        for edge in edges:
            br_id = str(edge[0])
            lweight = edge[2]
            lowest_rank = None

            if lweight == 0.0:
                continue

            # accumulate weight for the current sequence
            ranks = self.bid_taxonomy_map[br_id]
            for i in range(len(ranks)):
                rank = ranks[i]
                rank_id = Taxonomy.get_rank_uid(ranks, i)
                if rank != Taxonomy.EMPTY_RANK:
                    rw_total[rank_id] = rw_total.get(rank_id, 0) + lweight
                    lowest_rank = rank_id
                    if not rank_id in rb:
                        rb[rank_id] = br_id
                else:
                    break

            if lowest_rank:
                rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight
                rb[lowest_rank] = br_id
            elif self.cfg.verbose:
                print "WARNING: no annotation for branch ", br_id

        # if all branches have empty ranks only, just return this placement
        if len(rw_total) == 0:
            return ranks, [1.0] * len(ranks)

        # we assign the sequence to a rank, which has the max "own" weight AND
        # whose "total" weight is greater than a confidence threshold
        max_rw = 0.0
        s_r = None
        for r in rw_own.iterkeys():
            if rw_own[r] > max_rw and rw_total[r] >= minlw:
                s_r = r
                max_rw = rw_own[r]
        if not s_r:
            s_r = max(rw_total.iterkeys(), key=(lambda key: rw_total[key]))

        a_br_id = rb[s_r]
        a_ranks = self.bid_taxonomy_map[a_br_id]

        # "total" weight is considered as confidence value for now
        a_conf = [0.0] * len(a_ranks)
        for i in range(len(a_conf)):
            rank = a_ranks[i]
            if rank != Taxonomy.EMPTY_RANK:
                rank_id = Taxonomy.get_rank_uid(a_ranks, i)
                a_conf[i] = rw_total[rank_id]

        return a_ranks, a_conf
Example #45
0
 def test_rank_uid(self):
     tax = self.taxonomy
     for sid in tax.get_map().iterkeys():
         self.assertEqual(tax.get_seq_ranks(sid), Taxonomy.split_rank_uid(tax.seq_rank_id(sid)))