def run_epa(self): self.cfg.log.info( "Running RAxML-EPA to place %d query sequences...\n" % self.query_count) raxml = RaxmlWrapper(config) reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.refjson.get_raxml_readable_tree(reftree_fname) optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.refjson.get_binary_model(optmod_fname) job_name = self.cfg.subst_name("epa_%NAME%") reftree_str = self.refjson.get_raxml_readable_tree() reftree = Tree(reftree_str) self.reftree_size = len(reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() reduced_align_fname = raxml.reduce_alignment(self.epa_alignment) jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname) raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True) return jp
def get_nodesheight(self): root = Tree(self.tree) nh_map = {} for node in root.traverse(strategy = "preorder"): if hasattr(node, "B"): height = node.get_closest_leaf(topology_only=True) #height = node.get_farthest_leaf(topology_only=True) nh_map[node.B] = height[1] + 1 return nh_map
def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error( "ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
def test_taxtree_builder(self): cfg = EpacConfig() testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") tax_fname = os.path.join(testfile_dir, "test.tax") tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname) tree_fname = os.path.join(testfile_dir, "taxtree.nw") expected_tree = Tree(tree_fname, format=8) tb = TaxTreeBuilder(cfg, tax) tax_tree, seq_ids = tb.build() self.assertEqual(seq_ids, tax.get_map().keys()) self.assertEqual(tax_tree.write(format=8), expected_tree.write(format=8))
def setUp(self): self.testfile_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname)
def test_outgroup(self): mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw") mfu_tree = Tree(mfu_tree_fname) self.taxtree_helper.set_mf_rooted_tree(mfu_tree) outgr = self.taxtree_helper.get_outgroup() self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names())
def setUp(self): self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname)
def epa_2_ptp(epa_jp, ref_jp, full_alignment, min_lw = 0.5, debug = False): placements = epa_jp.get_placement() reftree = Tree(epa_jp.get_std_newick_tree()) allnodes = reftree.get_descendants() species_list = [] placemap = {} """find how many edges are used for placement, and create a map to store """ for placement in placements: edges = placement["p"] curredge = edges[0][0] lw = edges[0][2] if lw >= min_lw: placemap[curredge] = placemap.get(curredge, []) """group taxa name by placement branch""" for placement in placements: edges = placement["p"] taxa_names = placement["n"] curredge = edges[0][0] lw = edges[0][2] if lw >= min_lw: a = placemap[curredge] a.extend(taxa_names) placemap[curredge] = a groups = placemap.items() cnt_leaf = 0 cnt_inode = 0 """check each placement edge""" for i,item in enumerate(groups): place_branch_name = item[0] seqset = item[1] if len(seqset) < 4: species_list.append(seqset) else: branch_alignment = SeqGroup() for taxa in seqset: branch_alignment.set_seq(taxa, full_alignment.get_seq(taxa)) species = build_tree_run_ptp(branch_alignment, ref_jp.get_rate()) species_list.extend(species) return species_list
def get_speciation_rate(self): #pruning the input tree such that each species only appear once species = set() keepseqs = [] for name in self.taxonomy.keys(): ranks = self.taxonomy[name] sp = ranks[-1] if sp == "-": keepseqs.append(name) else: if not sp in species: keepseqs.append(name) species.add(sp) root = Tree(self.tree) root.prune(keepseqs, preserve_branch_length=True) sumbr = 0.0 cnt = 0.0 for node in root.traverse(strategy = "preorder"): sumbr = sumbr + node.dist cnt = cnt + 1.0 return float(cnt) / float(sumbr)
def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
def epa_post_process(self): lbl_tree = Tree(self.reftree_lbl_str) self.taxtree_helper.set_bf_unrooted_tree(lbl_tree) self.reftree_tax = self.taxtree_helper.get_tax_tree() self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map() if self.cfg.debug: self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3) with open(self.reftree_lbl_fname, "w") as outf: outf.write(self.reftree_lbl_str) with open(self.brmap_fname, "w") as outf: for bid, br_rec in self.bid_ranks_map.iteritems(): outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2]))
def get_speciation_rate_fast(self): """ETE2 prune() function is extremely slow on large trees, so this function don't use it and instead just removes "redundant" species-level nodes one-by-one""" species = set() root = Tree(self.tree) name2node = {} for node in root.traverse(strategy = "postorder"): if node.is_leaf(): name2node[node.name] = node #pruning the input tree such that each species only appear once for name in self.taxonomy.keys(): ranks = self.taxonomy[name] sp = ranks[-1] if sp != "-": if sp in species: node = name2node.get(name, None) if node: node.delete(preserve_branch_length=True) else: raise ValueError("Node names not found in the tree: " + name) else: species.add(sp) # traverse the pruned tree, counting the number of speciation events and # summing up the branch lengths sumbr = 0.0 cnt = 0 for node in root.traverse(strategy = "preorder"): sumbr += node.dist cnt += 1 # sp_rate = number_of_sp_events / sum_of_branch_lengts return float(cnt) / float(sumbr)
def __init__(self, config, args): self.cfg = config self.method = args.method self.minlw = args.min_lhw self.jplace_fname = args.jplace_fname self.ranktest = args.ranktest self.output_fname = args.output_dir + "/" + args.output_name # switch off branch length filter self.brlen_pv = 0. self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") try: self.refjson = RefJsonParser(config.refjson_fname, ver="1.2") except ValueError: print("Invalid json file format!") sys.exit() #validate input json format self.refjson.validate() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.bid_taxonomy_map = self.refjson.get_bid_tanomomy_map() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.brlen_pv, self.rate, self.node_height) self.TAXONOMY_RANKS_COUNT = 10 self.mislabels = [] self.mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT self.rank_mislabels = [] self.rank_mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT self.misrank_conf_map = {}
def test_jplace_read(self): jplace_fname = os.path.join(self.testfile_dir, "test.jplace") parser = EpaJsonParser(jplace_fname) self.assertEqual(parser.get_raxml_version(), "8.2.3") t = Tree(parser.get_tree()) t_len = len(t) self.assertEqual(t_len, 32) self.assertEqual(len(parser.get_placement()), 6) for p in parser.get_placement(): self.assertFalse(p["n"][0] in t) self.assertTrue(len(p["p"]) > 0) for edge in p["p"]: branch = int(edge[0]) lh = edge[1] lhw = edge[2] self.assertTrue(branch >= 0 and branch < (t_len * 2 - 3)) self.assertTrue(lhw >= 0.0 and lhw <= 1.0)
class TaxTreeHelperTests(unittest.TestCase): def setUp(self): self.testfile_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname) def tearDown(self): self.taxonomy = None self.taxtree_helper = None def test_outgroup(self): mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw") mfu_tree = Tree(mfu_tree_fname) self.taxtree_helper.set_mf_rooted_tree(mfu_tree) outgr = self.taxtree_helper.get_outgroup() self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names()) def test_branch_labeling(self): bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw") bfu_tree = Tree(bfu_tree_fname) map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt") self.expected_map = {} with open(map_fname) as inf: for line in inf: bid, rank_id, rdiff, brlen = line.strip().split("\t") self.expected_map[bid] = (rank_id, int(rdiff), float(brlen)) self.taxtree_helper.set_outgroup(self.expected_outgr) self.taxtree_helper.set_bf_unrooted_tree(bfu_tree) bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map() self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3) for bid in self.expected_map.iterkeys(): e_rec = self.expected_map[bid] rec = bid_tax_map[bid] self.assertEqual(e_rec[0], rec[0]) self.assertEqual(e_rec[1], rec[1]) self.assertAlmostEqual(e_rec[2], rec[2], 6)
class TaxTreeHelperTests(unittest.TestCase): def setUp(self): self.testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testfiles") self.tax_fname = os.path.join(self.testfile_dir, "test_clean.tax") self.taxonomy = Taxonomy(EpacConfig.REF_SEQ_PREFIX, self.tax_fname) tax_map = self.taxonomy.get_map() cfg = EpacConfig() self.taxtree_helper = TaxTreeHelper(cfg, tax_map) outgr_fname = os.path.join(self.testfile_dir, "outgroup.nw") self.expected_outgr = Tree(outgr_fname) def tearDown(self): self.taxonomy = None self.taxtree_helper = None def test_outgroup(self): mfu_tree_fname = os.path.join(self.testfile_dir, "taxtree.nw") mfu_tree = Tree(mfu_tree_fname) self.taxtree_helper.set_mf_rooted_tree(mfu_tree) outgr = self.taxtree_helper.get_outgroup() self.assertEqual(outgr.get_leaf_names(), self.expected_outgr.get_leaf_names()) def test_branch_labeling(self): bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw") bfu_tree = Tree(bfu_tree_fname) map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt") self.expected_map = {} with open(map_fname) as inf: for line in inf: bid, rank_id, rdiff, brlen = line.strip().split("\t") self.expected_map[bid] = (rank_id, int(rdiff), float(brlen)) self.taxtree_helper.set_outgroup(self.expected_outgr) self.taxtree_helper.set_bf_unrooted_tree(bfu_tree) bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map() self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3) for bid in self.expected_map.iterkeys(): e_rec = self.expected_map[bid] rec = bid_tax_map[bid] self.assertEqual(e_rec[0], rec[0]) self.assertEqual(e_rec[1], rec[1]) self.assertAlmostEqual(e_rec[2], rec[2], 6)
def test_branch_labeling(self): bfu_tree_fname = os.path.join(self.testfile_dir, "resolved_tree.nw") bfu_tree = Tree(bfu_tree_fname) map_fname = os.path.join(self.testfile_dir, "bid_tax_map.txt") self.expected_map = {} with open(map_fname) as inf: for line in inf: bid, rank_id, rdiff, brlen = line.strip().split("\t") self.expected_map[bid] = (rank_id, int(rdiff), float(brlen)) self.taxtree_helper.set_outgroup(self.expected_outgr) self.taxtree_helper.set_bf_unrooted_tree(bfu_tree) bid_tax_map = self.taxtree_helper.get_bid_taxonomy_map() self.assertEqual(len(bid_tax_map), 2 * len(bfu_tree) - 3) for bid in self.expected_map.iterkeys(): e_rec = self.expected_map[bid] rec = bid_tax_map[bid] self.assertEqual(e_rec[0], rec[0]) self.assertEqual(e_rec[1], rec[1]) self.assertAlmostEqual(e_rec[2], rec[2], 6)
def classify(self, query_fname, fout = None, method = "1", minlw = 0.0, pv = 0.02, minp = 0.9, ptp = False): if self.jplace_fname: jp = EpaJsonParser(self.jplace_fname) else: self.checkinput(query_fname, minp) raxml = RaxmlWrapper(config) reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.refjson.get_raxml_readable_tree(reftree_fname) optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.refjson.get_binary_model(optmod_fname) job_name = self.cfg.subst_name("epa_%NAME%") reftree_str = self.refjson.get_raxml_readable_tree() reftree = Tree(reftree_str) self.reftree_size = len(reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() reduced_align_fname = raxml.reduce_alignment(self.epa_alignment) jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname) placements = jp.get_placement() if fout: fo = open(fout, "w") else: fo = None output2 = "" for place in placements: output = None taxon_name = place["n"][0] origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name) edges = place["p"] # edges = self.erlang_filter(edges, p = pv) if len(edges) > 0: ranks, lws = self.classify_helper.classify_seq(edges, method, minlw) isnovo = self.novelty_check(place_edge = str(edges[0][0]), ranks =ranks, lws = lws, minlw = minlw) rankout = self.print_ranks(ranks, lws, minlw) if rankout == None: output2 = output2 + origin_taxon_name+ "\t\t\t?\n" else: output = "%s\t%s\t" % (origin_taxon_name, self.print_ranks(ranks, lws, minlw)) if isnovo: output += "*" else: output +="o" if self.cfg.verbose: print(output) if fo: fo.write(output + "\n") else: output2 = output2 + origin_taxon_name+ "\t\t\t?\n" if os.path.exists(self.noalign): with open(self.noalign) as fnoa: lines = fnoa.readlines() for line in lines: taxon_name = line.strip()[1:] origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name) output = "%s\t\t\t?" % origin_taxon_name if self.cfg.verbose: print(output) if fo: fo.write(output + "\n") if self.cfg.verbose: print(output2) if fo: fo.write(output2) fo.close() ############################################# # # EPA-PTP species delimitation # ############################################# if ptp: full_aln = SeqGroup(self.epa_alignment) species_list = epa_2_ptp(epa_jp = jp, ref_jp = self.refjson, full_alignment = full_aln, min_lw = 0.5, debug = self.cfg.debug) if self.cfg.verbose: print "Species clusters:" if fout: fo2 = open(fout+".species", "w") else: fo2 = None for sp_cluster in species_list: translated_taxa = [] for taxon in sp_cluster: origin_taxon_name = EpacConfig.strip_query_prefix(taxon) translated_taxa.append(origin_taxon_name) s = ",".join(translated_taxa) if fo2: fo2.write(s + "\n") if self.cfg.verbose: print s if fo2: fo2.close() ############################################# if not self.jplace_fname: if not self.cfg.debug: raxml.cleanup(job_name) FileUtils.remove_if_exists(reduced_align_fname) FileUtils.remove_if_exists(reftree_fname) FileUtils.remove_if_exists(optmod_fname)
def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() tmp_reftree = self.reftree.copy(method="newick") name2refnode = {} for leaf in tmp_reftree.iter_leaves(): name2refnode[leaf.name] = leaf tmp_taxtree = self.tax_tree.copy(method="newick") name2taxnode = {} for leaf in tmp_taxtree.iter_leaves(): name2taxnode[leaf.name] = leaf for mis_rec in self.mislabels: rname = mis_rec['name'] # rname = EpacConfig.REF_SEQ_PREFIX + name if rname in name2refnode: name2refnode[rname].delete() else: print "Node not found in the reference tree: %s" % rname if rname in name2taxnode: name2taxnode[rname].delete() else: print "Node not found in the taxonomic tree: %s" % rname # remove unifurcation at the root if len(tmp_reftree.children) == 1: tmp_reftree = tmp_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(tmp_taxtree) epa_result = self.run_epa_once(tmp_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True)
def classify(self, query_fname, minp = 0.9, ptp = False): if self.jplace_fname: jp = EpaJsonParser(self.jplace_fname) else: self.checkinput(query_fname, minp) self.cfg.log.info("Running RAxML-EPA to place %d query sequences...\n" % self.query_count) raxml = RaxmlWrapper(config) reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.refjson.get_raxml_readable_tree(reftree_fname) optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.refjson.get_binary_model(optmod_fname) job_name = self.cfg.subst_name("epa_%NAME%") reftree_str = self.refjson.get_raxml_readable_tree() reftree = Tree(reftree_str) self.reftree_size = len(reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() reduced_align_fname = raxml.reduce_alignment(self.epa_alignment) jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname) raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True) self.cfg.log.info("Assigning taxonomic labels based on EPA placements...\n") placements = jp.get_placement() if self.out_assign_fname: fo = open(self.out_assign_fname, "w") else: fo = None noassign_list = [] for place in placements: taxon_name = place["n"][0] origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name) edges = place["p"] if len(edges) > 0: ranks, lws = self.classify_helper.classify_seq(edges) isnovo = self.novelty_check(place_edge = str(edges[0][0]), ranks=ranks, lws=lws) rankout = self.print_ranks(ranks, lws, self.cfg.min_lhw) if rankout == None: noassign_list.append(origin_taxon_name) else: output = "%s\t%s\t" % (origin_taxon_name, rankout) if isnovo: output += "*" else: output +="o" if self.cfg.verbose: print(output) if fo: fo.write(output + "\n") else: noassign_list.append(origin_taxon_name) if os.path.exists(self.noalign): with open(self.noalign) as fnoa: lines = fnoa.readlines() for line in lines: taxon_name = line.strip()[1:] origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name) noassign_list.append(origin_taxon_name) for taxon_name in noassign_list: output = "%s\t\t\t?" % origin_taxon_name if self.cfg.verbose: print(output) if fo: fo.write(output + "\n") if fo: fo.close() ############################################# # # EPA-PTP species delimitation # ############################################# if ptp: full_aln = SeqGroup(self.epa_alignment) species_list = epa_2_ptp(epa_jp = jp, ref_jp = self.refjson, full_alignment = full_aln, min_lw = 0.5, debug = self.cfg.debug) self.cfg.log.debug("Species clusters:") if fout: fo2 = open(fout+".species", "w") else: fo2 = None for sp_cluster in species_list: translated_taxa = [] for taxon in sp_cluster: origin_taxon_name = EpacConfig.strip_query_prefix(taxon) translated_taxa.append(origin_taxon_name) s = ",".join(translated_taxa) if fo2: fo2.write(s + "\n") self.cfg.log.debug(s) if fo2: fo2.close()
class LeaveOneTest: def __init__(self, config): self.cfg = config self.mis_fname = self.cfg.out_fname("%NAME%.mis") self.premis_fname = self.cfg.out_fname("%NAME%.premis") self.misrank_fname = self.cfg.out_fname("%NAME%.misrank") self.stats_fname = self.cfg.out_fname("%NAME%.stats") if os.path.isfile(self.mis_fname): print("\nERROR: Output file already exists: %s" % self.mis_fname) print( "Please specify a different job name using -n or remove old output files." ) self.cfg.exit_user_error() self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.mislabels = [] self.mislabels_cnt = [] self.rank_mislabels = [] self.rank_mislabels_cnt = [] self.misrank_conf_map = {} def write_bid_tax_map(self, bid_tax_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix) with open(bid_fname, "w") as outf: for bid, bid_rec in bid_tax_map.items(): outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2])) def write_assignments(self, assign_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix) with open(assign_fname, "w") as outf: for seq_name in assign_map.keys(): ranks, lws = assign_map[seq_name] outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join( ["%.3f" % l for l in lws]))) def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error( "ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS def run_epa_trainer(self): epa_trainer.run_trainer(self.cfg) if not os.path.isfile(self.cfg.refjson_fname): self.cfg.log.error( "\nBuilding reference tree failed, see error messages above.") self.cfg.exit_fatal_error() def classify_seq(self, placement): edges = placement["p"] if len(edges) > 0: return self.classify_helper.classify_seq(edges) else: print("ERROR: no placements! something is definitely wrong!") def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks), len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[ rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level( orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name( real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec[ 'real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec def filter_mislabels(self): filtered_mis = [] for i in range(len(self.mislabels)): if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff: filtered_mis.append(self.mislabels[i]) self.mislabels = filtered_mis def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks), len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[ rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = rank_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.rank_mislabels.append(mis_rec) return mis_rec else: return None def mis_rec_to_string_old(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\n" % ( mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += ";".join(mis_rec['orig_ranks']) + "\n" output += ";".join(mis_rec['ranks']) + "\n" output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n" return output def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix( self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks( mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % ( mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output def sort_mislabels(self): self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.mislabels: real_lvl = mis_rec["real_level"] self.mislabels_cnt[real_lvl] += 1 if self.cfg.ranktest: self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter( 'inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.rank_mislabels: real_lvl = mis_rec["real_level"] self.rank_mislabels_cnt[real_lvl] += 1 def write_stats(self, toFile=False): self.cfg.log.info("Mislabeled sequences by rank:") seq_sum = 0 rank_sum = 0 stats = [] for i in range(len(self.mislabels_cnt)): if i > 0: rname = self.tax_code.rank_level_name(i)[0].ljust(12) else: rname = "[NotIngroup]" if self.mislabels_cnt[i] > 0: seq_sum += self.mislabels_cnt[i] # output = "%s:\t%d" % (rname, seq_sum) output = "%s:\t%d" % (rname, self.mislabels_cnt[i]) if self.cfg.ranktest: rank_sum += self.rank_mislabels_cnt[i] output += "\t%d" % rank_sum self.cfg.log.info(output) stats.append(output) if toFile: with open(self.stats_fname, "w") as fo_stat: for line in stats: fo_stat.write(line + "\n") def write_mislabels_header(self, fo, final, fields): header = ";" + "\t".join(fields) + "\n" # write to file if final: for line in DISCLAIMER.split("\n"): fo.write(";%s\n" % line) fo.write(";\n") fo.write(header) # print to console if final and self.cfg.verbose and len(self.rank_mislabels) > 0: print(DISCLAIMER, "\n") print("Mislabeled sequences:\n") print(header) def write_rank_mislabels(self, final=True): if not self.cfg.ranktest: return with open(self.misrank_fname, "w") as fo_all: fields = [ "RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence" ] self.write_mislabels_header(fo_all, final, fields) for mis_rec in self.rank_mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose: print(output) def write_mislabels(self, final=True): if final: out_fname = self.mis_fname else: out_fname = self.premis_fname with open(out_fname, "w") as fo_all: fields = [ "SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence" ] if self.cfg.ranktest: fields += ["HigherRankMisplacedConfidence"] self.write_mislabels_header(fo_all, final, fields) for mis_rec in self.mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose and final: print(output) if final: self.write_rank_mislabels() self.write_stats() def get_parent_tip_ranks(self, tax_tree): rank_tips = {} rank_parent = {} for node in tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size - 4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks return rank_parent, rank_tips def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree) subtree_list = list(rank_tips.items()) if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.tax_code.guess_rank_level_name( orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname( "%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2, len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count def prune_mislabels_from_tree(self, src_tree, tree_name): pruned_tree = src_tree.copy(method="newick") name2node = {} for leaf in pruned_tree.iter_leaves(): name2node[leaf.name] = leaf for mis_rec in self.mislabels: rname = mis_rec['name'] # rname = EpacConfig.REF_SEQ_PREFIX + name if rname in name2node: name2node[rname].delete() else: config.log.debug("Node not found in the %s tree: %s" % (tree_name, rname)) return pruned_tree def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() pruned_reftree = self.prune_mislabels_from_tree( self.reftree, "reference") pruned_taxtree = self.prune_mislabels_from_tree( self.reftree, "taxonomic") # remove unifurcation at the root if len(pruned_reftree.children) == 1: pruned_reftree = pruned_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(pruned_taxtree) reftree_epalbl_str = None if self.cfg.final_jplace_fname: if os.path.isdir(self.cfg.final_jplace_fname): jplace_fmask = os.path.join(self.cfg.final_jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.final_jplace_fname jplace_fname_list = glob.glob(jplace_fmask) placements = [] for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() if not reftree_epalbl_str: reftree_epalbl_str = jp.get_std_newick_tree() config.log.debug("Loaded %d final epa placements from %s\n", len(placements), jplace_fmask) else: epa_result = self.run_epa_once(pruned_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True) def run_epa_once(self, reftree): reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre") job_name = self.cfg.subst_name("final_epa_%NAME%") reftree.write(outfile=reftree_fname) # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! optmod_fname = "" epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname) if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True) return epa_result def run_test(self): self.raxml = RaxmlWrapper(self.cfg) # config.log.info("Number of sequences in the reference: %d\n", self.reftree_size) self.refjson.get_raxml_readable_tree(self.reftree_fname) self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln) self.refjson.get_binary_model(self.optmod_fname) if self.cfg.ranktest: config.log.info("Running the leave-one-rank-out test...\n") subtree_count = self.run_leave_subtree_out_test() config.log.info("Running the leave-one-sequence-out test...\n") self.run_leave_seq_out_test() if len(self.mislabels) > 0: config.log.info( "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels)) if self.cfg.debug: self.write_mislabels(final=False) self.run_final_epa_test() self.filter_mislabels() self.sort_mislabels() self.write_mislabels() config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() pruned_reftree = self.prune_mislabels_from_tree( self.reftree, "reference") pruned_taxtree = self.prune_mislabels_from_tree( self.reftree, "taxonomic") # remove unifurcation at the root if len(pruned_reftree.children) == 1: pruned_reftree = pruned_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(pruned_taxtree) reftree_epalbl_str = None if self.cfg.final_jplace_fname: if os.path.isdir(self.cfg.final_jplace_fname): jplace_fmask = os.path.join(self.cfg.final_jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.final_jplace_fname jplace_fname_list = glob.glob(jplace_fmask) placements = [] for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() if not reftree_epalbl_str: reftree_epalbl_str = jp.get_std_newick_tree() config.log.debug("Loaded %d final epa placements from %s\n", len(placements), jplace_fmask) else: epa_result = self.run_epa_once(pruned_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True)
class LeaveOneTest: def __init__(self, config, args): self.cfg = config self.method = args.method self.minlw = args.min_lhw self.jplace_fname = args.jplace_fname self.ranktest = args.ranktest self.output_fname = args.output_dir + "/" + args.output_name # switch off branch length filter self.brlen_pv = 0. self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") try: self.refjson = RefJsonParser(config.refjson_fname, ver="1.2") except ValueError: print("Invalid json file format!") sys.exit() #validate input json format self.refjson.validate() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.bid_taxonomy_map = self.refjson.get_bid_tanomomy_map() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.brlen_pv, self.rate, self.node_height) self.TAXONOMY_RANKS_COUNT = 10 self.mislabels = [] self.mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT self.rank_mislabels = [] self.rank_mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT self.misrank_conf_map = {} def cleanup(self): FileUtils.remove_if_exists(self.tmp_refaln) def classify_seq(self, placement): edges = placement["p"] if len(edges) > 0: return self.classify_helper.classify_seq(edges, self.method, self.minlw) else: print "ERROR: no placements! something is definitely wrong!" def rank_level_name(self, uni_rank_level): return { 0: ("?__", "Unknown"), 1: ("k__", "Kingdom"), 2: ("p__", "Phylum"), 3: ("c__", "Class"), 4: ("d__", "Subclass"), 5: ("o__", "Order"), 6: ("n__", "Suborder"), 7: ("f__", "Family"), 8: ("g__", "Genus"), 9: ("s__", "Species") }[uni_rank_level] def guess_rank_level(self, ranks, rank_level): rank_name = ranks[rank_level] real_level = 0 # check common prefixes and suffixes if rank_name.startswith("k__") or rank_name.lower() in ["bacteria", "archaea", "eukaryota"]: real_level = 1 elif rank_name.startswith("p__"): real_level = 2 elif rank_name.startswith("c__"): real_level = 3 elif rank_name.endswith("dae"): real_level = 4 elif rank_name.startswith("o__") or rank_name.endswith("ales"): real_level = 5 elif rank_name.endswith("neae"): real_level = 6 elif rank_name.startswith("f__") or rank_name.endswith("ceae"): real_level = 7 elif rank_name.startswith("g__"): real_level = 8 elif rank_name.startswith("s__"): real_level = 9 if real_level == 0: if rank_level == 0: # kingdom real_level = 1 else: parent_level = self.guess_rank_level(ranks, rank_level-1) real_level = parent_level + 1 if len(ranks) < 8 and (real_level in [4,6]): real_level += 1 return real_level def guess_rank_level_name(self, ranks, rank_level): real_level = self.guess_rank_level(ranks, rank_level) return self.rank_level_name(real_level) def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = EpacConfig.strip_ref_prefix(seq_name) mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.rank_level_name(real_lvl)[1] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.mislabels.append(mis_rec) return mis_rec else: return None def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = rank_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.rank_level_name(real_lvl)[1] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.rank_mislabels.append(mis_rec) return mis_rec else: return None def mis_rec_to_string_old(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += ";".join(mis_rec['orig_ranks']) + "\n" output += ";".join(mis_rec['ranks']) + "\n" output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n" return output def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += Taxonomy.lineage_str(mis_rec['orig_ranks']) + "\t" output += Taxonomy.lineage_str(mis_rec['ranks']) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output def sort_mislabels(self): self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf'), reverse=True) for mis_rec in self.mislabels: real_lvl = mis_rec["real_level"] self.mislabels_cnt[real_lvl] += 1 if self.ranktest: self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf'), reverse=True) for mis_rec in self.rank_mislabels: real_lvl = mis_rec["real_level"] self.rank_mislabels_cnt[real_lvl] += 1 def write_mislabels(self, final=True): if final: out_fname = "%s.mis" % self.output_fname else: out_fname = "%s.premis" % self.output_fname with open(out_fname, "w") as fo_all: fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] if self.ranktest: fields += ["HigherRankMisplacedConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.mislabels) > 0 and final: print "Mislabeled sequences:\n" print header for mis_rec in self.mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose and final: print(output) if not final: return if self.ranktest: with open("%s.misrank" % self.output_fname, "w") as fo_all: fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.rank_mislabels) > 0: print "\nMislabeled higher ranks:\n" print header for mis_rec in self.rank_mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose: print(output) print "Mislabels counts by ranks:" with open("%s.stats" % self.output_fname, "w") as fo_stat: seq_sum = 0 rank_sum = 0 for i in range(1, self.TAXONOMY_RANKS_COUNT): rname = self.rank_level_name(i)[1].ljust(10) if self.mislabels_cnt[i] > 0 or i not in [4,6]: seq_sum += self.mislabels_cnt[i] output = "%s:\t%d" % (rname, seq_sum) if self.ranktest: rank_sum += self.rank_mislabels_cnt[i] output += "\t%d" % rank_sum fo_stat.write(output + "\n") print(output) def get_orig_ranks(self, seq_name): nodes = self.tax_tree.get_leaves_by_name(seq_name) if len(nodes) != 1: print "FATAL ERROR: Sequence %s is not found in the taxonomic tree, or is present more than once!" % seq_name sys.exit() seq_node = nodes[0] orig_ranks = Taxonomy.split_rank_uid(seq_node.up.name) return orig_ranks def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_tips = {} rank_parent = {} for node in self.tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size-4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") if self.jplace_fname: jp = EpaJsonParser(self.jplace_fname) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() seq_count = 0 for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.get_orig_ranks(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.ranktest and mis_rec: rank_conf = 0 for lvl in range(2,len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 return seq_count def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() tmp_reftree = self.reftree.copy() tmp_taxtree = self.tax_tree.copy() for mis_rec in self.mislabels: name = mis_rec['name'] rname = EpacConfig.REF_SEQ_PREFIX + name leaf_nodes = tmp_reftree.get_leaves_by_name(rname) if len(leaf_nodes) > 0: leaf_nodes[0].delete() else: print "Node not found in the reference tree: %s" % rname leaf_nodes = tmp_taxtree.get_leaves_by_name(rname) if len(leaf_nodes) > 0: leaf_nodes[0].delete() else: print "Node not found in the taxonomic tree: %s" % rname # remove unifurcation at the root if len(tmp_reftree.children) == 1: tmp_reftree = tmp_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.origin_taxonomy, self.cfg) th.set_mf_rooted_tree(tmp_taxtree) self.run_epa_once(tmp_reftree, th) def run_epa_once(self, reftree, th): reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre") job_name = self.cfg.subst_name("final_epa_%NAME%") reftree.write(outfile=reftree_fname) # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! optmod_fname="" epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.brlen_pv, self.rate, self.node_height) for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.get_orig_ranks(seq_name) # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) if not self.cfg.debug: self.raxml.cleanup(job_name) FileUtils.remove_if_exists(reftree_fname) def run_test(self): self.raxml = RaxmlWrapper(self.cfg) print "Number of sequences in the reference: %d\n" % self.reftree_size self.refjson.get_raxml_readable_tree(self.reftree_fname) self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln) self.refjson.get_binary_model(self.optmod_fname) if self.ranktest: print "Running the leave-one-rank-out test...\n" subtree_count = self.run_leave_subtree_out_test() print "Running the leave-one-sequence-out test...\n" self.run_leave_seq_out_test() if len(self.mislabels) > 0: print "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n" % len(self.mislabels) self.write_mislabels(final=False) self.run_final_epa_test() self.sort_mislabels() self.write_mislabels() print "\nPercentage of mislabeled sequences: %.2f %%" % (float(len(self.mislabels)) / self.reftree_size * 100) if not self.cfg.debug: FileUtils.remove_if_exists(self.reftree_fname) FileUtils.remove_if_exists(self.optmod_fname) FileUtils.remove_if_exists(self.refalign_fname)
class LeaveOneTest: def __init__(self, config): self.cfg = config self.mis_fname = self.cfg.out_fname("%NAME%.mis") self.premis_fname = self.cfg.out_fname("%NAME%.premis") self.misrank_fname = self.cfg.out_fname("%NAME%.misrank") self.stats_fname = self.cfg.out_fname("%NAME%.stats") if os.path.isfile(self.mis_fname): print "\nERROR: Output file already exists: %s" % self.mis_fname print "Please specify a different job name using -n or remove old output files." self.cfg.exit_user_error() self.tmp_refaln = config.tmp_fname("%NAME%.refaln") self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre") self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre") self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt") self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre") self.mislabels = [] self.mislabels_cnt = [] self.rank_mislabels = [] self.rank_mislabels_cnt = [] self.misrank_conf_map = {} def write_bid_tax_map(self, bid_tax_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix) with open(bid_fname, "w") as outf: for bid, bid_rec in bid_tax_map.iteritems(): outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2])); def write_assignments(self, assign_map, final): if self.cfg.debug: fname_suffix = "final" if final else "l1out" assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix) with open(assign_fname, "w") as outf: for seq_name in assign_map.iterkeys(): ranks, lws = assign_map[seq_name] outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join(["%.3f" % l for l in lws]))) def load_refjson(self, refjson_fname): try: self.refjson = RefJsonParser(refjson_fname) except ValueError: self.cfg.exit_user_error("ERROR: Invalid json file format!") #validate input json format (valid, err) = self.refjson.validate() if not valid: self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err) self.cfg.exit_user_error() self.rate = self.refjson.get_rate() self.node_height = self.refjson.get_node_height() self.origin_taxonomy = self.refjson.get_origin_taxonomy() self.tax_tree = self.refjson.get_tax_tree() self.cfg.compress_patterns = self.refjson.get_pattern_compression() self.bid_taxonomy_map = self.refjson.get_branch_tax_map() if not self.bid_taxonomy_map: # old file format (before 1.6), need to rebuild this map from scratch th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(self.tax_tree) th.set_bf_unrooted_tree(self.refjson.get_reftree()) self.bid_taxonomy_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(self.bid_taxonomy_map, final=False) reftree_str = self.refjson.get_raxml_readable_tree() self.reftree = Tree(reftree_str) self.reftree_size = len(self.reftree.get_leaves()) # IMPORTANT: set EPA heuristic rate based on tree size! self.cfg.resolve_auto_settings(self.reftree_size) # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file if self.cfg.epa_load_optmod: self.cfg.raxml_model = self.refjson.get_ratehet_model() self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height) self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree) tax_code_name = self.refjson.get_taxcode() self.tax_code = TaxCode(tax_code_name) self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy) self.tax_common_ranks = self.taxonomy.get_common_ranks() # print "Common ranks: ", self.tax_common_ranks self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS def run_epa_trainer(self): epa_trainer.run_trainer(self.cfg) if not os.path.isfile(self.cfg.refjson_fname): self.cfg.log.error("\nBuilding reference tree failed, see error messages above.") self.cfg.exit_fatal_error() def classify_seq(self, placement): edges = placement["p"] if len(edges) > 0: return self.classify_helper.classify_seq(edges) else: print "ERROR: no placements! something is definitely wrong!" def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws): mis_rec = None num_common_ranks = len(self.tax_common_ranks) orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks) #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks): # if new_rank_level < 0: if len(ranks) == 0: mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = -1 mis_rec['real_level'] = 0 mis_rec['level_name'] = "[NotIngroup]" mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = [] mis_rec['lws'] = [1.0] mis_rec['conf'] = mis_rec['lws'][0] else: mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = seq_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * mis_rec['real_level'] # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] if mis_rec: self.mislabels.append(mis_rec) return mis_rec def filter_mislabels(self): filtered_mis = [] for i in range(len(self.mislabels)): if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff: filtered_mis.append(self.mislabels[i]) self.mislabels = filtered_mis def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws): mislabel_lvl = -1 min_len = min(len(orig_ranks),len(ranks)) for rank_lvl in range(min_len): if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]: mislabel_lvl = rank_lvl break if mislabel_lvl >= 0: real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl) mis_rec = {} mis_rec['name'] = rank_name mis_rec['orig_level'] = mislabel_lvl mis_rec['real_level'] = real_lvl mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0] mis_rec['inv_level'] = -1 * real_lvl # just for sorting mis_rec['orig_ranks'] = orig_ranks mis_rec['ranks'] = ranks mis_rec['lws'] = lws mis_rec['conf'] = lws[mislabel_lvl] self.rank_mislabels.append(mis_rec) return mis_rec else: return None def mis_rec_to_string_old(self, mis_rec): lvl = mis_rec['orig_level'] output = mis_rec['name'] + "\t" output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl]) output += ";".join(mis_rec['orig_ranks']) + "\n" output += ";".join(mis_rec['ranks']) + "\n" output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n" return output def mis_rec_to_string(self, mis_rec): lvl = mis_rec['orig_level'] uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name'])) uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks']) uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks']) output = uncorr_name + "\t" if lvl >= 0: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl]) else: output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA", "NA", mis_rec['lws'][0]) output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t" output += Taxonomy.lineage_str(uncorr_ranks) + "\t" output += ";".join(["%.3f" % conf for conf in mis_rec['lws']]) if 'rank_conf' in mis_rec: output += "\t%.3f" % mis_rec['rank_conf'] return output def sort_mislabels(self): self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.mislabels: real_lvl = mis_rec["real_level"] self.mislabels_cnt[real_lvl] += 1 if self.cfg.ranktest: self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True) for mis_rec in self.rank_mislabels: real_lvl = mis_rec["real_level"] self.rank_mislabels_cnt[real_lvl] += 1 def write_stats(self, toFile=False): self.cfg.log.info("Mislabeled sequences by rank:") seq_sum = 0 rank_sum = 0 stats = [] for i in range(len(self.mislabels_cnt)): if i > 0: rname = self.tax_code.rank_level_name(i)[0].ljust(12) else: rname = "[NotIngroup]" if self.mislabels_cnt[i] > 0: seq_sum += self.mislabels_cnt[i] # output = "%s:\t%d" % (rname, seq_sum) output = "%s:\t%d" % (rname, self.mislabels_cnt[i]) if self.cfg.ranktest: rank_sum += self.rank_mislabels_cnt[i] output += "\t%d" % rank_sum self.cfg.log.info(output) stats.append(output) if toFile: with open(self.stats_fname, "w") as fo_stat: for line in stats: fo_stat.write(line + "\n") def write_mislabels(self, final=True): if final: out_fname = self.mis_fname else: out_fname = self.premis_fname with open(out_fname, "w") as fo_all: fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] if self.cfg.ranktest: fields += ["HigherRankMisplacedConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.mislabels) > 0 and final: print "Mislabeled sequences:\n" print header for mis_rec in self.mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose and final: print(output) if not final: return if self.cfg.ranktest: with open(self.misrank_fname, "w") as fo_all: fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"] header = ";" + "\t".join(fields) + "\n" fo_all.write(header) if self.cfg.verbose and len(self.rank_mislabels) > 0: print "\nMislabeled higher ranks:\n" print header for mis_rec in self.rank_mislabels: output = self.mis_rec_to_string(mis_rec) + "\n" fo_all.write(output) if self.cfg.verbose: print(output) self.write_stats() def run_leave_subtree_out_test(self): job_name = self.cfg.subst_name("l1out_rank_%NAME%") # if self.jplace_fname: # jp = EpaJsonParser(self.jplace_fname) # else: #create file with subtrees rank_tips = {} rank_parent = {} for node in self.tax_tree.traverse("postorder"): if node.is_leaf() or node.is_root(): continue tax_path = node.name ranks = Taxonomy.split_rank_uid(tax_path) rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks) if rank_lvl < 2: continue parent_ranks = Taxonomy.split_rank_uid(node.up.name) parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks) if parent_lvl < 1: continue rank_seqs = node.get_leaf_names() rank_size = len(rank_seqs) if rank_size < 2 or rank_size > self.reftree_size-4: continue # print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n" rank_tips[tax_path] = node.get_leaf_names() rank_parent[tax_path] = parent_ranks subtree_list = rank_tips.items() if len(subtree_list) == 0: return 0 subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt") with open(subtree_list_file, "w") as fout: for rank_name, tips in subtree_list: fout.write("%s\n" % " ".join(tips)) jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_subtree", subtree_fname=subtree_list_file) subtree_count = 0 for jp in jp_list: placements = jp.get_placement() for place in placements: ranks, lws = self.classify_seq(place) tax_path = subtree_list[subtree_count][0] orig_ranks = Taxonomy.split_rank_uid(tax_path) rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks) rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0] rank_name = orig_ranks[rank_level] if not rank_name.startswith(rank_prefix): rank_name = rank_prefix + rank_name parent_ranks = rank_parent[tax_path] # print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n" mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws) if mis_rec: self.misrank_conf_map[tax_path] = mis_rec['conf'] subtree_count += 1 return subtree_count def run_leave_seq_out_test(self): job_name = self.cfg.subst_name("l1out_seq_%NAME%") placements = [] if self.cfg.jplace_fname: if os.path.isdir(self.cfg.jplace_fname): jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace') else: jplace_fmask = self.cfg.jplace_fname jplace_fname_list = glob.glob(jplace_fmask) for jplace_fname in jplace_fname_list: jp = EpaJsonParser(jplace_fname) placements += jp.get_placement() config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask) else: jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq") placements = jp.get_placement() if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq") seq_count = 0 l1out_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label # orig_ranks = self.get_orig_ranks(seq_name) orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # get EPA tax label ranks, lws = self.classify_seq(place) l1out_ass[seq_name] = (ranks, lws) # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) # cross-check with higher rank mislabels if self.cfg.ranktest and mis_rec: rank_conf = 0 for lvl in range(2,len(orig_ranks)): tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl) if tax_path in self.misrank_conf_map: rank_conf = max(rank_conf, self.misrank_conf_map[tax_path]) mis_rec['rank_conf'] = rank_conf seq_count += 1 self.write_assignments(l1out_ass, final=False) return seq_count def run_final_epa_test(self): self.reftree_outgroup = self.refjson.get_outgroup() tmp_reftree = self.reftree.copy(method="newick") name2refnode = {} for leaf in tmp_reftree.iter_leaves(): name2refnode[leaf.name] = leaf tmp_taxtree = self.tax_tree.copy(method="newick") name2taxnode = {} for leaf in tmp_taxtree.iter_leaves(): name2taxnode[leaf.name] = leaf for mis_rec in self.mislabels: rname = mis_rec['name'] # rname = EpacConfig.REF_SEQ_PREFIX + name if rname in name2refnode: name2refnode[rname].delete() else: print "Node not found in the reference tree: %s" % rname if rname in name2taxnode: name2taxnode[rname].delete() else: print "Node not found in the taxonomic tree: %s" % rname # remove unifurcation at the root if len(tmp_reftree.children) == 1: tmp_reftree = tmp_reftree.children[0] self.mislabels = [] th = TaxTreeHelper(self.cfg, self.origin_taxonomy) th.set_mf_rooted_tree(tmp_taxtree) epa_result = self.run_epa_once(tmp_reftree) reftree_epalbl_str = epa_result.get_std_newick_tree() placements = epa_result.get_placement() # update branchid-taxonomy mapping to account for possible changes in branch numbering reftree_tax = Tree(reftree_epalbl_str) th.set_bf_unrooted_tree(reftree_tax) bid_tax_map = th.get_bid_taxonomy_map() self.write_bid_tax_map(bid_tax_map, final=True) cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height) # newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre") # th.get_tax_tree().write(outfile=newtax_fname, format=3) final_ass = {} for place in placements: seq_name = place["n"][0] # get original taxonomic label orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name) # EXPERIMENTAL FEATURE - disabled for now! # It could happen that certain ranks were present in the "original" reference tree, but # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" # after the leave-one-out test and thus pruned) # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from # pruned tree to "Undefined". # orig_ranks = th.strip_missing_ranks(orig_ranks) # print orig_ranks # get EPA tax label ranks, lws = cl.classify_seq(place["p"]) final_ass[seq_name] = (ranks, lws) #print seq_name, ": ", orig_ranks, "--->", ranks # check if they match mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws) self.write_assignments(final_ass, final=True) def run_epa_once(self, reftree): reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre") job_name = self.cfg.subst_name("final_epa_%NAME%") reftree.write(outfile=reftree_fname) # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! optmod_fname="" epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname) if self.cfg.output_interim_files: out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace") self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True) return epa_result def run_test(self): self.raxml = RaxmlWrapper(self.cfg) # config.log.info("Number of sequences in the reference: %d\n", self.reftree_size) self.refjson.get_raxml_readable_tree(self.reftree_fname) self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln) self.refjson.get_binary_model(self.optmod_fname) if self.cfg.ranktest: config.log.info("Running the leave-one-rank-out test...\n") subtree_count = self.run_leave_subtree_out_test() config.log.info("Running the leave-one-sequence-out test...\n") self.run_leave_seq_out_test() if len(self.mislabels) > 0: config.log.info("Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels)) if self.cfg.debug: self.write_mislabels(final=False) self.run_final_epa_test() self.filter_mislabels() self.sort_mislabels() self.write_mislabels() config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))