def check_identical_ranks(self): if not self.dupseq_sets: self.check_identical_seqs() self.merged_ranks = {} for dup_ids in self.dupseq_sets: if len(dup_ids) > 1: duprank_map = {} for seq_name in dup_ids: rank_id = self.taxonomy.seq_rank_id(seq_name) duprank_map[rank_id] = duprank_map.get(rank_id, 0) + 1 if len(duprank_map) > 1 and self.cfg.debug: self.cfg.log.debug("Ranks sharing duplicates: %s\n", str(duprank_map)) dup_ranks = [] for rank_id, count in duprank_map.iteritems(): if count > self.cfg.taxa_ident_thres * self.taxonomy.get_rank_seq_count(rank_id): dup_ranks += [rank_id] if len(dup_ranks) > 1: prefix = "__TAXCLUSTER%d__" % (len(self.merged_ranks) + 1) merged_rank_id = self.taxonomy.merge_ranks(dup_ranks, prefix) self.merged_ranks[merged_rank_id] = dup_ranks if self.verbose: merged_count = 0 for merged_rank_id, dup_ranks in self.merged_ranks.iteritems(): dup_ranks_str = "\n".join([Taxonomy.rank_uid_to_lineage_str(rank_id) for rank_id in dup_ranks]) self.cfg.log.warning("\nWARNING: Following taxa share >%.0f%% indentical sequences und thus considered indistinguishable:\n%s", self.cfg.taxa_ident_thres * 100, dup_ranks_str) merged_rank_str = Taxonomy.rank_uid_to_lineage_str(merged_rank_id) self.cfg.log.warning("For the purpose of mislabels identification, they were merged into one taxon:\n%s\n", merged_rank_str) merged_count += len(dup_ranks) if merged_count > 0: self.cfg.log.warning("WARNING: %d indistinguishable taxa have been merged into %d clusters.\n", merged_count, len(self.merged_ranks)) return self.merged_ranks
def get_parent_tip_ranks(self, tax_tree): rank_tips = {} rank_parent = {} for node in tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size-4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks return rank_parent, rank_tips
def get_parent_tip_ranks(self, tax_tree): rank_tips = {} rank_parent = {} for node in tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size - 4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks return rank_parent, rank_tips
def test_merge_ranks(self): tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map) merge_sids = ["UnpC[Ceti]", "UnpSomer,"] rank_ids = [tax.seq_rank_id(sid) for sid in merge_sids] new_rank_id = tax.merge_ranks(rank_ids) self.assertEqual(merge_sids, tax.get_rank_seqs(new_rank_id)) for sid in merge_sids: self.assertEqual(tax.seq_rank_id(sid), new_rank_id)
def test_normalize_seq_ids(self): tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map) self.assertTrue("UnpC[Ceti]" in tax.seq_ranks_map) self.assertTrue("UnpSomer," in tax.seq_ranks_map) tax.normalize_seq_ids() self.assertFalse("UnpC[Ceti]" in tax.seq_ranks_map) self.assertTrue("UnpC_Ceti_" in tax.seq_ranks_map) self.assertFalse("UnpSomer," in tax.seq_ranks_map) self.assertTrue("UnpSomer_" in tax.seq_ranks_map)
def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error( "ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
def test_taxtree_builder(self): cfg = EpacConfig() testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") tax_fname = os.path.join(testfile_dir, "test.tax") tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname) tree_fname = os.path.join(testfile_dir, "taxtree.nw") expected_tree = Tree(tree_fname, format=8) tb = TaxTreeBuilder(cfg, tax) tax_tree, seq_ids = tb.build() self.assertEqual(seq_ids, tax.get_map().keys()) self.assertEqual(tax_tree.write(format=8), expected_tree.write(format=8))
def setUp(self): self.testfile_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname)
def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += Taxonomy.lineage_str(mis_rec['orig_ranks']) + "\t" output += Taxonomy.lineage_str(mis_rec['ranks']) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output
def test_normalize_rank_names(self): tax = Taxonomy(tax_map=self.taxonomy.seq_ranks_map) ranks = tax.get_seq_ranks("UpbRectu") self.assertEqual(ranks[0], "[Bacteria]") self.assertEqual(ranks[1], "'Firmicutes'") self.assertEqual(ranks[2], "Clostridia(1)") corr_ranks = tax.normalize_rank_names() self.assertEqual(len(corr_ranks), 3) ranks = tax.get_seq_ranks("UpbRectu") self.assertEqual(ranks[0], "_Bacteria_") self.assertEqual(ranks[1], "_Firmicutes_") self.assertEqual(ranks[2], "Clostridia_1_")
def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks), len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[ rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level( orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name( real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec
def build_ref_tree(self): self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname) self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname) self.load_alignment() self.cfg.log.info("===> Validating taxonomy and alignment ...\n") self.validate_taxonomy() self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count()) self.build_multif_tree() self.cfg.log.info("=====> Building the reference alignment ...\n") self.export_ref_alignment() self.export_ref_taxonomy() self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n") self.save_rooting() self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num) self.resolve_multif() self.load_reduced_refalign() self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n") self.epa_branch_labeling() self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n") self.epa_post_process() self.calc_node_heights() self.cfg.log.debug("\n==========> Checking branch labels ...") self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks())) self.cfg.log.debug("shared rank names after training: %s\n", repr(self.mono_index())) self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname) self.write_json()
def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") if self.jplace_fname: jp = EpaJsonParser(self.jplace_fname) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() seq_count = 0 for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.get_orig_ranks(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.ranktest and mis_rec: rank_conf = 0 for lvl in range(2,len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 return seq_count
def label_bf_tree_with_ranks(self): """labeling inner tree nodes with taxonomic ranks""" if not self.bf_rooted_tree: raise AssertionError( "self.bf_rooted_tree is not set: TaxTreeHelper.set_bf_unrooted_tree() must be called before!" ) for node in self.bf_rooted_tree.traverse("postorder"): if node.is_leaf(): seq_ranks = self.origin_taxonomy[node.name] rank_level = Taxonomy.lowest_assigned_rank_level(seq_ranks) node.add_feature("rank_level", rank_level) node.add_feature("ranks", seq_ranks) node.name += "__" + seq_ranks[rank_level] else: if len(node.children) != 2: raise AssertionError("FATAL ERROR: tree is not bifurcating!") lchild = node.children[0] rchild = node.children[1] rank_level = min(lchild.rank_level, rchild.rank_level) while rank_level >= 0 and lchild.ranks[rank_level] != rchild.ranks[rank_level]: rank_level -= 1 node.add_feature("rank_level", rank_level) node_ranks = [Taxonomy.EMPTY_RANK] * 7 if rank_level >= 0: node_ranks[0 : rank_level + 1] = lchild.ranks[0 : rank_level + 1] node.name = lchild.ranks[rank_level] else: node.name = "Undefined" if hasattr(node, "B") and self.cfg.verbose: print "INFO: no taxonomic annotation for branch %s (reason: children belong to different kingdoms)" % node.B node.add_feature("ranks", node_ranks) self.tax_tree = self.bf_rooted_tree
def check_identical_ranks(self): if not self.dupseq_sets: self.check_identical_seqs() self.merged_ranks = {} for dup_ids in self.dupseq_sets: if len(dup_ids) > 1: duprank_map = {} for seq_name in dup_ids: rank_id = self.taxonomy.seq_rank_id(seq_name) duprank_map[rank_id] = duprank_map.get(rank_id, 0) + 1 if len(duprank_map) > 1 and self.cfg.debug: self.cfg.log.debug("Ranks sharing duplicates: %s\n", str(duprank_map)) dup_ranks = [] for rank_id, count in duprank_map.iteritems(): if count > self.cfg.taxa_ident_thres * self.taxonomy.get_rank_seq_count( rank_id): dup_ranks += [rank_id] if len(dup_ranks) > 1: prefix = "__TAXCLUSTER%d__" % (len(self.merged_ranks) + 1) merged_rank_id = self.taxonomy.merge_ranks( dup_ranks, prefix) self.merged_ranks[merged_rank_id] = dup_ranks if self.verbose: merged_count = 0 for merged_rank_id, dup_ranks in self.merged_ranks.iteritems(): dup_ranks_str = "\n".join([ Taxonomy.rank_uid_to_lineage_str(rank_id) for rank_id in dup_ranks ]) self.cfg.log.warning( "\nWARNING: Following taxa share >%.0f%% indentical sequences und thus considered indistinguishable:\n%s", self.cfg.taxa_ident_thres * 100, dup_ranks_str) merged_rank_str = Taxonomy.rank_uid_to_lineage_str( merged_rank_id) self.cfg.log.warning( "For the purpose of mislabels identification, they were merged into one taxon:\n%s\n", merged_rank_str) merged_count += len(dup_ranks) if merged_count > 0: self.cfg.log.warning( "WARNING: %d indistinguishable taxa have been merged into %d clusters.\n", merged_count, len(self.merged_ranks)) return self.merged_ranks
def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree) subtree_list = list(rank_tips.items()) if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.tax_code.guess_rank_level_name( orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count
def get_orig_ranks(self, seq_name): nodes = self.tax_tree.get_leaves_by_name(seq_name) if len(nodes) != 1: print "FATAL ERROR: Sequence %s is not found in the taxonomic tree, or is present more than once!" % seq_name sys.exit() seq_node = nodes[0] orig_ranks = Taxonomy.split_rank_uid(seq_node.up.name) return orig_ranks
def setUp(self): test_dir = os.path.dirname(os.path.abspath(__file__)) self.tax_fname = os.path.join(test_dir, "test.tax") self.PREFIXED_TAX_DICT = {} with open(self.tax_fname, "w") as outf: for sid, ranks in self.TAX_DICT.iteritems(): outf.write("%s\t%s\n" % (sid, ";".join(ranks))) self.PREFIXED_TAX_DICT[EpacConfig.REF_SEQ_PREFIX+sid] = ranks self.taxonomy = Taxonomy("", self.tax_fname)
def build_ref_tree(self): self.cfg.log.info("=> Loading taxonomy from file: %s ...\n", self.cfg.taxonomy_fname) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname) self.cfg.log.info( "==> Loading reference alignment from file: %s ...\n", self.cfg.align_fname) self.load_alignment() self.cfg.log.info("===> Validating taxonomy and alignment ...\n") self.validate_taxonomy() self.cfg.log.info( "====> Building a multifurcating tree from taxonomy with %d seqs ...\n", self.taxonomy.seq_count()) self.build_multif_tree() self.cfg.log.info("=====> Building the reference alignment ...\n") self.export_ref_alignment() self.export_ref_taxonomy() self.cfg.log.info( "======> Saving the outgroup for later re-rooting ...\n") self.save_rooting() self.cfg.log.info( "=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num) self.resolve_multif() self.load_reduced_refalign() self.cfg.log.info( "========> Calling RAxML-EPA to obtain branch labels ...\n") self.epa_branch_labeling() self.cfg.log.info( "=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n" ) self.epa_post_process() self.calc_node_heights() self.cfg.log.debug("\n==========> Checking branch labels ...") self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks())) self.cfg.log.debug("shared rank names after training: %s\n", repr(self.mono_index())) self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname) self.write_json()
def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec
def setUp(self): self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname)
def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output
def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_parent, rank_tips = get_parent_tip_ranks(self.tax_tree) subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count
def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix( self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks( mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % ( mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output
def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2,len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count
def test_subst_ranks(self): testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") tax_fname = os.path.join(testfile_dir, "test.tax") tax = Taxonomy("", tax_fname) old_ranks = tax.get_seq_ranks("WgeSangu") self.assertEqual(old_ranks[-2], 'Sneathia') syn_map = {'Sneathia' : 'Sebaldella'} tax.subst_synonyms(syn_map) new_ranks = tax.get_seq_ranks("WgeSangu") self.assertEqual(old_ranks[-2], 'Sebaldella')
def setUp(self): cfg = EpacTrainerConfig() cfg.debug = True testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") tax_fname = os.path.join(testfile_dir, "test.tax") phy_fname = os.path.join(testfile_dir, "test.phy") tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname) seqs = SeqGroup(sequences=phy_fname, format="phylip") self.inval = InputValidator(cfg, tax, seqs, False) self.expected_mis_ids = ["Missing1", "Missing2"] self.expected_dups = ["DupSeq(01)", "DupSeq02"] self.expected_merges = [ self.inval.taxonomy.seq_rank_id(sid) for sid in self.expected_dups ]
class TaxTreeHelperTests(unittest.TestCase): def setUp(self): self.testfile_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname) def tearDown(self): self.taxonomy = None self.taxtree_helper = None def test_outgroup(self): mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw") mfu_tree = Tree(mfu_tree_fname) self.taxtree_helper.set_mf_rooted_tree(mfu_tree) outgr = self.taxtree_helper.get_outgroup() self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names()) def test_branch_labeling(self): bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw") bfu_tree = Tree(bfu_tree_fname) map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt") self.expected_map = {} with open(map_fname) as inf: for line in inf: bid, rank_id, rdiff, brlen = line.strip().split("\t") self.expected_map[bid] = (rank_id, int(rdiff), float(brlen)) self.taxtree_helper.set_outgroup(self.expected_outgr) self.taxtree_helper.set_bf_unrooted_tree(bfu_tree) bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map() self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3) for bid in self.expected_map.iterkeys(): e_rec = self.expected_map[bid] rec = bid_tax_map[bid] self.assertEqual(e_rec[0], rec[0]) self.assertEqual(e_rec[1], rec[1]) self.assertAlmostEqual(e_rec[2], rec[2], 6)
class TaxTreeHelperTests(unittest.TestCase): def setUp(self): self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname) def tearDown(self): self.taxonomy = None self.taxtree_helper = None def test_outgroup(self): mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw") mfu_tree = Tree(mfu_tree_fname) self.taxtree_helper.set_mf_rooted_tree(mfu_tree) outgr = self.taxtree_helper.get_outgroup() self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names()) def test_branch_labeling(self): bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw") bfu_tree = Tree(bfu_tree_fname) map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt") self.expected_map = {} with open(map_fname) as inf: for line in inf: bid, rank_id, rdiff, brlen = line.strip().split("\t") self.expected_map[bid] = (rank_id, int(rdiff), float(brlen)) self.taxtree_helper.set_outgroup(self.expected_outgr) self.taxtree_helper.set_bf_unrooted_tree(bfu_tree) bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map() self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3) for bid in self.expected_map.iterkeys(): e_rec = self.expected_map[bid] rec = bid_tax_map[bid] self.assertEqual(e_rec[0], rec[0]) self.assertEqual(e_rec[1], rec[1]) self.assertAlmostEqual(e_rec[2], rec[2], 6)
def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname( "%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2, len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count
class RefTreeBuilder: def __init__(self, config): self.cfg = config self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%") self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%") self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%") self.raxml_wrapper = RaxmlWrapper(config) self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre") self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre") self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa") self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre") self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt") def load_alignment(self): in_file = self.cfg.align_fname self.input_seqs = None formats = [ "fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip" ] for fmt in formats: try: self.input_seqs = SeqGroup(sequences=in_file, format=fmt) break except: self.cfg.log.debug("Guessing input format: not " + fmt) if self.input_seqs == None: self.cfg.exit_user_error( "Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file) def validate_taxonomy(self): self.input_validator = InputValidator(self.cfg, self.taxonomy, self.input_seqs) self.input_validator.validate() def build_multif_tree(self): c = self.cfg tb = TaxTreeBuilder(c, self.taxonomy) (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore) self.reftree_ids = frozenset(ids) self.reftree_size = len(ids) self.reftree_multif = t # IMPORTANT: select GAMMA or CAT model based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) if self.cfg.debug: refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt") # list of sequence ids which comprise the reference tree with open(refseq_fname, "w") as f: for sid in ids: f.write("%s\n" % sid) # original tree with taxonomic ranks as internal node labels reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre") t.write(outfile=reftax_fname, format=8) # t.show() def export_ref_alignment(self): """This function transforms the input alignment in the following way: 1. Filter out sequences which are not part of the reference tree 2. Add sequence name prefix (r_)""" self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa") with open(self.refalign_fname, "w") as fout: for name, seq, comment, sid in self.input_seqs.iter_entries(): seq_name = EpacConfig.REF_SEQ_PREFIX + name if seq_name in self.input_validator.corr_seqid: seq_name = self.input_validator.corr_seqid[seq_name] if seq_name in self.reftree_ids: fout.write(">" + seq_name + "\n" + seq + "\n") # we do not need the original alignment anymore, so free its memory self.input_seqs = None def export_ref_taxonomy(self): self.taxonomy_map = {} for sid, ranks in self.taxonomy.iteritems(): if sid in self.reftree_ids: self.taxonomy_map[sid] = ranks if self.cfg.debug: tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt") with open(tax_fname, "w") as fout: for sid, ranks in self.taxonomy_map.iteritems(): ranks_str = self.taxonomy.seq_lineage_str(sid) fout.write(sid + "\t" + ranks_str + "\n") def save_rooting(self): rt = self.reftree_multif tax_map = self.taxonomy.get_map() self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map) self.taxtree_helper.set_mf_rooted_tree(rt) outgr = self.taxtree_helper.get_outgroup() outgr_size = len(outgr.get_leaves()) outgr.write(outfile=self.outgr_fname, format=9) self.reftree_outgroup = outgr self.cfg.log.debug( "Outgroup for rooting was saved to: %s, outgroup size: %d", self.outgr_fname, outgr_size) # remove unifurcation at the root if len(rt.children) == 1: rt = rt.children[0] # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml rt.write(outfile=self.reftree_mfu_fname, format=9) # RAxML call to convert multifurcating tree to the strictly bifurcating one def resolve_multif(self): self.cfg.log.debug("\nReducing the alignment: \n") self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment( self.refalign_fname) self.cfg.log.debug("\nConstrained ML inference: \n") raxml_params = [ "-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "--no-seq-check", "-N", str(self.cfg.rep_num) ] if self.cfg.mfresolv_method == "fast": raxml_params += ["-D"] elif self.cfg.mfresolv_method == "ultrafast": raxml_params += ["-f", "e"] if self.cfg.restart and self.raxml_wrapper.result_exists( self.mfresolv_job_name): self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str( self.mfresolv_job_name) self.cfg.log.debug( "\nUsing existing ML tree found in: %s\n", self.raxml_wrapper.result_fname(self.mfresolv_job_name)) else: self.invocation_raxml_multif = self.raxml_wrapper.run( self.mfresolv_job_name, raxml_params) # self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num) if self.cfg.mfresolv_method == "ultrafast": self.raxml_wrapper.copy_result_tree( self.mfresolv_job_name, self.raxml_wrapper.besttree_fname(self.mfresolv_job_name)) if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name): if not self.cfg.reopt_model: self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name, self.reftree_bfu_fname) self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname) self.invocation_raxml_optmod = "" job_name = self.mfresolv_job_name else: bfu_fname = self.raxml_wrapper.besttree_fname( self.mfresolv_job_name) job_name = self.optmod_job_name # RAxML call to optimize model parameters and write them down to the binary model file self.cfg.log.debug("\nOptimizing model parameters: \n") raxml_params = [ "-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check" ] if self.cfg.raxml_model.startswith( "GTRCAT") and not self.cfg.compress_patterns: raxml_params += ["-H"] if self.cfg.restart and self.raxml_wrapper.result_exists( self.optmod_job_name): self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str( self.optmod_job_name) self.cfg.log.debug( "\nUsing existing optimized tree and parameters found in: %s\n", self.raxml_wrapper.result_fname(self.optmod_job_name)) else: self.invocation_raxml_optmod = self.raxml_wrapper.run( self.optmod_job_name, raxml_params) if self.raxml_wrapper.result_exists(self.optmod_job_name): self.raxml_wrapper.copy_result_tree( self.optmod_job_name, self.reftree_bfu_fname) self.raxml_wrapper.copy_optmod_params( self.optmod_job_name, self.optmod_fname) else: errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name) self.cfg.exit_fatal_error(errmsg) if self.cfg.raxml_model.startswith("GTRCAT"): mod_name = "CAT" else: mod_name = "GAMMA" self.reftree_loglh = self.raxml_wrapper.get_tree_lh( job_name, mod_name) self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" % (mod_name, self.reftree_loglh)) else: errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name) self.cfg.exit_fatal_error(errmsg) def load_reduced_refalign(self): formats = ["fasta", "phylip_relaxed"] for fmt in formats: try: self.reduced_refalign_seqs = SeqGroup( sequences=self.reduced_refalign_fname, format=fmt) break except: pass if self.reduced_refalign_seqs == None: errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname self.cfg.exit_fatal_error(errmsg) # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks def epa_branch_labeling(self): # create alignment with dummy query seq self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0)) self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname) with open(self.lblalign_fname, "a") as fout: fout.write(">" + "DUMMY131313" + "\n") fout.write("A" * self.refalign_width + "\n") # TODO always load model regardless of the config file settings? epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname, mode="epa_mp") self.reftree_lbl_str = epa_result.get_std_newick_tree() self.raxml_version = epa_result.get_raxml_version() self.invocation_raxml_epalbl = epa_result.get_raxml_invocation() if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name): errmsg = "RAxML EPA run failed, please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name) self.cfg.exit_fatal_error(errmsg) def epa_post_process(self): lbl_tree = Tree(self.reftree_lbl_str) self.taxtree_helper.set_bf_unrooted_tree(lbl_tree) self.reftree_tax = self.taxtree_helper.get_tax_tree() self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map() if self.cfg.debug: self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3) with open(self.reftree_lbl_fname, "w") as outf: outf.write(self.reftree_lbl_str) with open(self.brmap_fname, "w") as outf: for bid, br_rec in self.bid_ranks_map.iteritems(): outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2])) def calc_node_heights(self): """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step) Algorithm is as follows: Tip node or node resolved to Species level: height = 1 Inner node resolved to Genus or above: height = min(left_height, right_height) + 1 """ nh_map = {} dummy_added = False for node in self.reftree_tax.traverse("postorder"): if not node.is_root(): if not hasattr(node, "B"): # In a rooted tree, there is always one more node/branch than in unrooted one # That's why one branch will be always not EPA-labelled after the rooting if not dummy_added: node.B = "DDD" dummy_added = True species_rank = Taxonomy.EMPTY_RANK else: errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)" self.cfg.exit_fatal_error(errmsg) else: species_rank = self.bid_ranks_map[node.B][-1] bid = node.B if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK: nh_map[bid] = 1 else: lchild = node.children[0] rchild = node.children[1] nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1 # remove heights for dummy nodes, since there won't be any placements on them if dummy_added: del nh_map["DDD"] self.node_height_map = nh_map def __get_all_rank_names(self, root): rnames = set([]) for node in root.traverse("postorder"): ranks = node.ranks for rk in ranks: rnames.add(rk) return rnames def mono_index(self): """This method will calculate monophyly index by looking at the left and right hand side of the tree""" children = self.reftree_tax.children if len(children) == 1: while len(children) == 1: children = children[0].children if len(children) == 2: left = children[0] right = children[1] lset = self.__get_all_rank_names(left) rset = self.__get_all_rank_names(right) iset = lset & rset return iset else: print("Error: input tree not birfurcating") return set([]) def build_hmm_profile(self, json_builder): print "Building the HMMER profile...\n" # this stupid workaround is needed because RAxML outputs the reduced # alignment in relaxed PHYLIP format, which is not supported by HMMER refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa") self.reduced_refalign_seqs.write(outfile=refalign_fasta) hmm = hmmer(self.cfg, refalign_fasta) fprofile = hmm.build_hmm_profile() json_builder.set_hmm_profile(fprofile) def write_json(self): jw = RefJsonBuilder() jw.set_branch_tax_map(self.bid_ranks_map) jw.set_tree(self.reftree_lbl_str) jw.set_outgroup(self.reftree_outgroup) jw.set_ratehet_model(self.cfg.raxml_model) jw.set_tax_tree(self.reftree_multif) jw.set_pattern_compression(self.cfg.compress_patterns) jw.set_taxcode(self.cfg.taxcode_name) jw.set_merged_ranks_map(self.input_validator.merged_ranks) corr_ranks_reverse = dict( (reversed(item) for item in self.input_validator.corr_ranks.items())) jw.set_corr_ranks_map(corr_ranks_reverse) corr_seqid_reverse = dict( (reversed(item) for item in self.input_validator.corr_seqid.items())) jw.set_corr_seqid_map(corr_seqid_reverse) mdata = { "ref_tree_size": self.reftree_size, "ref_alignment_width": self.refalign_width, "raxml_version": self.raxml_version, "timestamp": str(datetime.datetime.now()), "invocation_epac": self.invocation_epac, "invocation_raxml_multif": self.invocation_raxml_multif, "invocation_raxml_optmod": self.invocation_raxml_optmod, "invocation_raxml_epalbl": self.invocation_raxml_epalbl, "reftree_loglh": self.reftree_loglh } jw.set_metadata(mdata) seqs = self.reduced_refalign_seqs.get_entries() jw.set_sequences(seqs) if not self.cfg.no_hmmer: self.build_hmm_profile(jw) orig_tax = self.taxonomy_map jw.set_origin_taxonomy(orig_tax) self.cfg.log.debug("Calculating the speciation rate...\n") tp = tree_param(tree=self.reftree_lbl_str, origin_taxonomy=orig_tax) jw.set_rate(tp.get_speciation_rate_fast()) jw.set_nodes_height(self.node_height_map) jw.set_binary_model(self.optmod_fname) self.cfg.log.debug("Writing down the reference file...\n") jw.dump(self.cfg.refjson_fname) # top-level function to build a reference tree def build_ref_tree(self): self.cfg.log.info("=> Loading taxonomy from file: %s ...\n", self.cfg.taxonomy_fname) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname) self.cfg.log.info( "==> Loading reference alignment from file: %s ...\n", self.cfg.align_fname) self.load_alignment() self.cfg.log.info("===> Validating taxonomy and alignment ...\n") self.validate_taxonomy() self.cfg.log.info( "====> Building a multifurcating tree from taxonomy with %d seqs ...\n", self.taxonomy.seq_count()) self.build_multif_tree() self.cfg.log.info("=====> Building the reference alignment ...\n") self.export_ref_alignment() self.export_ref_taxonomy() self.cfg.log.info( "======> Saving the outgroup for later re-rooting ...\n") self.save_rooting() self.cfg.log.info( "=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num) self.resolve_multif() self.load_reduced_refalign() self.cfg.log.info( "========> Calling RAxML-EPA to obtain branch labels ...\n") self.epa_branch_labeling() self.cfg.log.info( "=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n" ) self.epa_post_process() self.calc_node_heights() self.cfg.log.debug("\n==========> Checking branch labels ...") self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks())) self.cfg.log.debug("shared rank names after training: %s\n", repr(self.mono_index())) self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname) self.write_json()
def test_rank_uid(self): tax = self.taxonomy for sid in tax.get_map().iterkeys(): self.assertEqual(tax.get_seq_ranks(sid), Taxonomy.split_rank_uid(tax.seq_rank_id(sid)))
def test_load(self): self.assertEqual(self.TAX_DICT, self.taxonomy.seq_ranks_map) prefixed_tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) self.assertEqual(self.PREFIXED_TAX_DICT, prefixed_tax.seq_ranks_map) prefixed_tax = None
def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_tips = {} rank_parent = {} for node in self.tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size - 4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name( orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count
def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_tips = {} rank_parent = {} for node in self.tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size-4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count
class LeaveOneTest: def __init__(self, config): self.cfg = config self.mis_fname = self.cfg.out_fname("%NAME%.mis") self.premis_fname = self.cfg.out_fname("%NAME%.premis") self.misrank_fname = self.cfg.out_fname("%NAME%.misrank") self.stats_fname = self.cfg.out_fname("%NAME%.stats") if os.path.isfile(self.mis_fname): print "\nERROR: Output file already exists: %s" % self.mis_fname print "Please specify a different job name using -n or remove old output files." self.cfg.exit_user_error() self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.mislabels = [] self.mislabels_cnt = [] self.rank_mislabels = [] self.rank_mislabels_cnt = [] self.misrank_conf_map = {} def write_bid_tax_map(self, bid_tax_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix) with open(bid_fname, "w") as outf: for bid, bid_rec in bid_tax_map.iteritems(): outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2])); def write_assignments(self, assign_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix) with open(assign_fname, "w") as outf: for seq_name in assign_map.iterkeys(): ranks, lws = assign_map[seq_name] outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join(["%.3f" % l for l in lws]))) def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS def run_epa_trainer(self): epa_trainer.run_trainer(self.cfg) if not os.path.isfile(self.cfg.refjson_fname): self.cfg.log.error("\nBuilding reference tree failed, see error messages above.") self.cfg.exit_fatal_error() def classify_seq(self, placement): edges = placement["p"] if len(edges) > 0: return self.classify_helper.classify_seq(edges) else: print "ERROR: no placements! something is definitely wrong!" def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec def filter_mislabels(self): filtered_mis = [] for i in range(len(self.mislabels)): if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff: filtered_mis.append(self.mislabels[i]) self.mislabels = filtered_mis def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = rank_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.rank_mislabels.append(mis_rec) return mis_rec else: return None def mis_rec_to_string_old(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += ";".join(mis_rec['orig_ranks']) + "\n" output += ";".join(mis_rec['ranks']) + "\n" output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n" return output def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output def sort_mislabels(self): self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.mislabels: real_lvl = mis_rec["real_level"] self.mislabels_cnt[real_lvl] += 1 if self.cfg.ranktest: self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.rank_mislabels: real_lvl = mis_rec["real_level"] self.rank_mislabels_cnt[real_lvl] += 1 def write_stats(self, toFile=False): self.cfg.log.info("Mislabeled sequences by rank:") seq_sum = 0 rank_sum = 0 stats = [] for i in range(len(self.mislabels_cnt)): if i > 0: rname = self.tax_code.rank_level_name(i)[0].ljust(12) else: rname = "[NotIngroup]" if self.mislabels_cnt[i] > 0: seq_sum += self.mislabels_cnt[i] # output = "%s:\t%d" % (rname, seq_sum) output = "%s:\t%d" % (rname, self.mislabels_cnt[i]) if self.cfg.ranktest: rank_sum += self.rank_mislabels_cnt[i] output += "\t%d" % rank_sum self.cfg.log.info(output) stats.append(output) if toFile: with open(self.stats_fname, "w") as fo_stat: for line in stats: fo_stat.write(line + "\n") def write_mislabels(self, final=True): if final: out_fname = self.mis_fname else: out_fname = self.premis_fname with open(out_fname, "w") as fo_all: fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] if self.cfg.ranktest: fields += ["HigherRankMisplacedConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.mislabels) > 0 and final: print "Mislabeled sequences:\n" print header for mis_rec in self.mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose and final: print(output) if not final: return if self.cfg.ranktest: with open(self.misrank_fname, "w") as fo_all: fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.rank_mislabels) > 0: print "\nMislabeled higher ranks:\n" print header for mis_rec in self.rank_mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose: print(output) self.write_stats() def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_tips = {} rank_parent = {} for node in self.tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size-4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2,len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() tmp_reftree = self.reftree.copy(method="newick") name2refnode = {} for leaf in tmp_reftree.iter_leaves(): name2refnode[leaf.name] = leaf tmp_taxtree = self.tax_tree.copy(method="newick") name2taxnode = {} for leaf in tmp_taxtree.iter_leaves(): name2taxnode[leaf.name] = leaf for mis_rec in self.mislabels: rname = mis_rec['name'] # rname = EpacConfig.REF_SEQ_PREFIX + name if rname in name2refnode: name2refnode[rname].delete() else: print "Node not found in the reference tree: %s" % rname if rname in name2taxnode: name2taxnode[rname].delete() else: print "Node not found in the taxonomic tree: %s" % rname # remove unifurcation at the root if len(tmp_reftree.children) == 1: tmp_reftree = tmp_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(tmp_taxtree) epa_result = self.run_epa_once(tmp_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True) def run_epa_once(self, reftree): reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre") job_name = self.cfg.subst_name("final_epa_%NAME%") reftree.write(outfile=reftree_fname) # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! optmod_fname="" epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname) if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True) return epa_result def run_test(self): self.raxml = RaxmlWrapper(self.cfg) # config.log.info("Number of sequences in the reference: %d\n", self.reftree_size) self.refjson.get_raxml_readable_tree(self.reftree_fname) self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln) self.refjson.get_binary_model(self.optmod_fname) if self.cfg.ranktest: config.log.info("Running the leave-one-rank-out test...\n") subtree_count = self.run_leave_subtree_out_test() config.log.info("Running the leave-one-sequence-out test...\n") self.run_leave_seq_out_test() if len(self.mislabels) > 0: config.log.info("Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels)) if self.cfg.debug: self.write_mislabels(final=False) self.run_final_epa_test() self.filter_mislabels() self.sort_mislabels() self.write_mislabels() config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
class LeaveOneTest: def __init__(self, config): self.cfg = config self.mis_fname = self.cfg.out_fname("%NAME%.mis") self.premis_fname = self.cfg.out_fname("%NAME%.premis") self.misrank_fname = self.cfg.out_fname("%NAME%.misrank") self.stats_fname = self.cfg.out_fname("%NAME%.stats") if os.path.isfile(self.mis_fname): print("\nERROR: Output file already exists: %s" % self.mis_fname) print( "Please specify a different job name using -n or remove old output files." ) self.cfg.exit_user_error() self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.mislabels = [] self.mislabels_cnt = [] self.rank_mislabels = [] self.rank_mislabels_cnt = [] self.misrank_conf_map = {} def write_bid_tax_map(self, bid_tax_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix) with open(bid_fname, "w") as outf: for bid, bid_rec in bid_tax_map.items(): outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2])) def write_assignments(self, assign_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix) with open(assign_fname, "w") as outf: for seq_name in assign_map.keys(): ranks, lws = assign_map[seq_name] outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join( ["%.3f" % l for l in lws]))) def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error( "ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS def run_epa_trainer(self): epa_trainer.run_trainer(self.cfg) if not os.path.isfile(self.cfg.refjson_fname): self.cfg.log.error( "\nBuilding reference tree failed, see error messages above.") self.cfg.exit_fatal_error() def classify_seq(self, placement): edges = placement["p"] if len(edges) > 0: return self.classify_helper.classify_seq(edges) else: print("ERROR: no placements! something is definitely wrong!") def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks), len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[ rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level( orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name( real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec def filter_mislabels(self): filtered_mis = [] for i in range(len(self.mislabels)): if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff: filtered_mis.append(self.mislabels[i]) self.mislabels = filtered_mis def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks), len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[ rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = rank_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.rank_mislabels.append(mis_rec) return mis_rec else: return None def mis_rec_to_string_old(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\n" % ( mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += ";".join(mis_rec['orig_ranks']) + "\n" output += ";".join(mis_rec['ranks']) + "\n" output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n" return output def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix( self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks( mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % ( mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output def sort_mislabels(self): self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.mislabels: real_lvl = mis_rec["real_level"] self.mislabels_cnt[real_lvl] += 1 if self.cfg.ranktest: self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter( 'inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.rank_mislabels: real_lvl = mis_rec["real_level"] self.rank_mislabels_cnt[real_lvl] += 1 def write_stats(self, toFile=False): self.cfg.log.info("Mislabeled sequences by rank:") seq_sum = 0 rank_sum = 0 stats = [] for i in range(len(self.mislabels_cnt)): if i > 0: rname = self.tax_code.rank_level_name(i)[0].ljust(12) else: rname = "[NotIngroup]" if self.mislabels_cnt[i] > 0: seq_sum += self.mislabels_cnt[i] # output = "%s:\t%d" % (rname, seq_sum) output = "%s:\t%d" % (rname, self.mislabels_cnt[i]) if self.cfg.ranktest: rank_sum += self.rank_mislabels_cnt[i] output += "\t%d" % rank_sum self.cfg.log.info(output) stats.append(output) if toFile: with open(self.stats_fname, "w") as fo_stat: for line in stats: fo_stat.write(line + "\n") def write_mislabels_header(self, fo, final, fields): header = ";" + "\t".join(fields) + "\n" # write to file if final: for line in DISCLAIMER.split("\n"): fo.write(";%s\n" % line) fo.write(";\n") fo.write(header) # print to console if final and self.cfg.verbose and len(self.rank_mislabels) > 0: print(DISCLAIMER, "\n") print("Mislabeled sequences:\n") print(header) def write_rank_mislabels(self, final=True): if not self.cfg.ranktest: return with open(self.misrank_fname, "w") as fo_all: fields = [ "RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence" ] self.write_mislabels_header(fo_all, final, fields) for mis_rec in self.rank_mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose: print(output) def write_mislabels(self, final=True): if final: out_fname = self.mis_fname else: out_fname = self.premis_fname with open(out_fname, "w") as fo_all: fields = [ "SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence" ] if self.cfg.ranktest: fields += ["HigherRankMisplacedConfidence"] self.write_mislabels_header(fo_all, final, fields) for mis_rec in self.mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose and final: print(output) if final: self.write_rank_mislabels() self.write_stats() def get_parent_tip_ranks(self, tax_tree): rank_tips = {} rank_parent = {} for node in tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size - 4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks return rank_parent, rank_tips def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree) subtree_list = list(rank_tips.items()) if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.tax_code.guess_rank_level_name( orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname( "%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2, len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count def prune_mislabels_from_tree(self, src_tree, tree_name): pruned_tree = src_tree.copy(method="newick") name2node = {} for leaf in pruned_tree.iter_leaves(): name2node[leaf.name] = leaf for mis_rec in self.mislabels: rname = mis_rec['name'] # rname = EpacConfig.REF_SEQ_PREFIX + name if rname in name2node: name2node[rname].delete() else: config.log.debug("Node not found in the %s tree: %s" % (tree_name, rname)) return pruned_tree def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() pruned_reftree = self.prune_mislabels_from_tree( self.reftree, "reference") pruned_taxtree = self.prune_mislabels_from_tree( self.reftree, "taxonomic") # remove unifurcation at the root if len(pruned_reftree.children) == 1: pruned_reftree = pruned_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(pruned_taxtree) reftree_epalbl_str = None if self.cfg.final_jplace_fname: if os.path.isdir(self.cfg.final_jplace_fname): jplace_fmask = os.path.join(self.cfg.final_jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.final_jplace_fname jplace_fname_list = glob.glob(jplace_fmask) placements = [] for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() if not reftree_epalbl_str: reftree_epalbl_str = jp.get_std_newick_tree() config.log.debug("Loaded %d final epa placements from %s\n", len(placements), jplace_fmask) else: epa_result = self.run_epa_once(pruned_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True) def run_epa_once(self, reftree): reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre") job_name = self.cfg.subst_name("final_epa_%NAME%") reftree.write(outfile=reftree_fname) # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! optmod_fname = "" epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname) if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True) return epa_result def run_test(self): self.raxml = RaxmlWrapper(self.cfg) # config.log.info("Number of sequences in the reference: %d\n", self.reftree_size) self.refjson.get_raxml_readable_tree(self.reftree_fname) self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln) self.refjson.get_binary_model(self.optmod_fname) if self.cfg.ranktest: config.log.info("Running the leave-one-rank-out test...\n") subtree_count = self.run_leave_subtree_out_test() config.log.info("Running the leave-one-sequence-out test...\n") self.run_leave_seq_out_test() if len(self.mislabels) > 0: config.log.info( "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels)) if self.cfg.debug: self.write_mislabels(final=False) self.run_final_epa_test() self.filter_mislabels() self.sort_mislabels() self.write_mislabels() config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
class RefTreeBuilder: def __init__(self, config): self.cfg = config self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%") self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%") self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%") self.raxml_wrapper = RaxmlWrapper(config) self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre") self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre") self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa") self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre") self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt") def load_alignment(self): in_file = self.cfg.align_fname self.input_seqs = None formats = ["fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"] for fmt in formats: try: self.input_seqs = SeqGroup(sequences=in_file, format = fmt) break except: self.cfg.log.debug("Guessing input format: not " + fmt) if self.input_seqs == None: self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file) def validate_taxonomy(self): self.input_validator = InputValidator(self.cfg, self.taxonomy, self.input_seqs) self.input_validator.validate() def build_multif_tree(self): c = self.cfg tb = TaxTreeBuilder(c, self.taxonomy) (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore) self.reftree_ids = frozenset(ids) self.reftree_size = len(ids) self.reftree_multif = t # IMPORTANT: select GAMMA or CAT model based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) if self.cfg.debug: refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt") # list of sequence ids which comprise the reference tree with open(refseq_fname, "w") as f: for sid in ids: f.write("%s\n" % sid) # original tree with taxonomic ranks as internal node labels reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre") t.write(outfile=reftax_fname, format=8) # t.show() def export_ref_alignment(self): """This function transforms the input alignment in the following way: 1. Filter out sequences which are not part of the reference tree 2. Add sequence name prefix (r_)""" self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa") with open(self.refalign_fname, "w") as fout: for name, seq, comment, sid in self.input_seqs.iter_entries(): seq_name = EpacConfig.REF_SEQ_PREFIX + name if seq_name in self.input_validator.corr_seqid: seq_name = self.input_validator.corr_seqid[seq_name] if seq_name in self.reftree_ids: fout.write(">" + seq_name + "\n" + seq + "\n") # we do not need the original alignment anymore, so free its memory self.input_seqs = None def export_ref_taxonomy(self): self.taxonomy_map = {} for sid, ranks in self.taxonomy.iteritems(): if sid in self.reftree_ids: self.taxonomy_map[sid] = ranks if self.cfg.debug: tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt") with open(tax_fname, "w") as fout: for sid, ranks in self.taxonomy_map.iteritems(): ranks_str = self.taxonomy.seq_lineage_str(sid) fout.write(sid + "\t" + ranks_str + "\n") def save_rooting(self): rt = self.reftree_multif tax_map = self.taxonomy.get_map() self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map) self.taxtree_helper.set_mf_rooted_tree(rt) outgr = self.taxtree_helper.get_outgroup() outgr_size = len(outgr.get_leaves()) outgr.write(outfile=self.outgr_fname, format=9) self.reftree_outgroup = outgr self.cfg.log.debug("Outgroup for rooting was saved to: %s, outgroup size: %d", self.outgr_fname, outgr_size) # remove unifurcation at the root if len(rt.children) == 1: rt = rt.children[0] # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml rt.write(outfile=self.reftree_mfu_fname, format=9) # RAxML call to convert multifurcating tree to the strictly bifurcating one def resolve_multif(self): self.cfg.log.debug("\nReducing the alignment: \n") self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(self.refalign_fname) self.cfg.log.debug("\nConstrained ML inference: \n") raxml_params = ["-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "--no-seq-check", "-N", str(self.cfg.rep_num)] if self.cfg.mfresolv_method == "fast": raxml_params += ["-D"] elif self.cfg.mfresolv_method == "ultrafast": raxml_params += ["-f", "e"] if self.cfg.restart and self.raxml_wrapper.result_exists(self.mfresolv_job_name): self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(self.mfresolv_job_name) self.cfg.log.debug("\nUsing existing ML tree found in: %s\n", self.raxml_wrapper.result_fname(self.mfresolv_job_name)) else: self.invocation_raxml_multif = self.raxml_wrapper.run(self.mfresolv_job_name, raxml_params) # self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num) if self.cfg.mfresolv_method == "ultrafast": self.raxml_wrapper.copy_result_tree(self.mfresolv_job_name, self.raxml_wrapper.besttree_fname(self.mfresolv_job_name)) if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name): if not self.cfg.reopt_model: self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name, self.reftree_bfu_fname) self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname) self.invocation_raxml_optmod = "" job_name = self.mfresolv_job_name else: bfu_fname = self.raxml_wrapper.besttree_fname(self.mfresolv_job_name) job_name = self.optmod_job_name # RAxML call to optimize model parameters and write them down to the binary model file self.cfg.log.debug("\nOptimizing model parameters: \n") raxml_params = ["-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check"] if self.cfg.raxml_model.startswith("GTRCAT") and not self.cfg.compress_patterns: raxml_params += ["-H"] if self.cfg.restart and self.raxml_wrapper.result_exists(self.optmod_job_name): self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(self.optmod_job_name) self.cfg.log.debug("\nUsing existing optimized tree and parameters found in: %s\n", self.raxml_wrapper.result_fname(self.optmod_job_name)) else: self.invocation_raxml_optmod = self.raxml_wrapper.run(self.optmod_job_name, raxml_params) if self.raxml_wrapper.result_exists(self.optmod_job_name): self.raxml_wrapper.copy_result_tree(self.optmod_job_name, self.reftree_bfu_fname) self.raxml_wrapper.copy_optmod_params(self.optmod_job_name, self.optmod_fname) else: errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name) self.cfg.exit_fatal_error(errmsg) if self.cfg.raxml_model.startswith("GTRCAT"): mod_name = "CAT" else: mod_name = "GAMMA" self.reftree_loglh = self.raxml_wrapper.get_tree_lh(job_name, mod_name) self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" % (mod_name, self.reftree_loglh)) else: errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name) self.cfg.exit_fatal_error(errmsg) def load_reduced_refalign(self): formats = ["fasta", "phylip_relaxed"] for fmt in formats: try: self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt) break except: pass if self.reduced_refalign_seqs == None: errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname self.cfg.exit_fatal_error(errmsg) # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks def epa_branch_labeling(self): # create alignment with dummy query seq self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0)) self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname) with open(self.lblalign_fname, "a") as fout: fout.write(">" + "DUMMY131313" + "\n") fout.write("A"*self.refalign_width + "\n") # TODO always load model regardless of the config file settings? epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname, mode="epa_mp") self.reftree_lbl_str = epa_result.get_std_newick_tree() self.raxml_version = epa_result.get_raxml_version() self.invocation_raxml_epalbl = epa_result.get_raxml_invocation() if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name): errmsg = "RAxML EPA run failed, please examine the log for details: %s" \ % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name) self.cfg.exit_fatal_error(errmsg) def epa_post_process(self): lbl_tree = Tree(self.reftree_lbl_str) self.taxtree_helper.set_bf_unrooted_tree(lbl_tree) self.reftree_tax = self.taxtree_helper.get_tax_tree() self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map() if self.cfg.debug: self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3) with open(self.reftree_lbl_fname, "w") as outf: outf.write(self.reftree_lbl_str) with open(self.brmap_fname, "w") as outf: for bid, br_rec in self.bid_ranks_map.iteritems(): outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2])) def calc_node_heights(self): """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step) Algorithm is as follows: Tip node or node resolved to Species level: height = 1 Inner node resolved to Genus or above: height = min(left_height, right_height) + 1 """ nh_map = {} dummy_added = False for node in self.reftree_tax.traverse("postorder"): if not node.is_root(): if not hasattr(node, "B"): # In a rooted tree, there is always one more node/branch than in unrooted one # That's why one branch will be always not EPA-labelled after the rooting if not dummy_added: node.B = "DDD" dummy_added = True species_rank = Taxonomy.EMPTY_RANK else: errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)" self.cfg.exit_fatal_error(errmsg) else: species_rank = self.bid_ranks_map[node.B][-1] bid = node.B if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK: nh_map[bid] = 1 else: lchild = node.children[0] rchild = node.children[1] nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1 # remove heights for dummy nodes, since there won't be any placements on them if dummy_added: del nh_map["DDD"] self.node_height_map = nh_map def __get_all_rank_names(self, root): rnames = set([]) for node in root.traverse("postorder"): ranks = node.ranks for rk in ranks: rnames.add(rk) return rnames def mono_index(self): """This method will calculate monophyly index by looking at the left and right hand side of the tree""" children = self.reftree_tax.children if len(children) == 1: while len(children) == 1: children = children[0].children if len(children) == 2: left = children[0] right =children[1] lset = self.__get_all_rank_names(left) rset = self.__get_all_rank_names(right) iset = lset & rset return iset else: print("Error: input tree not birfurcating") return set([]) def build_hmm_profile(self, json_builder): print "Building the HMMER profile...\n" # this stupid workaround is needed because RAxML outputs the reduced # alignment in relaxed PHYLIP format, which is not supported by HMMER refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa") self.reduced_refalign_seqs.write(outfile=refalign_fasta) hmm = hmmer(self.cfg, refalign_fasta) fprofile = hmm.build_hmm_profile() json_builder.set_hmm_profile(fprofile) def write_json(self): jw = RefJsonBuilder() jw.set_branch_tax_map(self.bid_ranks_map) jw.set_tree(self.reftree_lbl_str) jw.set_outgroup(self.reftree_outgroup) jw.set_ratehet_model(self.cfg.raxml_model) jw.set_tax_tree(self.reftree_multif) jw.set_pattern_compression(self.cfg.compress_patterns) jw.set_taxcode(self.cfg.taxcode_name) jw.set_merged_ranks_map(self.input_validator.merged_ranks) corr_ranks_reverse = dict((reversed(item) for item in self.input_validator.corr_ranks.items())) jw.set_corr_ranks_map(corr_ranks_reverse) corr_seqid_reverse = dict((reversed(item) for item in self.input_validator.corr_seqid.items())) jw.set_corr_seqid_map(corr_seqid_reverse) mdata = { "ref_tree_size": self.reftree_size, "ref_alignment_width": self.refalign_width, "raxml_version": self.raxml_version, "timestamp": str(datetime.datetime.now()), "invocation_epac": self.invocation_epac, "invocation_raxml_multif": self.invocation_raxml_multif, "invocation_raxml_optmod": self.invocation_raxml_optmod, "invocation_raxml_epalbl": self.invocation_raxml_epalbl, "reftree_loglh": self.reftree_loglh } jw.set_metadata(mdata) seqs = self.reduced_refalign_seqs.get_entries() jw.set_sequences(seqs) if not self.cfg.no_hmmer: self.build_hmm_profile(jw) orig_tax = self.taxonomy_map jw.set_origin_taxonomy(orig_tax) self.cfg.log.debug("Calculating the speciation rate...\n") tp = tree_param(tree = self.reftree_lbl_str, origin_taxonomy = orig_tax) jw.set_rate(tp.get_speciation_rate_fast()) jw.set_nodes_height(self.node_height_map) jw.set_binary_model(self.optmod_fname) self.cfg.log.debug("Writing down the reference file...\n") jw.dump(self.cfg.refjson_fname) # top-level function to build a reference tree def build_ref_tree(self): self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname) self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname) self.load_alignment() self.cfg.log.info("===> Validating taxonomy and alignment ...\n") self.validate_taxonomy() self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count()) self.build_multif_tree() self.cfg.log.info("=====> Building the reference alignment ...\n") self.export_ref_alignment() self.export_ref_taxonomy() self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n") self.save_rooting() self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num) self.resolve_multif() self.load_reduced_refalign() self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n") self.epa_branch_labeling() self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n") self.epa_post_process() self.calc_node_heights() self.cfg.log.debug("\n==========> Checking branch labels ...") self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks())) self.cfg.log.debug("shared rank names after training: %s\n", repr(self.mono_index())) self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname) self.write_json()
def assign_taxonomy_maxsum(self, edges, minlw): """this function sums up all LH-weights for each rank and takes the rank with the max. sum """ # in EPA result, each placement(=branch) has a "weight" # since we are interested in taxonomic placement, we do not care about branch vs. branch comparisons, # but only consider rank vs. rank (e. g. G1 S1 vs. G1 S2 vs. G1) # Thus we accumulate weights for each rank, there are to measures: # "own" weight = sum of weight of all placements EXACTLY to this rank (e.g. for G1: G1 only) # "total" rank = own rank + own rank of all children (for G1: G1 or G1 S1 or G1 S2) rw_own = {} rw_total = {} rb = {} ranks = [Taxonomy.EMPTY_RANK] for edge in edges: br_id = str(edge[0]) lweight = edge[2] lowest_rank = None if lweight == 0.0: continue # accumulate weight for the current sequence ranks = self.bid_taxonomy_map[br_id] for i in range(len(ranks)): rank = ranks[i] rank_id = Taxonomy.get_rank_uid(ranks, i) if rank != Taxonomy.EMPTY_RANK: rw_total[rank_id] = rw_total.get(rank_id, 0) + lweight lowest_rank = rank_id if not rank_id in rb: rb[rank_id] = br_id else: break if lowest_rank: rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight rb[lowest_rank] = br_id elif self.cfg.verbose: print "WARNING: no annotation for branch ", br_id # if all branches have empty ranks only, just return this placement if len(rw_total) == 0: return ranks, [1.0] * len(ranks) # we assign the sequence to a rank, which has the max "own" weight AND # whose "total" weight is greater than a confidence threshold max_rw = 0.0 s_r = None for r in rw_own.iterkeys(): if rw_own[r] > max_rw and rw_total[r] >= minlw: s_r = r max_rw = rw_own[r] if not s_r: s_r = max(rw_total.iterkeys(), key=(lambda key: rw_total[key])) a_br_id = rb[s_r] a_ranks = self.bid_taxonomy_map[a_br_id] # "total" weight is considered as confidence value for now a_conf = [0.0] * len(a_ranks) for i in range(len(a_conf)): rank = a_ranks[i] if rank != Taxonomy.EMPTY_RANK: rank_id = Taxonomy.get_rank_uid(a_ranks, i) a_conf[i] = rw_total[rank_id] return a_ranks, a_conf